%%capture
%pip install -U jax
import jax
import jax.numpy as jnp
try:
import jaxopt
except ModuleNotFoundError:
%pip install -qq jaxopt
import jaxopt
try:
import optax
except ModuleNotFoundError:
%pip install -qq optax
import optax
import tensorflow_probability.substrates.jax as tfp
Loss function
def loss_fun(x, a):
return (((x['param1'] - a) + (x['param2'] - (a+1)))**2).sum()
Initial parameters
= 3
N = lambda: {'param1': jnp.zeros(N), 'param2': jnp.ones(N)}
init_params = 2.0 a
Optimizers
JaxOpt ScipyMinimize
%%time
= jaxopt.ScipyMinimize('L-BFGS-B', fun=loss_fun)
solver = solver.run(init_params(), a)
ans print(ans)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
OptStep(params={'param1': DeviceArray([1.9999999, 1.9999999, 1.9999999], dtype=float32), 'param2': DeviceArray([3., 3., 3.], dtype=float32)}, state=ScipyMinimizeInfo(fun_val=DeviceArray(4.2632564e-14, dtype=float32), success=True, status=0, iter_num=2))
CPU times: user 78.3 ms, sys: 18.5 ms, total: 96.8 ms
Wall time: 95.8 ms
Pros
- Two lines of code will do it all.
Cons
- It only returns the final parameters and final loss. No option to retrive in-between loss values.
Optax
%%time
= optax.adam(learning_rate=0.1)
optimizer = jax.jit(jax.value_and_grad(loss_fun, argnums=0))
value_and_grad_fun = init_params()
params = optimizer.init(params)
state
for _ in range(100):
= value_and_grad_fun(params, a)
loss_value, gradients = optimizer.update(gradients, state)
updates, state = optax.apply_updates(params, updates)
params
print(params)
{'param1': DeviceArray([2.0084236, 2.0084236, 2.0084236], dtype=float32), 'param2': DeviceArray([3.0084238, 3.0084238, 3.0084238], dtype=float32)}
CPU times: user 3.09 s, sys: 63.4 ms, total: 3.16 s
Wall time: 4.2 s
Pros:
- Full control in user’s hand. We can save intermediate loss values.
Cons:
- Its code is verbose, similar to PyTorch optimizers.
Jaxopt OptaxSolver
%%time
= optax.adam(learning_rate=0.1)
optimizer = jaxopt.OptaxSolver(loss_fun, optimizer, maxiter=100)
solver = solver.run(init_params(), a)
ans print(ans)
OptStep(params={'param1': DeviceArray([2.008423, 2.008423, 2.008423], dtype=float32), 'param2': DeviceArray([3.008423, 3.008423, 3.008423], dtype=float32)}, state=OptaxState(iter_num=DeviceArray(100, dtype=int32, weak_type=True), value=DeviceArray(0.00113989, dtype=float32), error=DeviceArray(0.09549397, dtype=float32), internal_state=(ScaleByAdamState(count=DeviceArray(100, dtype=int32), mu={'param1': DeviceArray([0.02871927, 0.02871927, 0.02871927], dtype=float32), 'param2': DeviceArray([0.02871927, 0.02871927, 0.02871927], dtype=float32)}, nu={'param1': DeviceArray([0.44847375, 0.44847375, 0.44847375], dtype=float32), 'param2': DeviceArray([0.44847375, 0.44847375, 0.44847375], dtype=float32)}), EmptyState()), aux=None))
CPU times: user 719 ms, sys: 13.4 ms, total: 732 ms
Wall time: 1.09 s
Pros:
- Less lines of code.
- Applies
lax.scan
internally to make it fast [reference].
Cons:
- Not able to get in-between state/loss values
tfp math minimize
%%time
= optax.adam(learning_rate=0.1)
optimizer = tfp.math.minimize_stateless(loss_fun, (init_params(), a), num_steps=1000, optimizer=optimizer)
params, losses print(params)
print(losses[:5])
({'param1': DeviceArray([1.0000008, 1.0000008, 1.0000008], dtype=float32), 'param2': DeviceArray([1.9999989, 1.9999989, 1.9999989], dtype=float32)}, DeviceArray(0.9999999, dtype=float32))
[48. 38.88006 30.751791 23.626852 17.507807]
CPU times: user 880 ms, sys: 15.2 ms, total: 895 ms
Wall time: 1.53 s
Pros:
- One line of code to optimize the function and return in-between losses.
Cons:
- By default, it optimizes all arguments passed to the loss function. In above example, we can not control if
a
should be optimized or not. I have raised an issue here for this problem.