JAX Optimizers

Pros and cons of several jax optimizers.
ML
Author

Zeel B Patel

Published

June 10, 2022

%%capture
%pip install -U jax
import jax
import jax.numpy as jnp
try:
  import jaxopt
except ModuleNotFoundError:
  %pip install -qq jaxopt
  import jaxopt
try:
  import optax
except ModuleNotFoundError:
  %pip install -qq optax
  import optax

import tensorflow_probability.substrates.jax as tfp

Loss function

def loss_fun(x, a):
  return (((x['param1'] - a) + (x['param2'] - (a+1)))**2).sum()

Initial parameters

N = 3
init_params = lambda: {'param1': jnp.zeros(N), 'param2': jnp.ones(N)}
a = 2.0

Optimizers

JaxOpt ScipyMinimize

%%time
solver = jaxopt.ScipyMinimize('L-BFGS-B', fun=loss_fun)
ans = solver.run(init_params(), a)
print(ans)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
OptStep(params={'param1': DeviceArray([1.9999999, 1.9999999, 1.9999999], dtype=float32), 'param2': DeviceArray([3., 3., 3.], dtype=float32)}, state=ScipyMinimizeInfo(fun_val=DeviceArray(4.2632564e-14, dtype=float32), success=True, status=0, iter_num=2))
CPU times: user 78.3 ms, sys: 18.5 ms, total: 96.8 ms
Wall time: 95.8 ms

Pros

  • Two lines of code will do it all.

Cons

  • It only returns the final parameters and final loss. No option to retrive in-between loss values.

Optax

%%time
optimizer = optax.adam(learning_rate=0.1)
value_and_grad_fun = jax.jit(jax.value_and_grad(loss_fun, argnums=0))
params = init_params()
state = optimizer.init(params)

for _ in range(100):
  loss_value, gradients = value_and_grad_fun(params, a)
  updates, state = optimizer.update(gradients, state)
  params = optax.apply_updates(params, updates)

print(params)
{'param1': DeviceArray([2.0084236, 2.0084236, 2.0084236], dtype=float32), 'param2': DeviceArray([3.0084238, 3.0084238, 3.0084238], dtype=float32)}
CPU times: user 3.09 s, sys: 63.4 ms, total: 3.16 s
Wall time: 4.2 s

Pros:

  • Full control in user’s hand. We can save intermediate loss values.

Cons:

  • Its code is verbose, similar to PyTorch optimizers.

Jaxopt OptaxSolver

%%time
optimizer = optax.adam(learning_rate=0.1)
solver = jaxopt.OptaxSolver(loss_fun, optimizer, maxiter=100)
ans = solver.run(init_params(), a)
print(ans)
OptStep(params={'param1': DeviceArray([2.008423, 2.008423, 2.008423], dtype=float32), 'param2': DeviceArray([3.008423, 3.008423, 3.008423], dtype=float32)}, state=OptaxState(iter_num=DeviceArray(100, dtype=int32, weak_type=True), value=DeviceArray(0.00113989, dtype=float32), error=DeviceArray(0.09549397, dtype=float32), internal_state=(ScaleByAdamState(count=DeviceArray(100, dtype=int32), mu={'param1': DeviceArray([0.02871927, 0.02871927, 0.02871927], dtype=float32), 'param2': DeviceArray([0.02871927, 0.02871927, 0.02871927], dtype=float32)}, nu={'param1': DeviceArray([0.44847375, 0.44847375, 0.44847375], dtype=float32), 'param2': DeviceArray([0.44847375, 0.44847375, 0.44847375], dtype=float32)}), EmptyState()), aux=None))
CPU times: user 719 ms, sys: 13.4 ms, total: 732 ms
Wall time: 1.09 s

Pros:

  • Less lines of code.
  • Applies lax.scan internally to make it fast [reference].

Cons:

  • Not able to get in-between state/loss values

tfp math minimize

%%time
optimizer = optax.adam(learning_rate=0.1)
params, losses = tfp.math.minimize_stateless(loss_fun, (init_params(), a), num_steps=1000, optimizer=optimizer)
print(params)
print(losses[:5])
({'param1': DeviceArray([1.0000008, 1.0000008, 1.0000008], dtype=float32), 'param2': DeviceArray([1.9999989, 1.9999989, 1.9999989], dtype=float32)}, DeviceArray(0.9999999, dtype=float32))
[48.       38.88006  30.751791 23.626852 17.507807]
CPU times: user 880 ms, sys: 15.2 ms, total: 895 ms
Wall time: 1.53 s

Pros:

  • One line of code to optimize the function and return in-between losses.

Cons:

  • By default, it optimizes all arguments passed to the loss function. In above example, we can not control if a should be optimized or not. I have raised an issue here for this problem.