import jax
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from tinygp import kernels
Inspired from this GPSS video.
- ICM - Intrinsic coregionalization model
- SLFM - Semiparametric Latent Factor Model
- LMC - Linear model of coregionalization
Helper functions
def random_fill(key, params):
= ravel_pytree(params)
values, unravel_fn = jax.random.normal(key, shape=values.shape)
random_values return unravel_fn(random_values)
def get_real_params(params):
for i in range(1, q_len+1):
f'a{i}'] = params[f'a{i}'].reshape(n_outputs, rank)
params[if method == 'icm':
'var'] = jnp.exp(params['log_var'])
params['scale'] = jnp.exp(params['log_scale'])
params['noise'] = jnp.exp(params['log_noise'])
params[elif method == 'lmc':
for i in range(1, q_len+1):
f'var{i}'] = jnp.exp(params[f'log_var{i}'])
params[f'scale{i}'] = jnp.exp(params[f'log_scale{i}'])
params[f'noise{i}'] = jnp.exp(params[f'log_noise{i}'])
params[return params
def kron_cov_fn(params, x1, x2, add_noise=False):
= get_real_params(params)
params = [params[f'a{i}'] for i in range(1, q_len+1)]
a_list
if method == 'icm':
= params['var'] * kernels.ExpSquared(scale=params['scale'])
kernel_fn = kernel_fn(x1, x2)
cov if add_noise:
= cov + jnp.eye(cov.shape[0])*params['noise']
cov
= jax.tree_util.tree_reduce(lambda x1, x2: x1@x1.T+x2@x2.T, a_list)
B # print(B.shape, cov.shape)
return jnp.kron(B, cov)
elif method == 'lmc':
= []
cov_list for idx in range(1, q_len+1):
= params[f'var{idx}'] * kernels.ExpSquared(scale=params[f'scale{idx}'])
kernel_fn = kernel_fn(x1, x2)
cov if add_noise:
= cov + jnp.eye(cov.shape[0])*params[f'noise{idx}']
cov
= a_list[idx-1]@a_list[idx-1].T
B
cov_list.append(jnp.kron(B, cov))
return jax.tree_util.tree_reduce(lambda x1, x2: x1+x2, cov_list)
Configuration
= 2
q_len = 2 # if 1, slfm
rank = 2
n_outputs
= 'lmc' # lmc, icm method
if rank = 1
, lmc
becomes slfm
.
Generative process
= jax.random.PRNGKey(4)
x_key
= jax.random.uniform(x_key, shape=(40, 1)).sort(axis=0)
x = jnp.linspace(0,1,100).reshape(-1, 1)
x_test
= jax.random.split(x_key)
e1_key, e2_key
= jax.random.normal(e1_key, shape=(x.shape[0],))
e1 = jax.random.normal(e2_key, shape=(x.shape[0],))
e2
if method == 'icm':
= 0.01
noise = 1.2*kernels.ExpSquared(scale=0.2)
gen_kernel = gen_kernel(x, x) + jnp.eye(x.shape[0])*noise
gen_covariance = jnp.linalg.cholesky(gen_covariance)
gen_chol
= gen_chol@e1
y1 = gen_chol@e2
y2
= jnp.concatenate([y1, y2])
y
elif method == 'lmc':
= 0.01
noise1 = 0.1
noise2 = 1.2*kernels.ExpSquared(scale=0.1)
gen_kernel1 = gen_kernel1(x, x) + jnp.eye(x.shape[0])*noise1
gen_covariance1 = jnp.linalg.cholesky(gen_covariance1)
gen_chol1
= 0.8*kernels.ExpSquared(scale=0.2)
gen_kernel2 = gen_kernel2(x, x) + jnp.eye(x.shape[0])*noise2
gen_covariance2 = jnp.linalg.cholesky(gen_covariance2)
gen_chol2
= gen_chol1@e1
y1 = gen_chol2@e2
y2
= jnp.concatenate([y1, y2])
y
='y1')
plt.scatter(x, y1, label='y2')
plt.scatter(x, y2, label; plt.legend()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
def loss_fn(params):
= kron_cov_fn(params, x, x, add_noise=True)
mo_cov # print(y.shape, mo_cov.shape)
return -jax.scipy.stats.multivariate_normal.logpdf(y, jnp.zeros_like(y), mo_cov)
= jax.random.PRNGKey(1)
key if method == 'icm':
= {'log_var':0.0, 'log_scale':0.0, 'log_noise':0.0}
params for i in range(1, q_len+1):
f'a{i}'] = jnp.zeros((n_outputs, rank))
params[elif method == 'lmc':
= {}
params for i in range(1, q_len+1):
f'a{i}'] = jnp.zeros((n_outputs, rank))
params[f'log_var{i}'] = 0.0
params[f'log_scale{i}'] = 0.0
params[f'log_noise{i}'] = 0.0
params[
= random_fill(key, params)
params params
{'a1': DeviceArray([[-0.764527 , 1.0286916],
[-1.0690447, -0.7921495]], dtype=float32),
'a2': DeviceArray([[ 0.8845895, -1.1941622],
[-1.7434924, 1.5159688]], dtype=float32),
'log_noise1': DeviceArray(-1.1254696, dtype=float32),
'log_noise2': DeviceArray(-0.22446911, dtype=float32),
'log_scale1': DeviceArray(0.39719132, dtype=float32),
'log_scale2': DeviceArray(-0.22453257, dtype=float32),
'log_var1': DeviceArray(-0.7590596, dtype=float32),
'log_var2': DeviceArray(-0.08601531, dtype=float32)}
loss_fn(params)
DeviceArray(116.04026, dtype=float32)
= jax.random.PRNGKey(3)
key = random_fill(key, params)
params
= 1000
n_iters
= jax.jit(jax.value_and_grad(loss_fn))
value_and_grad_fn = optax.adam(0.01)
opt = opt.init(params)
state
def one_step(params_and_state, xs):
= params_and_state
params, state = value_and_grad_fn(params)
loss, grads = opt.update(grads, state)
updates, state = optax.apply_updates(params, updates)
params return (params, state), (params, loss)
= jax.lax.scan(one_step, init=(params, state), xs=None, length=n_iters)
(tuned_params, state), (params_history, loss_history)
; plt.plot(loss_history)
def predict_fn(params, x_test):
= kron_cov_fn(params, x, x, add_noise=True)
cov = kron_cov_fn(params, x_test, x_test, add_noise=True)
test_cov = kron_cov_fn(params, x_test, x, add_noise=False)
cross_cov
= jnp.linalg.cholesky(cov)
chol = jax.scipy.linalg.cho_solve((chol, True), y)
k_inv_y = jax.scipy.linalg.cho_solve((chol, True), cross_cov.T)
k_inv_cross_cov
= cross_cov@k_inv_y
pred_mean = test_cov - cross_cov@k_inv_cross_cov
pred_cov return pred_mean, pred_cov
= predict_fn(tuned_params, x_test)
pred_mean, pred_cov = 2 * jnp.diag(pred_cov)**0.5
pred_conf
='y1')
plt.scatter(x, y1, label='y2')
plt.scatter(x, y2, label0]], label='pred_y1')
plt.plot(x_test, pred_mean[:x_test.shape[0]:], label='pred_y2')
plt.plot(x_test, pred_mean[x_test.shape[0]] - pred_conf[:x_test.shape[0]], pred_mean[:x_test.shape[0]] + pred_conf[:x_test.shape[0]], label='pred_conf_y1', alpha=0.3)
plt.fill_between(x_test.ravel(), pred_mean[:x_test.shape[0]:] - pred_conf[x_test.shape[0]:], pred_mean[x_test.shape[0]:] + pred_conf[x_test.shape[0]:], label='pred_conf_y2', alpha=0.3)
plt.fill_between(x_test.ravel(), pred_mean[x_test.shape[=(1,1)); plt.legend(bbox_to_anchor
for name, value in get_real_params(tuned_params).items():
if not name.startswith('log_'):
print(name, value)
a1 [[0.03664799 0.00039898]
[0.3191718 0.00344488]]
a2 [[ 0.1351072 0.00248941]
[-0.05392759 -0.04239884]]
noise1 0.6797133
noise2 0.4154678
scale1 5.048228
scale2 0.10743636
var1 0.016275918
var2 41.034225