from math import prod
from functools import partial
from time import time
import blackjax
import flax.linen as nn
import jax
from jax.flatten_util import ravel_pytree
import jax.tree_util as jtu
import jax.numpy as jnp
# jnp.set_printoptions(linewidth=2000)
import optax
from tqdm import trange
import arviz as az
import seaborn as sns
import matplotlib.pyplot as plt
"jax_enable_x64", False)
jax.config.update(
%reload_ext watermark
Some helper functions:
= 1e-6
jitter
def get_shapes(params):
return jtu.tree_map(lambda x:x.shape, params)
def svd_inverse(matrix):
= jnp.linalg.svd(matrix+jnp.eye(matrix.shape[0])*jitter)
U, S, V
return V.T/S@U.T
Dataset
We take XOR dataset to begin with:
= jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
X = jnp.array([0, 1, 1, 0])
y
X.shape, y.shape
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
((4, 2), (4,))
NN model
class MLP(nn.Module):
features: []
@nn.compact
def __call__(self, x):
for n_features in self.features[:-1]:
= nn.Dense(n_features, kernel_init=jax.nn.initializers.glorot_normal(), bias_init=jax.nn.initializers.normal())(x)
x = nn.relu(x)
x
= nn.Dense(features[-1])(x)
x return x.ravel()
Let us initialize the weights of NN and inspect shapes of the parameters:
= [2, 1]
features = jax.random.PRNGKey(0)
key
= MLP(features)
model = model.init(key, X).unfreeze()
params
get_shapes(params)
{'params': {'Dense_0': {'bias': (2,), 'kernel': (2, 2)},
'Dense_1': {'bias': (1,), 'kernel': (2, 1)}}}
apply(params, X) model.
DeviceArray([ 0.00687164, -0.01380461, 0. , 0. ], dtype=float32)
Negative Log Joint
= 0.1
noise_var
def neg_log_joint(params):
= model.apply(params, X)
y_pred = ravel_pytree(params)[0]
flat_params = jax.scipy.stats.norm.logpdf(flat_params).sum()
log_prior = jax.scipy.stats.norm.logpdf(y, loc=y_pred, scale=noise_var).sum()
log_likelihood
return -(log_prior + log_likelihood)
Testing if it works:
neg_log_joint(params)
DeviceArray(105.03511, dtype=float32)
Find MAP
= jax.random.PRNGKey(0)
key = model.init(key, X).unfreeze()
params = 1000
n_iters
= jax.jit(jax.value_and_grad(neg_log_joint))
value_and_grad_fn = optax.adam(0.01)
opt = opt.init(params)
state
def one_step(params_and_state, xs):
= params_and_state
params, state = value_and_grad_fn(params)
loss, grads = opt.update(grads, state)
updates, state = optax.apply_updates(params, updates)
params return (params, state), loss
= jax.lax.scan(one_step, init=(params, state), xs=None, length=n_iters)
(params, state), losses
; plt.plot(losses)
= model.apply(params, X)
y_map y_map
DeviceArray([0.01383345, 0.98666817, 0.98563665, 0.01507111], dtype=float32)
= jnp.linspace(-0.1,1.1,100)
x = jnp.meshgrid(x, x)
X1, X2
def predict_fn(x1, x2):
return model.apply(params, jnp.array([x1,x2]).reshape(1,2))
= jax.jit(jax.vmap(jax.vmap(predict_fn)))
predict_fn_vec
= predict_fn_vec(X1, X2).squeeze()
Z
plt.contourf(X1, X2, Z); plt.colorbar()
Full Hessian Laplace
= ravel_pytree(params)
flat_params, unravel_fn
def neg_log_joint_flat(flat_params):
return neg_log_joint(unravel_fn(flat_params))
= jax.hessian(neg_log_joint_flat)(flat_params)
H
; sns.heatmap(H)
= svd_inverse(H)
posterior_cov
; sns.heatmap(posterior_cov)
Note that we can sample parameters from the posterior and revert them to correct structure with the unravel_fn
. Here is a class to do it all:
class FullHessianLaplace:
def __init__(self, map_params, model):
self.unravel_fn = ravel_pytree(map_params)
flat_params,
def neg_log_joint_flat(flat_params):
= unravel_fn(flat_params)
params return neg_log_joint(params)
self.H = jax.hessian(neg_log_joint_flat)(flat_params)
self.mean = flat_params
self.cov = svd_inverse(self.H)
self.model = model
def _vectorize(self, f, seed, shape, f_kwargs={}):
= prod(shape)
length = jax.random.split(seed, num=length).reshape(shape+(2,))
seeds
= partial(f, **f_kwargs)
sample_fn for _ in shape:
= jax.vmap(sample_fn)
sample_fn
return sample_fn(seed=seeds)
def _sample(self, seed):
= jax.random.multivariate_normal(seed, mean=self.mean, cov=self.cov)
sample return self.unravel_fn(sample)
def sample(self, seed, shape):
return self._vectorize(self._sample, seed, shape)
def _predict(self, X, seed):
= self._sample(seed)
sample return self.model.apply(sample, X)
def predict(self, X, seed, shape):
return self._vectorize(self._predict, seed, shape, {'X': X})
Estimating predictive posterior
= FullHessianLaplace(params, model)
posterior
= jax.random.PRNGKey(1)
seed = 100000
n_samples = posterior.predict(X, seed=seed, shape=(n_samples,))
y_pred_full = 5
ulim = -5
llim
= plt.subplots(2,2,figsize=(12,4))
fig, ax =ax.ravel()
axfor i in range(len(y)):
=ax[i]);
az.plot_dist(y_pred_full[:, i], axTrue)
ax[i].grid(range(llim,ulim))
ax[i].set_xticks(
ax[i].set_xlim(llim, ulim)f"X={X[i]}, y_pred_mean={y_pred_full[:, i].mean():.3f}, y_map={y_map[i]:.3f}")
ax[i].set_title( fig.tight_layout()
KFAC-Laplace
We need to invert partial Hessians to do KFAC-Laplace. We can use tree_flatten
with ravel_pytree
to ease the workflow. We need to: 1. pick up partial Hessians in pure matrix form to be able to invert them. 2. Create layer-wise distributions and sample them. These samples will be 1d arrays. 3. We need to convert those 1d arrays to params
dictionary form so that we can plug it into the flax model and get posterior predictions.
First we need to segregate the parameters layer-wise. We will use is_leaf
condition to stop traversing the parameter PyTree at a perticular depth. See how it is different from vanilla tree_flatten
:
= jtu.tree_flatten(params)
flat_params, tree_def display(flat_params, tree_def)
[DeviceArray([-0.00024913, 0.00027019], dtype=float32),
DeviceArray([[ 0.8275324 , -0.8314813 ],
[-0.8276633 , 0.83254045]], dtype=float32),
DeviceArray([0.01351773], dtype=float32),
DeviceArray([[1.1750739],
[1.1685134]], dtype=float32)]
PyTreeDef({'params': {'Dense_0': {'bias': *, 'kernel': *}, 'Dense_1': {'bias': *, 'kernel': *}}})
= lambda param: 'bias' in param
is_leaf = jtu.tree_flatten(params, is_leaf=is_leaf)
layers, tree_def display(layers, tree_def)
[{'bias': DeviceArray([-0.00024913, 0.00027019], dtype=float32),
'kernel': DeviceArray([[ 0.8275324 , -0.8314813 ],
[-0.8276633 , 0.83254045]], dtype=float32)},
{'bias': DeviceArray([0.01351773], dtype=float32),
'kernel': DeviceArray([[1.1750739],
[1.1685134]], dtype=float32)}]
PyTreeDef({'params': {'Dense_0': *, 'Dense_1': *}})
The difference is clearly evident. Now, we need to flatten the inner dictionaries to get 1d arrays.
= list(map(lambda x: ravel_pytree(x)[0], layers))
flat_params = list(map(lambda x: ravel_pytree(x)[1], layers))
unravel_fn_list display(flat_params, unravel_fn_list)
[DeviceArray([-2.4912864e-04, 2.7019347e-04, 8.2753241e-01,
-8.3148128e-01, -8.2766330e-01, 8.3254045e-01], dtype=float32),
DeviceArray([0.01351773, 1.1750739 , 1.1685134 ], dtype=float32)]
[<function jax._src.flatten_util.ravel_pytree.<locals>.<lambda>(flat)>,
<function jax._src.flatten_util.ravel_pytree.<locals>.<lambda>(flat)>]
def modified_neg_log_joint_fn(flat_params):
= jtu.tree_map(lambda unravel_fn, flat_param: unravel_fn(flat_param), unravel_fn_list, flat_params)
layers = tree_def.unflatten(layers)
params return neg_log_joint(params)
= jax.hessian(modified_neg_log_joint_fn)(flat_params)
full_hessian
# Pick diagonal entries from the Hessian
= [full_hessian[i][i] for i in range(len(full_hessian))]
useful_hessians useful_hessians
[DeviceArray([[139.07985, 0. , 138.07985, 0. , 0. ,
0. ],
[ 0. , 410.62708, 0. , 136.54236, 0. ,
273.08472],
[138.07985, 0. , 139.07985, 0. , 0. ,
0. ],
[ 0. , 136.54236, 0. , 137.54236, 0. ,
136.54236],
[ 0. , 0. , 0. , 0. , 1. ,
0. ],
[ 0. , 273.08472, 0. , 136.54236, 0. ,
274.08472]], dtype=float32),
DeviceArray([[400.99997, 82.72832, 83.44101],
[ 82.72832, 69.43975, 0. ],
[ 83.44101, 0. , 70.35754]], dtype=float32)]
Each entry in above list corresponds to layer-wise hessian matrices. Now, we need to create layer-wise distributions, sample from them and reconstruct params
using the similar tricks we used above:
class KFACHessianLaplace:
def __init__(self, map_params, model):
self.model = model
self.tree_def = jtu.tree_flatten(map_params, is_leaf=lambda x: 'bias' in x)
layers, = [ravel_pytree(layer) for layer in layers]
flat_layers self.means = list(map(lambda x: x[0], flat_layers))
self.unravel_fn_list = list(map(lambda x: x[1], flat_layers))
def neg_log_joint_flat(flat_params):
= [self.unravel_fn_list[i](flat_params[i]) for i in range(len(flat_params))]
flat_layers = self.tree_def.unflatten(flat_layers)
params return neg_log_joint(params)
self.H = jax.hessian(neg_log_joint_flat)(self.means)
self.useful_H = [self.H[i][i] for i in range(len(self.H))]
self.covs = [svd_inverse(matrix) for matrix in self.useful_H]
def _vectorize(self, f, seed, shape, f_kwargs={}):
= prod(shape)
length = jax.random.split(seed, num=length).reshape(shape+(2,))
seeds
= partial(f, **f_kwargs)
sample_fn for _ in shape:
= jax.vmap(sample_fn)
sample_fn
return sample_fn(seed=seeds)
def _sample_partial(self, seed, unravel_fn, mean, cov):
= jax.random.multivariate_normal(seed, mean=mean, cov=cov)
sample return unravel_fn(sample)
def _sample(self, seed):
= [seed for seed in jax.random.split(seed, num=len(self.means))]
seeds = jtu.tree_map(self._sample_partial, seeds, self.unravel_fn_list, self.means, self.covs)
flat_sample = self.tree_def.unflatten(flat_sample)
sample return sample
def sample(self, seed, n_samples=1):
return self._vectorize(self._sample, seed, shape)
def _predict(self, X, seed):
= self._sample(seed)
sample return self.model.apply(sample, X)
def predict(self, X, seed, shape):
return self._vectorize(self._predict, seed, shape, {'X': X})
Estimating predictive posterior
= KFACHessianLaplace(params, model)
kfac_posterior
= jax.random.PRNGKey(1)
seed = 1000000
n_samples = kfac_posterior.predict(X, seed=seed, shape=(n_samples, ))
y_pred_kfac = 5
ulim = -5
llim
= plt.subplots(2,2,figsize=(12,4))
fig, ax =ax.ravel()
axfor i in range(len(y)):
=ax[i], label='full', color='r')
az.plot_dist(y_pred_full[:, i], ax=ax[i], label='kfac', color='b')
az.plot_dist(y_pred_kfac[:, i], axTrue)
ax[i].grid(range(llim,ulim))
ax[i].set_xticks(
ax[i].set_xlim(llim, ulim)f"X={X[i]}, y_map={y_map[i]:.3f}")
ax[i].set_title( fig.tight_layout()
We can see that KFAC is approximating the trend of Full Hessian Laplace. We can visualize the Covariance matrices as below.
= plt.subplots(1,2,figsize=(18,5))
fig, ax =ax[0], annot=True, fmt = '.2f')
sns.heatmap(posterior.cov, ax0].set_title('Full')
ax[
= posterior.cov * 0
kfac_cov = 0
offset for cov in kfac_posterior.covs:
= cov.shape[0]
length = kfac_cov.at[offset:offset+length, offset:offset+length].set(cov)
kfac_cov += length
offset
=ax[1], annot=True, fmt = '.2f')
sns.heatmap(kfac_cov, ax1].set_title('KFAC'); ax[
Comparison with MCMC
Inspired from a blackjax docs example.
= jax.random.PRNGKey(0)
key = jax.random.split(key, 2)
warmup_key, inference_key = 5000
num_warmup = n_samples
num_samples
= model.init(key, X)
initial_position def logprob(params):
return -neg_log_joint(params)
def inference_loop(rng_key, kernel, initial_state, num_samples):
def one_step(state, rng_key):
= kernel(rng_key, state)
state, _ return state, state
= jax.random.split(rng_key, num_samples)
keys = jax.lax.scan(one_step, initial_state, keys)
_, states
return states
= time()
init = blackjax.window_adaptation(blackjax.nuts, logprob, num_warmup)
adapt = adapt.run(warmup_key, initial_position)
final_state, kernel, _ = inference_loop(inference_key, kernel, final_state, num_samples)
states = states.position.unfreeze()
samples print(f"Sampled {n_samples} samples in {time()-init:.2f} seconds")
Sampled 1000000 samples in 27.85 seconds
= jax.vmap(model.apply, in_axes=(0, None))(samples, X)
y_pred_mcmc
= 5
ulim = -5
llim
= plt.subplots(2,2,figsize=(12,4))
fig, ax =ax.ravel()
axfor i in range(len(y)):
=ax[i], label='full', color='r')
az.plot_dist(y_pred_full[:, i], ax=ax[i], label='kfac', color='b')
az.plot_dist(y_pred_kfac[:, i], ax=ax[i], label='mcmc', color='k')
az.plot_dist(y_pred_mcmc[:, i], axTrue)
ax[i].grid(range(llim,ulim))
ax[i].set_xticks(
ax[i].set_xlim(llim, ulim)f"X={X[i]}, y_map={y_map[i]:.3f}")
ax[i].set_title( fig.tight_layout()
= plt.subplots(1,3,figsize=(18,5))
fig, ax =0.1)
fig.subplots_adjust(wspace=ax[0], annot=True, fmt = '.2f')
sns.heatmap(posterior.cov, ax0].set_title('Full')
ax[
= posterior.cov * 0
kfac_cov = 0
offset for cov in kfac_posterior.covs:
= cov.shape[0]
length = kfac_cov.at[offset:offset+length, offset:offset+length].set(cov)
kfac_cov += length
offset
=ax[1], annot=True, fmt = '.2f')
sns.heatmap(kfac_cov, ax1].set_title('KFAC');
ax[
= jnp.cov(jax.vmap(lambda x: ravel_pytree(x)[0])(samples).T)
mcmc_cov
=ax[2], annot=True, fmt = '.2f')
sns.heatmap(mcmc_cov, ax2].set_title('MCMC'); ax[
Library versions
%watermark --iversions
flax : 0.6.1
blackjax : 0.8.2
optax : 0.1.3
matplotlib: 3.5.1
jax : 0.3.23
arviz : 0.12.1
seaborn : 0.11.2
json : 2.0.9