import os
"CUDA_VISIBLE_DEVICES"] = "0"
os.environ[# turn off preallocation by JAX
"XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ[
import numpy as np
import pandas as pd
from tqdm import tqdm
import jax
import jax.numpy as jnp
import flax.linen as nn
import distrax as dx
import optax
# load mnist dataset from tensorflow datasets
import tensorflow_datasets as tfds
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
# define initializers
def first_layer_init(key, shape, dtype=jnp.float32):
= shape[0] # reverse order compared to torch
num_input return jax.random.uniform(key, shape, dtype, minval=-1.0/num_input, maxval=1.0/num_input)
def other_layers_init(key, shape, dtype=jnp.float32):
= shape[0] # reverse order compared to torch
num_input return jax.random.uniform(key, shape, dtype, minval=-np.sqrt(6 / num_input)/30, maxval=np.sqrt(6 / num_input)/30)
class Encoder(nn.Module):
list
features: int
encoding_dims:
@nn.compact
def __call__(self, x_context, y_context):
= jnp.hstack([x_context, y_context.reshape(x_context.shape[0], -1)])
x
= nn.Dense(self.features[0], kernel_init=first_layer_init, bias_init=first_layer_init)(x)
x = jnp.sin(30*x)
x # x = nn.Dense(self.features[0])(x)
# x = nn.relu(x)
for n_features in self.features[1:]:
= nn.Dense(n_features, kernel_init=other_layers_init, bias_init=other_layers_init)(x)
x = jnp.sin(30*x)
x # x = nn.Dense(n_features)(x)
# x = nn.relu(x)
= nn.Dense(self.encoding_dims)(x)
x
= x.mean(axis=0, keepdims=True) # option 1
representation return representation # (1, encoding_dims)
class Decoder(nn.Module):
list
features: int
output_dim:
@nn.compact
def __call__(self, representation, x):
= jnp.repeat(representation, x.shape[0], axis=0)
representation = jnp.hstack([representation, x])
x
= nn.Dense(self.features[0], kernel_init=first_layer_init, bias_init=first_layer_init)(x)
x = jnp.sin(30*x)
x # x = nn.Dense(self.features[0])(x)
# x = nn.relu(x)
for n_features in self.features:
= nn.Dense(n_features, kernel_init=other_layers_init, bias_init=other_layers_init)(x)
x = jnp.sin(30*x)
x # x = nn.Dense(n_features)(x)
# x = nn.relu(x)
= nn.Dense(self.output_dim*2)(x)
x = x[:, :self.output_dim], x[:, self.output_dim:]
loc, raw_scale = jnp.exp(raw_scale)
scale
return loc, scale
class CNP(nn.Module):
list
encoder_features: int
encoding_dims: list
decoder_features: int
output_dim:
@nn.compact
def __call__(self, x_content, y_context, x_target):
= Encoder(self.encoder_features, self.encoding_dims)(x_content, y_context)
representation = Decoder(self.decoder_features, self.output_dim)(representation, x_target)
loc, scale return loc, scale
def loss_fn(self, params, x_context, y_context, x_target, y_target):
= self.apply(params, x_context, y_context, x_target)
loc, scale = dx.MultivariateNormalDiag(loc=loc, scale_diag=0.005+scale)
predictive_distribution return -predictive_distribution.log_prob(y_target)
Load MNIST
= tfds.load('mnist') ds
def dataset_to_arrays(dataset):
= []
data = []
labels = 0
stopper = 100
end for sample in dataset:
"image"].numpy())
data.append(sample["label"].numpy())
labels.append(sample[+= 1
stopper if stopper == end:
break
return np.array(data), np.array(labels)[..., None]
= dataset_to_arrays(ds["train"])
train_data, train_labels = dataset_to_arrays(ds["test"])
test_data, test_labels
train_data.shape, train_labels.shape, test_data.shape, test_labels.shape
2023-06-02 09:58:48.609001: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2023-06-02 09:58:48.681190: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
((100, 28, 28, 1), (100, 1), (100, 28, 28, 1), (100, 1))
= np.linspace(-1, 1, 28)
coords = np.meshgrid(coords, coords)
x, y = jnp.stack([x, y], axis=-1).reshape(-1, 2)
train_X
= jax.vmap(lambda x: x.reshape(-1, 1))(train_data) / 255.0
train_y type(train_X), type(train_y) train_X.shape, train_y.shape,
((784, 2),
(100, 784, 1),
jaxlib.xla_extension.ArrayImpl,
jaxlib.xla_extension.ArrayImpl)
= 10000
iterations
def loss_fn(params, context_X, context_y, target_X, target_y):
def loss_fn_per_sample(context_X, context_y, target_X, target_y):
= model.apply(params, context_X, context_y, target_X)
loc, scale # predictive_distribution = dx.MultivariateNormalDiag(loc=loc, scale_diag=scale)
# return -predictive_distribution.log_prob(target_y)
return jnp.square(loc.ravel() - target_y.ravel()).mean()
return jax.vmap(loss_fn_per_sample, in_axes=(None, 0, None, 0))(context_X, context_y, target_X, target_y).mean()
= jax.jit(jax.value_and_grad(loss_fn))
value_and_grad_fn = CNP([256]*2, 128, [256]*4, 1)
model = model.init(jax.random.PRNGKey(0), train_X, train_y[0], train_X)
params = optax.adam(1e-5)
optimizer = optimizer.init(params)
state
# losses = []
# for iter in tqdm(range(iterations)):
# tmp_index = jax.random.permutation(jax.random.PRNGKey(iter), index)
# context_X = train_X[tmp_index][:int(train_X.shape[0]*0.05)]
# context_y = train_y[:, tmp_index, :][:, :int(train_X.shape[0]*0.05), :]
# target_X = train_X[tmp_index][int(train_X.shape[0]*0.05):]
# target_y = train_y[:, tmp_index, :][:, int(train_X.shape[0]*0.05):, :]
# # print(context_X.shape, context_y.shape, target_X.shape, target_y.shape)
# # print(loss_fn(params, context_X, context_y, target_X, target_y).shape)
# loss, grads = value_and_grad_fn(params, context_X, context_y, target_X, target_y)
# updates, state = optimizer.update(grads, state)
# params = optax.apply_updates(params, updates)
# losses.append(loss.item())
def one_step(params_and_state, key):
= params_and_state
params, state = jax.random.permutation(key, train_X.shape[0])
tmp_index = train_X[tmp_index][:int(train_X.shape[0]*0.05)]
context_X = train_y[:, tmp_index, :][:, :int(train_X.shape[0]*0.05), :]
context_y = train_X[tmp_index][int(train_X.shape[0]*0.05):]
target_X = train_y[:, tmp_index, :][:, int(train_X.shape[0]*0.05):, :]
target_y = value_and_grad_fn(params, context_X, context_y, target_X, target_y)
loss, grads = optimizer.update(grads, state)
updates, state = optax.apply_updates(params, updates)
params return (params, state), loss
= jax.lax.scan(one_step, (params, state), jax.random.split(jax.random.PRNGKey(0), iterations)) (params, state), loss_history
10:]); plt.plot(loss_history[
= jax.random.PRNGKey(0)
test_key = jax.random.permutation(test_key, train_X.shape[0])
tmp_index = train_X[tmp_index][:int(train_X.shape[0]*0.5)]
context_X = train_y[:, tmp_index, :][:, :int(train_X.shape[0]*0.5), :]
context_y = train_X#[tmp_index][int(train_X.shape[0]*0.5):]
target_X = train_y#[:, tmp_index, :][:, int(train_X.shape[0]*0.5):, :]
target_y
id = 91
id].reshape(28, 28), cmap="gray", interpolation=None);
plt.imshow(train_y[
= jax.vmap(model.apply, in_axes=(None, None, 0, None))(params, context_X, context_y, target_X)
locs, scales # full_preds = jnp.concatenate([context_y, locs], axis=1)
# full_preds = full_preds.at[:, tmp_index, :].set(full_preds).__array__()
plt.figure()id].reshape(28, 28), cmap="gray", interpolation=None); plt.imshow(locs[