# Silence WARNING:root:The use of `check_types` is deprecated and does not have any effect.
# https://github.com/tensorflow/probability/issues/1523
import logging
= logging.getLogger()
logger
class CheckTypesFilter(logging.Filter):
def filter(self, record):
return "check_types" not in record.getMessage()
logger.addFilter(CheckTypesFilter())
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
try:
import flax.linen as nn
except ModuleNotFoundError:
%pip install flax
import flax.linen as nn
try:
import optax
except ModuleNotFoundError:
%pip install optax
import optax
try:
import tensorflow_probability.substrates.jax as tfp
except ModuleNotFoundError:
%pip install tensorflow-probability
import tensorflow_probability.substrates.jax as tfp
= tfp.distributions tfd
Model
class Encoder(nn.Module):
list
features: int
encoding_dims:
@nn.compact
def __call__(self, x_context, y_context):
= jnp.hstack([x_context, y_context.reshape(x_context.shape[0], -1)])
x for n_features in self.features:
= nn.Dense(n_features)(x)
x = nn.relu(x)
x
= nn.Dense(self.encoding_dims)(x)
x
= x.mean(axis=0, keepdims=True) # option 1
representation return representation # (1, encoding_dims)
class Decoder(nn.Module):
list
features:
@nn.compact
def __call__(self, representation, x):
= jnp.repeat(representation, x.shape[0], axis=0)
representation = jnp.hstack([representation, x])
x
for n_features in self.features:
= nn.Dense(n_features)(x)
x = nn.relu(x)
x
= nn.Dense(2)(x)
x = x[:, 0], x[:, 1]
loc, raw_scale = jax.nn.softplus(raw_scale)
scale
return loc, scale
class CNP(nn.Module):
list
encoder_features: int
encoding_dims: list
decoder_features:
@nn.compact
def __call__(self, x_content, y_context, x_target):
= Encoder(self.encoder_features, self.encoding_dims)(x_content, y_context)
representation = Decoder(self.decoder_features)(representation, x_target)
loc, scale return loc, scale
def loss_fn(self, params, x_context, y_context, x_target, y_target):
= self.apply(params, x_context, y_context, x_target)
loc, scale = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
predictive_distribution return -predictive_distribution.log_prob(y_target)
Data
= 100
N = jax.random.PRNGKey(0)
seed = jnp.linspace(-1, 1, N).reshape(-1, 1)
x = lambda x: (jnp.sin(10*x) + x).flatten()
f = jax.random.normal(seed, shape=(N,)) * 0.2
noise = f(x) + noise
y
= jnp.linspace(-2, 2, N*2+10).reshape(-1, 1)
x_test = f(x_test)
y_test
='train', zorder=5)
plt.scatter(x, y, label='test', alpha=0.5)
plt.scatter(x_test, y_test, label; plt.legend()
Training
def train_fn(model, optimizer, seed, n_iterations, n_context):
= model.init(seed, x, y, x)
params = jax.value_and_grad(model.loss_fn)
value_and_grad_fn = optimizer.init(params)
state = jnp.arange(N)
indices
def one_step(params_and_state, seed):
= params_and_state
params, state = jax.random.permutation(seed, indices)
shuffled_indices = shuffled_indices[:n_context]
context_indices = shuffled_indices[n_context:]
target_indices = x[context_indices], y[context_indices]
x_context, y_context = x[target_indices], y[target_indices]
x_target, y_target = value_and_grad_fn(params, x_context, y_context, x_target, y_target)
loss, grads = optimizer.update(grads, state)
updates, state = optax.apply_updates(params, updates)
params return (params, state), loss
= jax.random.split(seed, num=n_iterations)
seeds = jax.lax.scan(one_step, (params, state), seeds)
(params, state), losses return params, losses
= [64, 16, 8]
encoder_features = 1
encoding_dims = [16, 8]
decoder_features = CNP(encoder_features, encoding_dims, decoder_features)
model = optax.adam(learning_rate=0.001)
optimizer
= jax.random.PRNGKey(2)
seed = int(0.7 * N)
n_context = 20000
n_iterations
= train_fn(model, optimizer, seed, n_iterations=n_iterations, n_context=n_context) params, losses
; plt.plot(losses)
Predict
= model.apply(params, x, y, x_test)
loc, scale = loc - 2*scale, loc + 2*scale
lower, upper
='train', alpha=0.5)
plt.scatter(x, y, label='test', alpha=0.5)
plt.scatter(x_test, y_test, label;
plt.plot(x_test, loc)=0.4);
plt.fill_between(x_test.flatten(), lower, upper, alpha-5, 5); plt.ylim(