import os
"CUDA_VISIBLE_DEVICES"] = "3"
os.environ[
import numpy as np
import pandas as pd
import regdata as rd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
from tqdm import tqdm
import matplotlib.pyplot as plt
= "cuda"
device "torch") rd.set_backend(
Generate data
# x = torch.linspace(-1, 1, 100)
# y = (torch.sin(x * 2 * torch.pi) + torch.randn(x.size()) * 0.1).unsqueeze(1)
= rd.MotorcycleHelmet().get_data()
x, y, _ = x.ravel().to(torch.float32)
x = np.argsort(x)
idx = x[idx]
x = y.to(torch.float32)
y = y[idx]
y
= torch.vstack([torch.ones_like(x), x]).T
x print(x.shape, y.shape)
= x.to(device)
x = y.to(device)
y print(x.dtype, y.dtype)
1], y.cpu().numpy()) plt.scatter(x.cpu().numpy()[:,
torch.Size([94, 2]) torch.Size([94])
torch.float32 torch.float32
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, neurons, transform=None):
super().__init__()
self.layers = nn.ModuleList()
self.transform = transform
if transform is None:
self.transform = lambda x: x
self.layers.append(nn.Linear(in_dim, neurons[0]))
else:
self.layers.append(nn.Linear(self.transform.n_grid + 1, neurons[0]))
for i in range(1, len(neurons)):
self.layers.append(nn.Linear(neurons[i - 1], neurons[i]))
self.layers.append(nn.Linear(neurons[-1], out_dim))
def forward(self, x):
= self.transform(x)
x # print(x.shape)
for layer in self.layers[:-1]:
= F.gelu(layer(x))
x return self.layers[-1](x)
class RBF(nn.Module):
def __init__(self, log_gauss_var, n_grid):
super().__init__()
self.log_gauss_var = nn.Parameter(torch.tensor(log_gauss_var))
self.n_grid = n_grid
self.grid = nn.Parameter(torch.linspace(-1, 1, n_grid))
self.register_buffer("bias", torch.zeros(1))
def forward(self, x):
self.dist = dist.Normal(self.grid, torch.exp(self.log_gauss_var))
= torch.exp(self.dist.log_prob(x[:, 1:2]))
features # print(features.shape)
= torch.cat(
features
[self.bias.repeat(features.shape[0])).reshape(-1, 1),
torch.ones_like(
features,
],=1,
dim
)return features
0.0, 10).to(device)(x).shape RBF(
torch.Size([94, 11])
# def transform_fn(x):
# all_x = []
# for i in range(2, 11):
# all_x.append(x[:, 1:2] ** i)
# return torch.hstack([x] + all_x)
def get_mn_sn(x, s0):
= transform_fn(x)
x = (x.T @ x) / torch.exp(log_var_noise)
sn_inv = sn_inv.diagonal()
diag += 1 / s0
diag = torch.inverse(sn_inv)
sn = sn @ ((x.T @ y) / torch.exp(log_var_noise))
mn return mn, sn
def neg_log_likelihood(x, y, m0, s0):
= transform_fn(x)
x = (x @ x.T) / s0
cov = cov.diagonal()
diag += torch.exp(log_var_noise)
diag return (
-dist.MultivariateNormal(m0.repeat(y.shape[0]), cov).log_prob(y.ravel()).sum()
)
def get_pred_post(sn, mn, x):
= transform_fn(x)
x = x @ sn @ x.T
pred_cov = pred_cov.diagonal()
diag += torch.exp(log_var_noise)
diag = x @ mn
pred_mean return pred_mean, pred_cov
def plot_preds_and_95(ax, x, pred_mean, pred_cov):
with torch.no_grad():
= x[:, 1].cpu().numpy()
x = pred_mean.ravel().cpu().numpy()
pred_mean = pred_cov.diagonal().cpu().numpy()
pred_var ="red", label="mean")
ax.plot(x, pred_mean, color
ax.fill_between(
x,- 2 * np.sqrt(pred_var)),
(pred_mean + 2 * np.sqrt(pred_var)),
(pred_mean ="red",
color=0.2,
alpha="95% CI",
label
)return ax
= MLP(2, 1, [256, 256, 256]).to(device)
mlp # mlp = RBF(0.1, 20).to(device)
= mlp.forward
transform_fn
= torch.zeros((1,)).to(device)
m0 = torch.tensor(1.0).to(device)
s0 with torch.no_grad():
= nn.Parameter(torch.tensor(0.1)).to(device)
log_var_noise = True
log_var_noise.requires_grad = True
m0.requires_grad = True s0.requires_grad
= torch.optim.Adam([*list(mlp.parameters()), log_var_noise, m0, s0], lr=0.01)
optimizer = []
losses = tqdm(range(500))
pbar for i in pbar:
optimizer.zero_grad()= neg_log_likelihood(x, y, m0, s0)
loss
loss.backward()
optimizer.step()
losses.append(loss.item())f"loss: {loss.item():.4f}")
pbar.set_description(
plt.plot(losses)
loss: 30.6285: 100%|██████████| 500/500 [00:02<00:00, 209.49it/s]
= get_mn_sn(x, s0)
mn, sn = get_pred_post(sn, mn, x)
pred_mean, pred_var
= plt.subplots()
fig, ax = plot_preds_and_95(ax, x, pred_mean, pred_var)
ax with torch.no_grad():
1], y.cpu().numpy())
ax.scatter(x.cpu().numpy()[:, # ax.vlines(mlp.transform.grid.cpu().numpy(), -1, 1, color="black", alpha=0.2)
plt.show()
torch.exp(log_var_noise), s0, m0
(tensor(0.1191, device='cuda:0', grad_fn=<ExpBackward0>),
tensor(1.3897, device='cuda:0', requires_grad=True),
tensor([-0.0693], device='cuda:0', requires_grad=True))
Add Gaussian transform
= MLP(2, 1, [256, 256, 256], transform=RBF(0.1, 10)).to(device)
mlp # mlp = RBF(0.1, 20).to(device)
= mlp.forward
transform_fn
= torch.zeros((1,)).to(device)
m0 = torch.tensor(1.0).to(device)
s0 with torch.no_grad():
= nn.Parameter(torch.tensor(0.1)).to(device)
log_var_noise = True
log_var_noise.requires_grad = False
m0.requires_grad = True s0.requires_grad
= torch.optim.Adam([*list(mlp.parameters()), log_var_noise, m0, s0], lr=0.01)
optimizer = []
losses = tqdm(range(500))
pbar for i in pbar:
optimizer.zero_grad()= neg_log_likelihood(x, y, m0, s0)
loss
loss.backward()
optimizer.step()
losses.append(loss.item())f"loss: {loss.item():.4f}")
pbar.set_description(
plt.plot(losses)
loss: -29.9227: 100%|██████████| 500/500 [00:03<00:00, 156.90it/s]
= get_mn_sn(x, s0)
mn, sn = get_pred_post(sn, mn, x)
pred_mean, pred_var
= plt.subplots()
fig, ax = plot_preds_and_95(ax, x, pred_mean, pred_var)
ax with torch.no_grad():
1], y.cpu().numpy())
ax.scatter(x.cpu().numpy()[:, -1, 1, color="black", alpha=0.2)
ax.vlines(mlp.transform.grid.cpu().numpy(), plt.show()
Just Gaussian basis
# mlp = MLP(2, 1, [32, 32, 32], transform=RBF(0.1, 10)).to(device)
= RBF(1.0, 5).to(device)
mlp = mlp.forward
transform_fn
= torch.zeros((1,)).to(device)
m0 = torch.tensor(1.0).to(device)
s0 with torch.no_grad():
= nn.Parameter(torch.tensor(0.1)).to(device)
log_var_noise = True
log_var_noise.requires_grad = False
m0.requires_grad = True s0.requires_grad
= torch.optim.Adam([*list(mlp.parameters()), log_var_noise, m0, s0], lr=0.001)
optimizer = []
losses = tqdm(range(500))
pbar for i in pbar:
optimizer.zero_grad()= neg_log_likelihood(x, y, m0, s0)
loss
loss.backward()
optimizer.step()
losses.append(loss.item())f"loss: {loss.item():.4f}")
pbar.set_description(
plt.plot(losses)
loss: 207.0843: 100%|██████████| 500/500 [00:02<00:00, 195.61it/s]
= get_mn_sn(x, s0)
mn, sn = get_pred_post(sn, mn, x)
pred_mean, pred_var
= plt.subplots()
fig, ax = plot_preds_and_95(ax, x, pred_mean, pred_var)
ax with torch.no_grad():
1], y.cpu().numpy())
ax.scatter(x.cpu().numpy()[:, -1, 1, color="black", alpha=0.2)
ax.vlines(mlp.grid.cpu().numpy(), plt.show()
Appendix
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
= pd.read_csv("~/datasets/uci/bike/hour.csv", header=None).iloc[:, 1:]
data data.shape
(17379, 18)
= data.iloc[:, :-1].values
X = data.iloc[:, -1].values
y = train_test_split(X, y, test_size=0.4, random_state=0)
X_train, X_test, y_train, y_test
X_train.shape, X_test.shape, y_train.shape, y_test.shape
= MinMaxScaler()
x_scaler = StandardScaler()
y_scaler = x_scaler.fit_transform(X_train)
X_train = y_scaler.fit_transform(y_train.reshape(-1, 1))
y_train = x_scaler.transform(X_test)
X_test = y_scaler.transform(y_test.reshape(-1, 1))
y_test
X_train.shape, X_test.shape, y_train.shape, y_test.shape
((10427, 17), (6952, 17), (10427, 1), (6952, 1))
= map(
[X_train, X_test, y_train, y_test] lambda x: torch.tensor(x, dtype=torch.float32).to(device),
[X_train, X_test, y_train, y_test], )
= MLP(17, 1, [10, 10]).to(device)
mlp
= torch.optim.Adam(mlp.parameters(), lr=0.01)
optimizer = []
losses = tqdm(range(500))
pbar for i in pbar:
optimizer.zero_grad()= F.mse_loss(mlp(X_train), y_train)
loss
loss.backward()
optimizer.step()
losses.append(loss.item())f"loss: {loss.item():.4f}")
pbar.set_description(
plt.plot(losses)
loss: 0.0040: 100%|██████████| 500/500 [00:01<00:00, 482.25it/s]
with torch.no_grad():
= mlp(X_test).cpu().numpy()
y_pred if isinstance(y_test, torch.Tensor):
= y_test.cpu().numpy()
y_test print(y_pred.shape, y_test.shape)
print("RMSE", mean_squared_error(y_test, y_pred, squared=False))
(6952, 1) (6952, 1)
RMSE 0.08354535