Multi-Output Gaussian Processes

Exploring MOGPs from scratch
ML
Author

Zeel B Patel

Published

October 27, 2022

Inspired from this GPSS video.

import jax
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp

import optax

import matplotlib.pyplot as plt
from tinygp import kernels

Helper functions

def random_fill(key, params):
    values, unravel_fn = ravel_pytree(params)
    random_values = jax.random.normal(key, shape=values.shape)
    return unravel_fn(random_values)

def get_real_params(params):
    for i in range(1, q_len+1):
        params[f'a{i}'] = params[f'a{i}'].reshape(n_outputs, rank)
    if method == 'icm':
        params['var'] = jnp.exp(params['log_var'])
        params['scale'] = jnp.exp(params['log_scale'])
        params['noise'] = jnp.exp(params['log_noise'])
    elif method == 'lmc':
        for i in range(1, q_len+1):
            params[f'var{i}'] = jnp.exp(params[f'log_var{i}'])
            params[f'scale{i}'] = jnp.exp(params[f'log_scale{i}'])
            params[f'noise{i}'] = jnp.exp(params[f'log_noise{i}'])
    return params

def kron_cov_fn(params, x1, x2, add_noise=False):
    params = get_real_params(params)
    a_list = [params[f'a{i}'] for i in range(1, q_len+1)]

    if method == 'icm':
        kernel_fn = params['var'] * kernels.ExpSquared(scale=params['scale'])
        cov = kernel_fn(x1, x2)
        if add_noise:
            cov = cov + jnp.eye(cov.shape[0])*params['noise']

        B = jax.tree_util.tree_reduce(lambda x1, x2: x1@x1.T+x2@x2.T, a_list)
#         print(B.shape, cov.shape)
        return jnp.kron(B, cov)

    elif method == 'lmc':
        cov_list = []
        for idx in range(1, q_len+1):
            kernel_fn = params[f'var{idx}'] * kernels.ExpSquared(scale=params[f'scale{idx}'])
            cov = kernel_fn(x1, x2)
            if add_noise:
                cov = cov + jnp.eye(cov.shape[0])*params[f'noise{idx}']

            B = a_list[idx-1]@a_list[idx-1].T
            cov_list.append(jnp.kron(B, cov))
            
        return jax.tree_util.tree_reduce(lambda x1, x2: x1+x2, cov_list)

Configuration

q_len = 2
rank = 2 # if 1, slfm
n_outputs = 2

method = 'lmc' # lmc, icm

if rank = 1, lmc becomes slfm.

Generative process

x_key = jax.random.PRNGKey(4)

x = jax.random.uniform(x_key, shape=(40, 1)).sort(axis=0)
x_test = jnp.linspace(0,1,100).reshape(-1, 1)

e1_key, e2_key = jax.random.split(x_key)

e1 = jax.random.normal(e1_key, shape=(x.shape[0],))
e2 = jax.random.normal(e2_key, shape=(x.shape[0],))

if method == 'icm':
    noise = 0.01
    gen_kernel = 1.2*kernels.ExpSquared(scale=0.2)
    gen_covariance = gen_kernel(x, x) + jnp.eye(x.shape[0])*noise
    gen_chol = jnp.linalg.cholesky(gen_covariance)
    
    y1 = gen_chol@e1
    y2 = gen_chol@e2

    y = jnp.concatenate([y1, y2])
    
elif method == 'lmc':
    noise1 = 0.01
    noise2 = 0.1
    gen_kernel1 = 1.2*kernels.ExpSquared(scale=0.1)
    gen_covariance1 = gen_kernel1(x, x) + jnp.eye(x.shape[0])*noise1
    gen_chol1 = jnp.linalg.cholesky(gen_covariance1)

    gen_kernel2 = 0.8*kernels.ExpSquared(scale=0.2)
    gen_covariance2 = gen_kernel2(x, x) + jnp.eye(x.shape[0])*noise2
    gen_chol2 = jnp.linalg.cholesky(gen_covariance2)
    
    y1 = gen_chol1@e1
    y2 = gen_chol2@e2

    y = jnp.concatenate([y1, y2])
    

plt.scatter(x, y1, label='y1')
plt.scatter(x, y2, label='y2')
plt.legend();
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

def loss_fn(params):
    mo_cov = kron_cov_fn(params, x, x, add_noise=True)
#     print(y.shape, mo_cov.shape)
    return -jax.scipy.stats.multivariate_normal.logpdf(y, jnp.zeros_like(y), mo_cov)
key = jax.random.PRNGKey(1)
if method == 'icm':
    params = {'log_var':0.0, 'log_scale':0.0, 'log_noise':0.0}
    for i in range(1, q_len+1):
        params[f'a{i}'] = jnp.zeros((n_outputs, rank))
elif method == 'lmc':
    params = {}
    for i in range(1, q_len+1):
        params[f'a{i}'] = jnp.zeros((n_outputs, rank))
        params[f'log_var{i}'] = 0.0
        params[f'log_scale{i}'] = 0.0
        params[f'log_noise{i}'] = 0.0

params = random_fill(key, params)
params
{'a1': DeviceArray([[-0.764527 ,  1.0286916],
              [-1.0690447, -0.7921495]], dtype=float32),
 'a2': DeviceArray([[ 0.8845895, -1.1941622],
              [-1.7434924,  1.5159688]], dtype=float32),
 'log_noise1': DeviceArray(-1.1254696, dtype=float32),
 'log_noise2': DeviceArray(-0.22446911, dtype=float32),
 'log_scale1': DeviceArray(0.39719132, dtype=float32),
 'log_scale2': DeviceArray(-0.22453257, dtype=float32),
 'log_var1': DeviceArray(-0.7590596, dtype=float32),
 'log_var2': DeviceArray(-0.08601531, dtype=float32)}
loss_fn(params)
DeviceArray(116.04026, dtype=float32)
key = jax.random.PRNGKey(3)
params = random_fill(key, params)

n_iters = 1000

value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn))
opt = optax.adam(0.01)
state = opt.init(params)

def one_step(params_and_state, xs):
    params, state = params_and_state
    loss, grads = value_and_grad_fn(params)
    updates, state = opt.update(grads, state)
    params = optax.apply_updates(params, updates)
    return (params, state), (params, loss)

(tuned_params, state), (params_history, loss_history) = jax.lax.scan(one_step, init=(params, state), xs=None, length=n_iters)

plt.plot(loss_history);

def predict_fn(params, x_test):
    cov = kron_cov_fn(params, x, x, add_noise=True)
    test_cov = kron_cov_fn(params, x_test, x_test, add_noise=True)
    cross_cov = kron_cov_fn(params, x_test, x, add_noise=False)
    
    chol = jnp.linalg.cholesky(cov)
    k_inv_y = jax.scipy.linalg.cho_solve((chol, True), y)
    k_inv_cross_cov = jax.scipy.linalg.cho_solve((chol, True), cross_cov.T)

    pred_mean = cross_cov@k_inv_y
    pred_cov = test_cov - cross_cov@k_inv_cross_cov
    return pred_mean, pred_cov
pred_mean, pred_cov = predict_fn(tuned_params, x_test)
pred_conf = 2 * jnp.diag(pred_cov)**0.5

plt.scatter(x, y1, label='y1')
plt.scatter(x, y2, label='y2')
plt.plot(x_test, pred_mean[:x_test.shape[0]], label='pred_y1')
plt.plot(x_test, pred_mean[x_test.shape[0]:], label='pred_y2')
plt.fill_between(x_test.ravel(), pred_mean[:x_test.shape[0]] - pred_conf[:x_test.shape[0]], pred_mean[:x_test.shape[0]] + pred_conf[:x_test.shape[0]], label='pred_conf_y1', alpha=0.3)
plt.fill_between(x_test.ravel(), pred_mean[x_test.shape[0]:] - pred_conf[x_test.shape[0]:], pred_mean[x_test.shape[0]:] + pred_conf[x_test.shape[0]:], label='pred_conf_y2', alpha=0.3)
plt.legend(bbox_to_anchor=(1,1));

for name, value in get_real_params(tuned_params).items():
    if not name.startswith('log_'):
        print(name, value)
a1 [[0.03664799 0.00039898]
 [0.3191718  0.00344488]]
a2 [[ 0.1351072   0.00248941]
 [-0.05392759 -0.04239884]]
noise1 0.6797133
noise2 0.4154678
scale1 5.048228
scale2 0.10743636
var1 0.016275918
var2 41.034225