# JAX
import jax
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import jax.scipy as jsp
# Partially initialize functions
from functools import partial
# TFP
import tensorflow_probability.substrates.jax as tfp
= tfp.distributions
tfd = tfp.bijectors
tfb
# GP Kernels
from tinygp import kernels
# sklearn
from sklearn.datasets import make_moons, make_blobs, make_circles
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
# Optimization
import optax
# Plotting
import matplotlib.pyplot as plt
'scatter.edgecolors'] = "k"
plt.rcParams[
# Progress bar
from tqdm import tqdm
# Jitter
= 1e-6
JITTER
# Enable JAX 64bit
"jax_enable_x64", True) jax.config.update(
I recently read a compact and clean explanation of SVGP in the following blog post by Dr. Martin Ingram:
Now, I am attempting to implement a practical code from scratch for the same (What is practical about it? Sometimes math does not simply translate to code without careful modifications). I am assuming that you have read the blog post cited above before moving further. Let’s go for coding!
Imports
Dataset
For this blog post, we will stick to the classification problem and pick a reasonable classification dataset.
= 100
n_samples = 0.1
noise = 0
random_state = True
shuffle
= make_moons(
X, y =n_samples, random_state=random_state, noise=noise, shuffle=shuffle
n_samples
)= StandardScaler().fit_transform(X) # Yes, this is useful for GPs
X
= map(jnp.array, (X, y))
X, y
0], X[:, 1], c=y); plt.scatter(X[:,
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Methodology
To define a GP, we need a kernel function. Let us use the RBF or Exponentiated Quadratic or Squared Exponential kernel.
= 1.0
lengthscale = 1.0
variance
= variance * kernels.ExpSquared(scale=lengthscale)
kernel_fn
kernel_fn(X, X).shape
(100, 100)
As explained in the blog post, we want to minimize the following loss function:
\[ KL[q(u|\eta) || p(u|y, \theta)] = KL[q(u|\eta) || p(u | \theta)] - \mathbb{E}_{u \sim q(u|\eta)} \log p(y | u, \theta) + const \]
Let us break down the loss and discuss each componant.
KL divergence
In the first term, we want to compute the KL divergence between prior and variational distribution of GP at inducing points. First, we need to define the inducing points.
= jax.random.PRNGKey(0)
key = 10
n_inducing = X.shape[1]
n_dim
= jax.random.normal(key, shape=(n_inducing, n_dim))
X_inducing X_inducing.shape
(10, 2)
Now, defining the prior and variational distributions.
= 0.43 # a scalar parameter to train
gp_mean
= gp_mean * jnp.zeros(n_inducing)
prior_mean = kernel_fn(X_inducing, X_inducing)
prior_cov
= tfd.MultivariateNormalFullCovariance(prior_mean, prior_cov) prior_distribution
= jax.random.uniform(key, shape=(n_inducing,)) # a vector parameter to train variational_mean
A covariance matrix can not be learned directly due to positive definite constraint. We can decompose a covariance matrix in a following way:
\[ \begin{aligned} K &= diag(\boldsymbol{\sigma})\Sigma diag(\boldsymbol{\sigma})\\ &= diag(\boldsymbol{\sigma})LL^T diag(\boldsymbol{\sigma}) \end{aligned} \]
Where, \(\Sigma\) is a correlation matrix, \(L\) is a lower triangular cholesky decomposition of \(\Sigma\) and \(\boldsymbol{\sigma}\) is the variance vector. We can use tfb.CorrelationCholesky
to generate \(L\) from an unconstrained vector:
= jax.random.normal(key, shape=(3,))
random_vector = tfb.CorrelationCholesky()(random_vector)
corr_chol = corr_chol@corr_chol.T
correlation correlation
DeviceArray([[ 1. , 0.54464529, -0.7835968 ],
[ 0.54464529, 1. , -0.33059078],
[-0.7835968 , -0.33059078, 1. ]], dtype=float64)
To constrain \(\boldsymbol{\sigma}\), any positivity constraint would suffice. So, combining these tricks, we can model the covariance as following:
= jax.random.normal(
random_vector =(n_inducing * (n_inducing - 1) // 2,)
key, shape# a trainable parameter
) = jax.random.normal(key, shape=(n_inducing, 1)) # a trainable parameter
log_sigma
= jnp.exp(log_sigma)
sigma = tfb.CorrelationCholesky()(random_vector)
corr_chol = sigma * sigma.T * (corr_chol @ corr_chol.T)
variational_cov print(variational_cov.shape)
= tfd.MultivariateNormalFullCovariance(variational_mean, variational_cov
variational_distribution )
(10, 10)
Now, we can compute the KL divergence:
variational_distribution.kl_divergence(prior_distribution)
DeviceArray(416.89357355, dtype=float64)
Expectation over the likelihood
We want to compute the following expectation:
\[ -\sum_{i=1}^N \mathbb{E}_{f_i \sim q(f_i | \eta, \theta)} \log p(y_i| f_i, \theta) \]
Note that, \(p(y_i| f_i, \theta)\) can be any likelihood depending upon the problem, but for classification, we may use a Bernoulli likelihood.
= jax.random.normal(key, shape=y.shape)
f = tfd.Bernoulli(logits=f)
likelihood_distribution
= likelihood_distribution.log_prob(y).sum()
log_likelihood log_likelihood
DeviceArray(-72.04665624, dtype=float64)
We need to sample \(f_i\) from \(q(f_i | \eta, \theta)\) which has the following form:
\[ \begin{aligned} q(u) &\sim \mathcal{N}(\boldsymbol{m}, S)\\ q(f_i | \eta, \theta) &\sim \mathcal{N}(\mu_i, \sigma_i^2)\\ \mu_i &= A\boldsymbol{m}\\ \sigma_i^2 &= K_{ii} + A(S - K_{mm})A^T\\ A &= K_{im}K_{mm}^{-1} \end{aligned} \]
Note that matrix inversion is often unstable with jnp.linalg.inv
and thus we will use cholesky tricks to compute \(A\).
def q_f(x_i):
= x_i.reshape(1, -1) # ensure correct shape
x_i = kernel_fn(x_i, X_inducing)
K_im = kernel_fn(X_inducing, X_inducing)
K_mm = jnp.linalg.cholesky(K_mm + jnp.eye(K_mm.shape[0])*JITTER)
chol_mm = jsp.linalg.cho_solve((chol_mm, True), K_im.T).T
A
= A@variational_mean
mu_i = kernel_fn(x_i, x_i) + A@(variational_cov - prior_cov)@A.T
sigma_sqr_i
return tfd.Normal(loc=mu_i, scale=sigma_sqr_i**0.5)
Here is a function to compute log likelihood for a single data-point:
def log_likelihood(x_i, y_i, seed):
= q_f(x_i).sample(seed=seed)
sample = tfd.Bernoulli(logits=sample).log_prob(y_i)
log_likelihood return log_likelihood.squeeze()
0], y[0], seed=key) log_likelihood(X[
DeviceArray(-0.17831203, dtype=float64)
We can use jax.vmap
to compute log_likelihood over a batch. With that, we can leverage the stochastic variational inference following section 10.3.1 (Eq. 10.108) from pml book2. Basically, in each iteration, we need to multiply the batch log likelihood with \(\frac{N}{B}\) to get an unbiased minibatch approximation where \(N\) is size of the full dataset and \(B\) is the batch size.
= 10
batch_size
= jax.random.split(key, num=batch_size)
seeds
= len(y)/batch_size * jax.vmap(log_likelihood)(X[:batch_size], y[:batch_size], seeds).sum()
ll ll
DeviceArray(-215.46520331, dtype=float64)
Note that, once the parameters are optimized, we can use the derivations of \(q(f_i | \eta, \theta)\) to compute the posterior distribution. We have figured out all the pieces by now so it is the time to put it togather in a single class. Some pointers to note are the following:
- We define a single function
get_constrained_params
to transform all unconstrained parameters. jax.lax.scan
gives a huge boost to a training loop.- There is some repeatation of code due to lack of super code optimization. You can do it at your end if needed.
All in one
class SVGP:
def __init__(self, X_inducing, data_size):
self.X_inducing = X_inducing
self.n_inducing = len(X_inducing)
self.data_size = data_size
def init_params(self, seed):
= tfb.CorrelationCholesky().inverse(jnp.eye(self.n_inducing))
variational_corr_chol_param
= {"log_variance": jnp.zeros(()),
dummy_params "log_scale": jnp.zeros(()),
"mean": jnp.zeros(()),
"X_inducing": self.X_inducing,
"variational_mean": jnp.zeros(self.n_inducing),
"variational_corr_chol_param": variational_corr_chol_param,
"log_variational_sigma": jnp.zeros((self.n_inducing, 1)),
}
= ravel_pytree(dummy_params)
flat_params, unravel_fn = jax.random.normal(key, shape=(len(flat_params), ))
random_params = unravel_fn(random_params)
params return params
@staticmethod
def get_constrained_params(params):
return {"mean": params["mean"],
"variance": jnp.exp(params['log_variance']),
"scale": jnp.exp(params['log_scale']),
"X_inducing": params["X_inducing"],
"variational_mean": params["variational_mean"],
"variational_corr_chol_param": params["variational_corr_chol_param"],
"variational_sigma": jnp.exp(params["log_variational_sigma"])}
@staticmethod
def get_q_f(params, x_i, prior_distribution, variational_distribution):
= x_i.reshape(1, -1) # ensure correct shape
x_i
= params['variance'] * kernels.ExpSquared(scale=params["scale"])
kernel_fn = kernel_fn(x_i, params["X_inducing"])
K_im = prior_distribution.covariance()
K_mm = jnp.linalg.cholesky(K_mm)
chol_mm = jsp.linalg.cho_solve((chol_mm, True), K_im.T).T
A
= A@params["variational_mean"]
mu_i = kernel_fn(x_i, x_i) + A@(variational_distribution.covariance() - K_mm)@A.T
sigma_sqr_i
return tfd.Normal(loc=mu_i, scale=sigma_sqr_i**0.5)
def get_distributions(self, params):
= params['variance'] * kernels.ExpSquared(scale=params["scale"])
kernel_fn = params["mean"]
prior_mean = kernel_fn(params["X_inducing"], params["X_inducing"]) + jnp.eye(self.n_inducing)*JITTER
prior_cov = tfd.MultivariateNormalFullCovariance(prior_mean, prior_cov)
prior_distribution
= tfb.CorrelationCholesky()(params["variational_corr_chol_param"])
corr_chol = jnp.diag(params["variational_sigma"])
sigma = sigma*sigma.T*(corr_chol@corr_chol.T) + jnp.eye(self.n_inducing)*JITTER
variational_cov = tfd.MultivariateNormalFullCovariance(params["variational_mean"], variational_cov)
variational_distribution
return prior_distribution, variational_distribution
def loss_fn(self, params, X_batch, y_batch, seed):
= self.get_constrained_params(params)
params
# Get distributions
= self.get_distributions(params)
prior_distribution, variational_distribution
# Compute kl
= variational_distribution.kl_divergence(prior_distribution)
kl
# Compute log likelihood
def log_likelihood_fn(x_i, y_i, seed):
= self.get_q_f(params, x_i, prior_distribution, variational_distribution)
q_f = q_f.sample(seed=seed)
sample = tfd.Bernoulli(logits=sample).log_prob(y_i)
log_likelihood return log_likelihood.squeeze()
= jax.random.split(seed, num=len(y_batch))
seeds = jax.vmap(log_likelihood_fn)(X_batch, y_batch, seeds).sum() * self.data_size/len(y_batch)
log_likelihood
return kl - log_likelihood
def fit_fn(self, X, y, init_params, optimizer, n_iters, batch_size, seed):
= optimizer.init(init_params)
state = jax.value_and_grad(self.loss_fn)
value_and_grad_fn
def one_step(params_and_state, seed):
= params_and_state
params, state = jax.random.choice(seed, self.data_size, (batch_size,), replace=False)
idx = X[idx], y[idx]
X_batch, y_batch
= jax.random.split(seed, 1)[0]
seed2 = value_and_grad_fn(params, X_batch, y_batch, seed2)
loss, grads = optimizer.update(grads, state)
updates, state = optax.apply_updates(params, updates)
params return (params, state), (loss, params)
= jax.random.split(seed, num=n_iters)
seeds = jax.lax.scan(one_step, (init_params, state), xs=seeds)
(best_params, _), (loss_history, params_history) return best_params, loss_history, params_history
def predict_fn(self, params, X_new):
= self.get_constrained_params(params)
constrained_params = self.get_distributions(constrained_params)
prior_distribution, variational_distribution
def _predict_fn(x_i):
# Get posterior
= self.get_q_f(constrained_params, x_i, prior_distribution, variational_distribution)
q_f return q_f.mean().squeeze(), q_f.variance().squeeze()
= jax.vmap(_predict_fn)(X_new)
mean, var return mean.squeeze(), var.squeeze()
Train and predict
= 20
n_inducing = 100
n_epochs = 10
batch_size = len(y)
data_size = n_epochs*(data_size/batch_size)
n_iters n_iters
1000.0
= jax.random.PRNGKey(0)
key = jax.random.split(key)
key2, subkey = optax.adam(learning_rate=0.01)
optimizer
= jax.random.choice(key, X, (n_inducing,), replace=False)
X_inducing = SVGP(X_inducing, data_size)
model
= model.init_params(key2)
init_params
model.loss_fn(init_params, X, y, key)= model.fit_fn(X, y, init_params, optimizer, n_iters, batch_size, subkey)
best_params, loss_history, params_history
plt.figure();
plt.plot(loss_history)"Loss"); plt.title(
= jnp.linspace(-3.5, 3.5, 100)
x = jax.random.PRNGKey(123)
seed
= jnp.meshgrid(x, x)
X1, X2 = lambda x1, x2: model.predict_fn(best_params, jnp.array([x1, x2]).reshape(1, -1))
f = jax.vmap(jax.vmap(f))(X1, X2)
pred_mean, pred_var = tfd.Normal(pred_mean, pred_var**0.5).sample(seed=seed, sample_shape=(10000,))
logits = jax.nn.sigmoid(logits)
proba
= proba.mean(axis=0)
proba_mean = proba.std(axis=0)*2 proba_std2
= plt.subplots(1, 2, figsize=(12,4))
fig, ax = ax[0].contourf(X1, X2, proba_mean.squeeze(), alpha=0.5, levels=20)
cplot1 =ax[0])
plt.colorbar(cplot1, ax
= ax[1].contourf(X1, X2, proba_std2.squeeze(), alpha=0.5, levels=20)
cplot2 =ax[1])
plt.colorbar(cplot2, ax
0].scatter(X[:, 0], X[:, 1], c=y);
ax[1].scatter(X[:, 0], X[:, 1], c=y);
ax[
0].set_title("Posterior $\mu$");
ax[1].set_title("Posterior $\mu \pm 2*\sigma$"); ax[
Some more datasets
def fit_and_plot(X, y):
= StandardScaler().fit_transform(X) # Yes, this is useful for GPs
X = map(jnp.array, (X, y))
X, y
= jax.random.choice(key, X, (n_inducing,), replace=False)
X_inducing = SVGP(X_inducing, data_size)
model
= model.init_params(key2)
init_params
model.loss_fn(init_params, X, y, key)= model.fit_fn(X, y, init_params, optimizer, n_iters, batch_size, subkey)
best_params, loss_history, params_history
plt.figure();
plt.plot(loss_history)"Loss");
plt.title(
= lambda x1, x2: model.predict_fn(best_params, jnp.array([x1, x2]).reshape(1, -1))
f = jax.vmap(jax.vmap(f))(X1, X2)
pred_mean, pred_var = tfd.Normal(pred_mean, pred_var**0.5).sample(seed=seed, sample_shape=(10000,))
logits = jax.nn.sigmoid(logits)
proba
= proba.mean(axis=0)
proba_mean = proba.std(axis=0)*2
proba_std2
= plt.subplots(1, 2, figsize=(12,4))
fig, ax = ax[0].contourf(X1, X2, proba_mean.squeeze(), alpha=0.5, levels=20)
cplot1 =ax[0])
plt.colorbar(cplot1, ax
= ax[1].contourf(X1, X2, proba_std2.squeeze(), alpha=0.5, levels=20)
cplot2 =ax[1])
plt.colorbar(cplot2, ax
0].scatter(X[:, 0], X[:, 1], c=y);
ax[1].scatter(X[:, 0], X[:, 1], c=y);
ax[
0].set_title("Posterior $\mu$");
ax[1].set_title("Posterior $\mu \pm 2*\sigma$"); ax[
make_blobs
= make_blobs(n_samples=n_samples, random_state=random_state, centers=2)
X, y
0], X[:, 1], c=y);
plt.scatter(X[:, fit_and_plot(X, y)
make_circles
= make_circles(n_samples=n_samples, random_state=random_state, noise=noise, factor=0.1)
X, y
0], X[:, 1], c=y);
plt.scatter(X[:, fit_and_plot(X, y)