import os
"CUDA_VISIBLE_DEVICES"] = ""
os.environ[
import jax
import jax.numpy as jnp
import optax
from tqdm import tqdm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
We know that any continuous signal can be represented as a sum of sinusoids. The question is, how many sinusoids do we need to represent a signal? In this notebook, we will explore this question.
Random Combination of Sinusoids
= 1000
N = jnp.linspace(-10, 10, N).reshape(-1, 1)
x = jnp.sin(x) + jnp.sin(2*x) #+ jax.random.normal(jax.random.PRNGKey(0), (N, 1)) * 0.1
y "kx");
plt.plot(x, y, print(x.shape, y.shape)
(1000, 1) (1000, 1)
Recover the Signal
def get_weights(key):
= jax.random.uniform(key, (), minval=0.0, maxval=5.0)
w1 = jax.random.split(key)[0]
key = jax.random.uniform(key, (), minval=0.0, maxval=5.0)
w2 return w1, w2
def get_sine(weights, x):
= weights
w1, w2 return jnp.sin(w1*x) + jnp.sin(w2*x)
def loss_fn(weights, x, y):
= get_sine(weights, x)
output = weights
w1, w2 return jnp.mean((output.ravel() - y.ravel())**2)
def one_step(weights_and_state, xs):
= weights_and_state
weights, state = value_and_grad_fn(weights, x, y)
loss, grads = optimizer.update(grads, state)
updates, state = optax.apply_updates(weights, updates)
weights return (weights, state), (loss, weights)
= 1000
epochs = optax.adam(1e-2)
optimizer = jax.jit(jax.value_and_grad(loss_fn))
value_and_grad_fn = plt.subplots(4, 3, figsize=(15, 12))
fig, ax = plt.subplots(4, 3, figsize=(15, 12))
fig2, ax2 = ax.ravel()
ax = ax2.ravel()
ax2 for seed in tqdm(range(12)):
= jax.random.PRNGKey(seed)
key = get_weights(key)
init_weights = optimizer.init(init_weights)
state = jax.lax.scan(one_step, (init_weights, state), None, length=epochs)
(weights, _), (loss_history, _) = get_sine(weights, x)
y_pred "kx")
ax[seed].plot(x, y, "r-")
ax[seed].plot(x, y_pred, f"w_init=({init_weights[0]:.2f}, {init_weights[1]:.2f}), w_pred=({weights[0]:.2f}, {weights[1]:.2f}), loss={loss_fn(weights, x, y):.2f}")
ax[seed].set_title(
ax2[seed].plot(loss_history) fig.tight_layout()
100%|██████████| 12/12 [00:00<00:00, 15.91it/s]
Plot loss surface
= jnp.linspace(0, 3, 100)
w1 = jnp.linspace(0, 3, 100)
w2 = jnp.meshgrid(w1, w2)
W1, W2 = jax.vmap(jax.vmap(lambda w1, w2: loss_fn((w1, w2), x, y)))(W1, W2)
loss
# plot the loss surface in 3D
= plt.figure(figsize=(8, 6))
fig = fig.add_subplot(111, projection='3d')
ax ="viridis", alpha=0.9);
ax.plot_surface(W1, W2, loss, cmap"w1");
ax.set_xlabel("w2");
ax.set_ylabel(# top view
30, 45) ax.view_init(