from tqdm import tqdm
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from gpytorch.kernels import RBFKernel, Kernel
class Rank1Kernel(nn.Module):
def __init__(self, input_dim, output_dim, n_neurons_per_layer, activation):
super().__init__()
self.init = nn.Linear(input_dim, n_neurons_per_layer[0])
self.n_neurons_per_layer = n_neurons_per_layer
self.activation = activation
for i in range(1, len(n_neurons_per_layer)):
setattr(self, f'fc{i}', nn.Linear(n_neurons_per_layer[i-1], n_neurons_per_layer[i]))
self.out = nn.Linear(n_neurons_per_layer[-1], output_dim)
def forward(self, x1, x2):
def _forward(x):
= self.init(x)
x for i in range(1, len(self.n_neurons_per_layer)):
= getattr(self, f'fc{i}')(x)
x = self.activation(x)
x return self.out(x)
= _forward(x1)
x1 = _forward(x2)
x2
# print(x1.shape, x2.shape)
= x1 @ x2.T
covar # print(covar.shape, gt_covar.shape, x1.shape, x2.shape)
return covar
= RBFKernel()
fixed_kernel = 0.3
fixed_kernel.lengthscale
= torch.linspace(-1, 1, 100).view(-1, 1) X1
= 1000
epochs = [64]*4
n_neurons_per_layer = 10
output_dim = Rank1Kernel(1, output_dim, n_neurons_per_layer, torch.sin)
kernel = torch.optim.Adam(kernel.parameters(), lr=0.001)
optimizer
= []
losses with torch.no_grad():
= fixed_kernel(X1, X1).evaluate_kernel().tensor
gt_covar
= tqdm(range(epochs))
bar for epoch in bar:
optimizer.zero_grad()= kernel(X1, X1)
pred_covar = torch.mean((gt_covar - pred_covar)**2)
loss
losses.append(loss.item())
loss.backward()
optimizer.step()f"Loss: {loss.item():.4f}") bar.set_description(
Loss: 0.0001: 100%|██████████| 1000/1000 [00:06<00:00, 150.34it/s]
; plt.plot(losses)
= plt.subplots(1,2,figsize=(8, 3))
fig, ax
=ax[0], cmap='RdYlGn_r', cbar=True, vmin=-2, vmax=2)
sns.heatmap(gt_covar, ax0].set_title('Ground Truth Covariance')
ax[
= torch.linspace(-1.5, 1.5, 100).view(-1, 1)
X_new with torch.no_grad():
= kernel(X_new, X_new)
est_covar =ax[1], cmap='RdYlGn_r', cbar=True, vmin=-2, vmax=2)
sns.heatmap(est_covar, ax1].set_title('Estimated Covariance'); ax[
# plt.plot()
= torch.zeros(1, 1) + 1
X2 with torch.no_grad():
= gt_covar[-1, :]
variance ="fixed kernel");
plt.plot(X1, variance.numpy(), label
= kernel(X1, X2)
variance =f"rank-{output_dim} kernel");
plt.plot(X1, variance.numpy(), label
plt.legend()
print(gt_covar.shape)
2)
torch.random.manual_seed(= dist.MultivariateNormal(torch.zeros(100), gt_covar + 1e-5 * torch.eye(100))
norm = norm.sample()
y ; plt.plot(X1, y)
torch.Size([100, 100])
= 6
n
= plt.subplots(1, n, figsize=(15, 2))
fig, ax = X1
d_x = y
d_y for i in range(n):
print(f"{i}: {torch.var(d_y)}")
ax[i].plot(d_x, d_y)= d_x[1:] - d_x[:-1]
d_x = torch.cumsum(d_x, dim=0)
d_x = d_y[1:] - d_y[:-1]
d_y
= lambda x: torch.zeros_like(x)
f -1].plot(d_x, f(d_x), c="r", label="f(x)") ax[
0: 0.5698477029800415
1: 0.006691396702080965
2: 0.0001796285796444863
3: 0.00022799619182478637
4: 0.0008216467685997486
5: 0.00304242386482656