import jax
import jax.numpy as jnp
import numpy as np
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from time import time
# Enable high precision
from jax.config import config
"jax_enable_x64", True)
config.update(
# To enable animation inside notebook
"animation", html="jshtml") plt.rc(
Create dataset
= make_blobs(100, n_features=2, centers=2, random_state=0)
features, labels 0], features[:, 1], c=labels); plt.scatter(features[:,
print(features.shape, features.dtype, labels.shape, labels.dtype)
(100, 2) float64 (100,) int64
Implementing Newton’s method (naive way)
We will first try to implement Eq. 10.31 directly from PML book1:
\[ \boldsymbol{w}_{t+1}=\boldsymbol{w}_{t}-\eta_{t} \mathbf{H}_{t}^{-1} \boldsymbol{g}_{t} \]
def get_logits(params, feature): # for a single data-point
= jnp.sum(feature * params["w"]) + params["b"]
logits return logits
def naive_loss(params, feature, label): # for a single data-point
= get_logits(params, feature)
logits = jax.nn.sigmoid(logits)
prob
# Check if label is 1 or 0
= (label == 1)
is_one = lambda: -jnp.log(prob) # loss if label is 1
loss_if_one = lambda: -jnp.log(1 - prob) # loss if labels is 0
loss_if_zero
# Use lax.cond to convert if..else.. in jittable format
= jax.lax.cond(is_one, loss_if_one, loss_if_zero)
loss
return loss
def naive_loss_batch(params, features, labels): # for a batch of data-points
= jax.vmap(naive_loss, in_axes=(None, 0, 0))(params, features, labels)
losses return jnp.mean(losses)
Writing the train function
def naive_train_step(params, features, labels, learning_rate):
# Find gradient
= jax.value_and_grad(naive_loss_batch)(params, features, labels)
loss_value, grads
# Find Hessian
= jax.hessian(naive_loss_batch)(params, features, labels)
hess
# Adjust Hessian matrix nicely
= jnp.block([[hess["b"]["b"], hess["b"]["w"]],
hess_matrix "w"]["b"], hess["w"]["w"]]])
[hess[
# Adjust gradient vector nicely
= jnp.r_[grads["b"], grads["w"]]
grad_vector
# Find H^-1g
= jnp.dot(jnp.linalg.inv(hess_matrix), grad_vector)
h_inv_g
# Get back the structure
= {"b": h_inv_g[0], "w": h_inv_g[1:]}
h_inv_g
# Apply the update
= jax.tree_map(lambda p, g: p - learning_rate*g, params, h_inv_g)
params
return params, loss_value
# First order method
# vg = jax.value_and_grad(naive_loss_batch)
# def train_step(params, features, labels, learning_rate):
# # Find gradient
# loss_value, grads = vg(params, features, labels)
# # Apply the update
# params = jax.tree_map(lambda p, g: p - learning_rate*g, params, grads)
# return params, loss_value
= jax.random.PRNGKey(0)
key = jax.random.normal(key, shape=(3, ))
random_params # "b" should have shape (1,) for hessian trick with jnp.block to work
= {"w": random_params[:2], "b": random_params[2].reshape(1,)}
params = 1.0
learning_rate = 20
epochs
= jax.jit(naive_train_step)
train_step_jitted
= {"loss": [], "params": []}
history
# warm up
train_step_jitted(params, features, labels, learning_rate)
= time()
init for _ in range(epochs):
"params"].append(params)
history[= train_step_jitted(params, features, labels, learning_rate)
params, loss_value "loss"].append(loss_value)
history[print(time() - init, "seconds")
print(params)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
0.0015490055084228516 seconds
{'b': DeviceArray([13.22076694], dtype=float64), 'w': DeviceArray([ 0.59021174, -5.18797851], dtype=float64)}
A helper function to animate the learning.
def animate(history):
= plt.subplots(1, 2, figsize=(10,4))
fig, ax def update(idx):
# Clear previous frame
0].cla()
ax[1].cla()
ax[
# Plot data
= history["params"][idx]
params = history["loss"][:idx]
losses 0].scatter(features[:, 0], features[:, 1], c=labels)
ax[
# Calculate and plot decision boundary
= features[:, 0].min(), features[:, 0].max()
x0_min, x0_max = -(params["b"] + params["w"][0] * x0_min)/params["w"][1]
x1_min = -(params["b"] + params["w"][0] * x0_max)/params["w"][1]
x1_max
0].plot([x0_min, x0_max], [x1_min, x1_max], label='decision boundary')
ax[
# Plot losses
1].plot(losses, label="loss")
ax[1].set_xlabel("Iterations")
ax[
0].legend()
ax[1].legend()
ax[
= FuncAnimation(fig, update, range(epochs))
anim
plt.close()return anim
animate(history)
Implementing IRLS algorithm
def get_s_and_z(params, feature, label): # for a single data-point
= get_logits(params, feature)
logits = jax.nn.sigmoid(logits)
prob = prob * (1 - prob)
s = logits + (label - prob)/s
z return s, z
def irls_train_step(params, features, labels):
= jax.vmap(get_s_and_z, in_axes=(None, 0, 0))(params, features, labels)
s, z = jnp.diag(s.flatten()) # convert into a diagonal matrix
S
# Add column with ones
= jnp.c_[jnp.ones(len(z)), features]
X
# Get weights
= jnp.linalg.inv(X.T@S@X)@X.T@S@z.flatten()
weights
# get correct format
= {"b": weights[0], "w": weights[1:]}
params
return params
= jax.random.PRNGKey(0)
key = jax.random.normal(key, shape=(3,))
random_params = {"w": random_params[:2], "b": random_params[2]}
params = 20
epochs
= jax.jit(irls_train_step)
train_step_jitted
= {"params": []}
irls_history
# warm up
train_step_jitted(params, features, labels)
= time()
init for _ in range(epochs):
"params"].append(params)
irls_history[= train_step_jitted(params, features, labels)
params print(time() - init, "seconds")
print(params)
0.0016303062438964844 seconds
{'b': DeviceArray(13.22076694, dtype=float64), 'w': DeviceArray([ 0.59021174, -5.18797851], dtype=float64)}
Comparison
= list(map(lambda x: x["b"], history["params"]))
naive_params_b = list(map(lambda x: x["b"], irls_history["params"]))
irls_params_b
= list(map(lambda x: x["w"], history["params"]))
naive_params_w = list(map(lambda x: x["w"], irls_history["params"])) irls_params_w
"o-", label="Naive")
plt.plot(naive_params_b, ="IRLS")
plt.plot(irls_params_b, label"Iterations")
plt.xlabel("Bias")
plt.title(; plt.legend()
"o-", label="Naive")
plt.plot(naive_params_w, ="IRLS")
plt.plot(irls_params_w, label"Iterations")
plt.xlabel("Weights")
plt.title(; plt.legend()