Zero-Shot Neural ODEs for the Van der Pol Oscillator¶
Learning objectives
- Implement the function-encoder framework where a latent code summarises context pairs $(t_i, x_i)$ and conditions a Neural ODE
- Generate Van der Pol trajectories across varying stiffness parameter $\mu$ and initial states
- Train a conditional Neural ODE that reconstructs full trajectories from a handful of context points
- Compare against a vanilla Neural ODE baseline that lacks function encoding, highlighting why zero-shot conditioning matters
1. Motivation¶
Zero-shot Neural ODEs learn an encoder that ingests a few function evaluations $(t_i, x_i)$ and produces a latent summary vector $$c = \frac{1}{K} \sum_{i=1}^{K} \phi([t_i, x_i]).$$ A conditional ODE $\dot{x} = f_\theta(x, t, c)$ then reconstructs the full trajectory and extrapolates to unseen timestamps or parameter values. This notebook rebuilds the Van der Pol example from scratch to show how the encoder, conditional dynamics, and baseline model interact.
2. Environment setup¶
import sys
import time
from dataclasses import dataclass
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
REPO_ROOT = Path.cwd()
TORCHDIFFEQ_DIR = (REPO_ROOT / 'docs' / '08-neural-ode' / 'torchdiffeq').resolve()
if TORCHDIFFEQ_DIR.exists() and str(TORCHDIFFEQ_DIR) not in sys.path:
sys.path.insert(0, str(TORCHDIFFEQ_DIR))
from torchdiffeq import odeint
plt.style.use('seaborn-v0_8')
plt.rcParams.update({'axes.grid': True, 'figure.figsize': (6, 4)})
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE
def set_seed(seed=123):
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(7)
3. Van der Pol dynamics and data generation¶
The Van der Pol oscillator obeys $$\dot{x} = y, \qquad \dot{y} = \mu (1 - x^2) y - x$$ where $\mu$ controls nonlinearity/stiffness. We'll sample multiple $(\mu, x_0, y_0)$ combinations, integrate the system with a high-accuracy Dormand–Prince solver, and store the full trajectories as ground truth.
class VanDerPol(nn.Module):
def __init__(self, mu: float):
super().__init__()
self.register_buffer('mu', torch.tensor(mu, dtype=torch.float32))
def forward(self, t, y):
x = y[:, 0]
v = y[:, 1]
dx = v
dy = self.mu * (1 - x ** 2) * v - x
return torch.stack([dx, dy], dim=-1)
def simulate_vdp(mu, y0, t_eval):
y0 = y0.unsqueeze(0)
dyn = VanDerPol(mu)
with torch.no_grad():
sol = odeint(dyn, y0, t_eval, method='dopri5', rtol=1e-7, atol=1e-9)
return sol.squeeze(1)
T_END = 20.0
T_STEPS = 300
T_GRID_CPU = torch.linspace(0.0, T_END, T_STEPS)
DT = float((T_GRID_CPU[1] - T_GRID_CPU[0]).item())
train_split = {'num_samples': 48, 'mu_range': (0.5, 3.0)}
test_split = {'num_samples': 16, 'mu_range': (3.0, 5.0)}
def sample_initial_state():
x0 = np.random.uniform(-2.0, 2.0)
v0 = np.random.uniform(-2.0, 2.0)
return torch.tensor([x0, v0], dtype=torch.float32)
def build_dataset(num_samples, mu_range):
trajs, y0s, mus = [], [], []
for _ in range(num_samples):
mu = np.random.uniform(*mu_range)
y0 = sample_initial_state()
traj = simulate_vdp(mu, y0, T_GRID_CPU)
trajs.append(traj)
y0s.append(y0)
mus.append(mu)
return (
torch.stack(trajs, dim=0),
torch.stack(y0s, dim=0),
torch.tensor(mus, dtype=torch.float32)
)
train_trajs, train_y0, train_mus = build_dataset(**train_split)
test_trajs, test_y0, test_mus = build_dataset(**test_split)
train_trajs.shape, test_trajs.shape
Visualize sample trajectories¶
def plot_samples(trajs, title, n=3):
t_np = T_GRID_CPU.numpy()
plt.figure(figsize=(7, 4))
for i in range(n):
plt.plot(t_np, trajs[i, :, 0], label=f'x sample {i}')
plt.title(title)
plt.xlabel('time')
plt.ylabel('x(t)')
plt.legend()
plt.show()
plot_samples(train_trajs, 'Train trajectories (x component)')
plot_samples(test_trajs, 'Held-out trajectories (x component)')
4. Dataset wrappers¶
class TrajectoryDataset(Dataset):
def __init__(self, trajs, y0s, mus):
self.trajs = trajs
self.y0s = y0s
self.mus = mus
def __len__(self):
return self.trajs.shape[0]
def __getitem__(self, idx):
return {
'traj': self.trajs[idx],
'y0': self.y0s[idx],
'mu': self.mus[idx]
}
train_dataset = TrajectoryDataset(train_trajs, train_y0, train_mus)
test_dataset = TrajectoryDataset(test_trajs, test_y0, test_mus)
5. Context sampling helper¶
The function encoder observes $K$ context points per trajectory. We randomly select indices, gather $(t, x, y)$ tuples, and normalize time to $[0, 1]$ before feeding them to the encoder (per Section 3 of the zero-shot paper).
Function encoder and conditional dynamics (math view)¶
Given context points $(t_i, x_i)$ we normalize time to $[0, 1]$, embed each pair with an MLP $\phi([t_i, x_i])$, and average to obtain a latent code $$c = \frac{1}{K} \sum_{i=1}^{K} \phi([t_i, x_i]).$$ The conditional Neural ODE then integrates $$\dot{x} = f_\theta(x, t, c)$$ with the same initial condition as the ground-truth system. Changing the context set changes $c$, so the decoder adapts to different $\mu$ values without retraining. We'll compare this zero-shot model against a universal Neural ODE that lacks the latent code and therefore must memorize a single vector field.
T_GRID = T_GRID_CPU.to(DEVICE)
TIME_MIN = T_GRID[0]
TIME_RANGE = T_GRID[-1] - T_GRID[0]
def sample_contexts(traj_batch, num_context):
batch, steps, state_dim = traj_batch.shape
device = traj_batch.device
idx = torch.stack([torch.randperm(steps, device=device)[:num_context]
for _ in range(batch)], dim=0)
idx, _ = torch.sort(idx, dim=1)
times = T_GRID[idx]
times_norm = ((times - TIME_MIN) / TIME_RANGE).unsqueeze(-1)
gather_idx = idx.unsqueeze(-1).expand(-1, -1, state_dim)
values = traj_batch.gather(1, gather_idx)
return times_norm, values
6. Function encoder + conditional Neural ODE¶
class FunctionEncoder(nn.Module):
def __init__(self, latent_dim=16, hidden_dim=64):
super().__init__()
self.embed = nn.Sequential(
nn.Linear(3, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True)
)
self.project = nn.Linear(hidden_dim, latent_dim)
def forward(self, t_ctx, x_ctx):
inp = torch.cat([t_ctx, x_ctx], dim=-1)
feats = self.embed(inp)
pooled = feats.mean(dim=1)
return self.project(pooled)
class ConditionalODEFunc(nn.Module):
def __init__(self, latent_dim=16, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim + 2 + 1, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 2)
)
self.current_latent = None
def set_latent(self, latent):
self.current_latent = latent
def forward(self, t, y):
if self.current_latent is None:
raise RuntimeError('latent code not set')
lat = self.current_latent
if lat.shape[0] != y.shape[0]:
lat = lat.expand(y.shape[0], -1)
t_input = torch.full((y.shape[0], 1), float(t), device=y.device)
inp = torch.cat([y, lat, t_input], dim=-1)
return self.net(inp)
class ZeroShotVanDerPol(nn.Module):
def __init__(self, latent_dim=16, hidden_dim=64, solver='rk4', step_size=DT):
super().__init__()
self.encoder = FunctionEncoder(latent_dim, hidden_dim)
self.func = ConditionalODEFunc(latent_dim, hidden_dim)
self.solver = solver
self.step_size = step_size
def forward(self, t_ctx, x_ctx, t_eval, y0):
latent = self.encoder(t_ctx, x_ctx)
self.func.set_latent(latent)
options = {'step_size': self.step_size} if self.solver == 'rk4' else None
traj = odeint(self.func, y0, t_eval, method=self.solver, options=options)
self.func.set_latent(None)
return traj.permute(1, 0, 2)
7. Baseline: universal Neural ODE without function encoding¶
To address the issue noted in the previous notebook, we also train a vanilla Neural ODE that uses a single vector field for every trajectory (no latent code, no context). This highlights how poorly it generalizes to unseen $\mu$ values.
class VanillaODEFunc(nn.Module):
def __init__(self, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(2 + 1, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 2)
)
def forward(self, t, y):
t_input = torch.full((y.shape[0], 1), float(t), device=y.device)
inp = torch.cat([y, t_input], dim=-1)
return self.net(inp)
class UniversalVanDerPol(nn.Module):
def __init__(self, hidden_dim=64, solver='rk4', step_size=DT):
super().__init__()
self.func = VanillaODEFunc(hidden_dim)
self.solver = solver
self.step_size = step_size
def forward(self, t_eval, y0):
options = {'step_size': self.step_size} if self.solver == 'rk4' else None
traj = odeint(self.func, y0, t_eval, method=self.solver, options=options)
return traj.permute(1, 0, 2)
8. Training utilities¶
def train_zero_shot(model, dataset, epochs=150, batch_size=8, n_context=20, lr=1e-3):
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
history = []
model.train()
for epoch in range(1, epochs + 1):
epoch_loss = 0.0
for batch in loader:
traj = batch['traj'].to(DEVICE)
y0 = batch['y0'].to(DEVICE)
t_ctx, x_ctx = sample_contexts(traj, n_context)
pred = model(t_ctx, x_ctx, T_GRID, y0)
loss = F.mse_loss(pred, traj)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss /= len(loader)
history.append(epoch_loss)
if epoch % 25 == 0:
print(f"Zero-shot epoch {epoch:03d} | MSE {epoch_loss:.4f}")
return history
def train_baseline(model, dataset, epochs=150, batch_size=8, lr=1e-3):
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
history = []
model.train()
for epoch in range(1, epochs + 1):
epoch_loss = 0.0
for batch in loader:
traj = batch['traj'].to(DEVICE)
y0 = batch['y0'].to(DEVICE)
pred = model(T_GRID, y0)
loss = F.mse_loss(pred, traj)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss /= len(loader)
history.append(epoch_loss)
if epoch % 25 == 0:
print(f"Baseline epoch {epoch:03d} | MSE {epoch_loss:.4f}")
return history
9. Train both models¶
set_seed(21)
zero_shot_model = ZeroShotVanDerPol(latent_dim=16, hidden_dim=64).to(DEVICE)
zs_history = train_zero_shot(zero_shot_model, train_dataset, epochs=150, batch_size=8, n_context=20, lr=1e-3)
set_seed(21)
baseline_model = UniversalVanDerPol(hidden_dim=64).to(DEVICE)
baseline_history = train_baseline(baseline_model, train_dataset, epochs=150, batch_size=8, lr=1e-3)
Training curves¶
plt.figure()
plt.plot(zs_history, label='Zero-shot Neural ODE')
plt.plot(baseline_history, label='Universal baseline')
plt.xlabel('epoch')
plt.ylabel('MSE loss')
plt.title('Training dynamics')
plt.legend()
plt.show()
10. Evaluation on held-out $\mu$¶
We randomly sample new context points for each batch during evaluation to mimic the zero-shot setting.
def evaluate_zero_shot(model, dataset, n_context=20):
loader = DataLoader(dataset, batch_size=8, shuffle=False)
model.eval()
mses = []
with torch.no_grad():
for batch in loader:
traj = batch['traj'].to(DEVICE)
y0 = batch['y0'].to(DEVICE)
t_ctx, x_ctx = sample_contexts(traj, n_context)
pred = model(t_ctx, x_ctx, T_GRID, y0)
mse = F.mse_loss(pred, traj, reduction='none').mean(dim=[1, 2])
mses.append(mse.cpu())
return torch.cat(mses).mean().item()
def evaluate_baseline(model, dataset):
loader = DataLoader(dataset, batch_size=8, shuffle=False)
model.eval()
mses = []
with torch.no_grad():
for batch in loader:
traj = batch['traj'].to(DEVICE)
y0 = batch['y0'].to(DEVICE)
pred = model(T_GRID, y0)
mse = F.mse_loss(pred, traj, reduction='none').mean(dim=[1, 2])
mses.append(mse.cpu())
return torch.cat(mses).mean().item()
train_mse_zs = evaluate_zero_shot(zero_shot_model, train_dataset)
test_mse_zs = evaluate_zero_shot(zero_shot_model, test_dataset)
train_mse_base = evaluate_baseline(baseline_model, train_dataset)
test_mse_base = evaluate_baseline(baseline_model, test_dataset)
print(f"Zero-shot model | train MSE {train_mse_zs:.4f} | test MSE {test_mse_zs:.4f}")
print(f"Universal baseline | train MSE {train_mse_base:.4f} | test MSE {test_mse_base:.4f}")
11. Qualitative comparison on a held-out trajectory¶
def visualize_prediction(sample_idx=0):
sample = test_dataset[sample_idx]
traj = sample['traj'].unsqueeze(0).to(DEVICE)
y0 = sample['y0'].unsqueeze(0).to(DEVICE)
t_ctx, x_ctx = sample_contexts(traj, num_context=20)
with torch.no_grad():
zs_pred = zero_shot_model(t_ctx, x_ctx, T_GRID, y0).squeeze(0).cpu()
base_pred = baseline_model(T_GRID, y0).squeeze(0).cpu()
true_traj = traj.squeeze(0).cpu()
t_np = T_GRID_CPU.numpy()
fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True)
axes[0].plot(t_np, true_traj[:, 0], label='true x(t)')
axes[0].plot(t_np, zs_pred[:, 0], '--', label='zero-shot')
axes[0].set_title('Zero-shot vs truth (x component)')
axes[0].legend()
axes[1].plot(t_np, true_traj[:, 0], label='true x(t)')
axes[1].plot(t_np, base_pred[:, 0], '--', label='baseline')
axes[1].set_title('Baseline vs truth (x component)')
axes[1].legend()
plt.show()
visualize_prediction(sample_idx=3)
Phase portrait¶
def plot_phase(sample_idx=1):
sample = test_dataset[sample_idx]
traj = sample['traj'].unsqueeze(0).to(DEVICE)
y0 = sample['y0'].unsqueeze(0).to(DEVICE)
t_ctx, x_ctx = sample_contexts(traj, num_context=20)
with torch.no_grad():
zs_pred = zero_shot_model(t_ctx, x_ctx, T_GRID, y0).squeeze(0).cpu()
base_pred = baseline_model(T_GRID, y0).squeeze(0).cpu()
true_traj = traj.squeeze(0).cpu()
plt.figure(figsize=(6, 5))
plt.plot(true_traj[:, 0], true_traj[:, 1], label='true')
plt.plot(zs_pred[:, 0], zs_pred[:, 1], '--', label='zero-shot')
plt.plot(base_pred[:, 0], base_pred[:, 1], ':', label='baseline')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Phase portrait comparison (held-out $\mu$)')
plt.legend()
plt.axis('equal')
plt.show()
plot_phase(sample_idx=5)
12. Exercises¶
- Context budget: vary the number of context points $K\in\{5, 10, 20, 40\}$ and measure test MSE to reproduce the ablation from the paper.
- Latent prior: add an $\ell_2$ penalty on the latent code to avoid degenerate solutions and inspect the effect on extrapolation.
- Adaptive solvers: swap the RK4 solver for
dopri5and track the number of function evaluations vs. accuracy. - Different dynamics: replace Van der Pol with FitzHugh–Nagumo or Lotka–Volterra to test how well the function encoder transfers.
These prompts align with the other course sections and encourage deeper experimentation with zero-shot Neural ODEs.