B2B Operator Learning: Basis-to-Basis Transformations¶
Learning Objectives:
- Master the B2B (Basis-to-Basis) framework for operator learning
- Learn operators as explicit transformation matrices
- Implement the same examples as DeepONet for direct comparison
- Understand zero-shot and few-shot operator learning
- Compare B2B with DeepONet on performance and interpretability
Examples covered:
- Derivative operator
- Poisson equation solver
- 1D nonlinear Darcy flow
The B2B Framework¶
Core idea: Decompose operator learning into three steps:
- Encode source: $f \xrightarrow{E_1} c_f \in \mathbb{R}^{n_1}$
- Transform: $c_f \xrightarrow{A} c_g \in \mathbb{R}^{n_2}$
- Decode target: $c_g \xrightarrow{D_2} g$
The operator $\mathcal{G}$ is represented as: $\mathcal{G}[f] \approx D_2(A \cdot E_1(f))$
Key advantage: The transformation matrix $A$ is explicit and interpretable!
Theory: B2B Operator Learning Framework¶
Mathematical Foundation¶
The B2B framework learns operators between function spaces by decomposing the problem into basis representations. Given an operator $\mathcal{T}: \mathcal{U} \rightarrow \mathcal{V}$, we learn:
- Source encoding: $u \mapsto \alpha^u = E_\mathcal{U}(u)$
- Transformation: $\alpha^u \mapsto \alpha^v = A \alpha^u$
- Target decoding: $\alpha^v \mapsto v = D_\mathcal{V}(\alpha^v)$
The complete operator: $\mathcal{T}[u] \approx D_\mathcal{V}(A \cdot E_\mathcal{U}(u))$
Learning the Transformation Matrix¶
Given training pairs $\{(u_i, v_i)\}_{i=1}^N$, we solve:
$$\min_A \sum_{i=1}^N \|E_\mathcal{V}(v_i) - A \cdot E_\mathcal{U}(u_i)\|^2 + \lambda \|A\|_F^2$$
Solution via regularized least squares: $$A = (C_U^T C_U + \lambda I)^{-1} C_U^T C_V$$
where $C_U$ and $C_V$ are matrices of encoded coefficients.
Key Properties¶
- Interpretability: The matrix $A$ explicitly represents the operator's action
- Sample efficiency: Pre-trained encoders enable few-shot operator learning
- Transfer learning: Encoders can be reused across related operators
- Spectral analysis: SVD of $A$ reveals operator characteristics
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import matplotlib.pyplot as plt
from scipy.sparse import diags
from scipy.sparse.linalg import spsolve
from scipy.stats import multivariate_normal
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
# Set random seeds
np.random.seed(42)
torch.manual_seed(42)
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
Using device: mps
Part 1: Function Encoder Architecture¶
First, we need function encoders to learn representations of source and target function spaces.
class FunctionEncoder(nn.Module):
"""Function encoder for learning basis representations"""
def __init__(self, sensor_dim, n_basis, hidden_dim=64):
super().__init__()
self.sensor_dim = sensor_dim
self.n_basis = n_basis
# Encoder: maps function samples to coefficients
self.encoder = nn.Sequential(
nn.Linear(sensor_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, n_basis)
)
# Decoder: generates basis functions at query points
self.decoder = nn.Sequential(
nn.Linear(1, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, n_basis)
)
def encode(self, function_samples):
"""Extract coefficients from function samples"""
return self.encoder(function_samples)
def decode_basis(self, x):
"""Get basis function values at points x"""
return self.decoder(x)
def reconstruct(self, coefficients, x):
"""Reconstruct function from coefficients"""
if x.dim() == 2:
x = x.unsqueeze(0)
batch_size, n_points, _ = x.shape
basis_values = self.decoder(x.reshape(-1, 1))
basis_values = basis_values.view(batch_size, n_points, self.n_basis)
if coefficients.dim() == 1:
coefficients = coefficients.unsqueeze(0)
return torch.einsum('bn,bpn->bp', coefficients, basis_values)
def forward(self, function_samples, query_points):
coeffs = self.encode(function_samples)
return self.reconstruct(coeffs, query_points)
class B2BOperator:
"""B2B Operator Learning Framework"""
def __init__(self, source_encoder, target_encoder):
self.source_encoder = source_encoder
self.target_encoder = target_encoder
self.transformation_matrix = None
def learn_transformation(self, source_functions, target_functions, regularization=1e-6):
"""Learn transformation matrix A using least squares"""
self.source_encoder.eval()
self.target_encoder.eval()
with torch.no_grad():
# Encode all functions
source_coeffs = self.source_encoder.encode(source_functions)
target_coeffs = self.target_encoder.encode(target_functions)
# Solve least squares: Y = X @ A.T
# Add regularization for stability
X = source_coeffs.cpu()
Y = target_coeffs.cpu()
# Regularized least squares
XtX = X.T @ X + regularization * torch.eye(X.shape[1])
XtY = X.T @ Y
A = torch.linalg.solve(XtX, XtY).T
self.transformation_matrix = A.to(device)
# Compute fitting error
Y_pred = X @ A.T
mse = F.mse_loss(Y_pred, Y).item()
return A, mse
def apply(self, source_function, query_points):
"""Apply the learned operator"""
if self.transformation_matrix is None:
raise ValueError("Transformation matrix not learned yet")
self.source_encoder.eval()
self.target_encoder.eval()
with torch.no_grad():
# Encode source
source_coeffs = self.source_encoder.encode(source_function)
# Transform
target_coeffs = source_coeffs @ self.transformation_matrix.T
# Decode
return self.target_encoder.reconstruct(target_coeffs, query_points)
print("B2B Framework initialized")
print("Components: Encoder → Transformation → Decoder")
B2B Framework initialized Components: Encoder → Transformation → Decoder
Part 2: Helper Functions for Training¶
def train_encoder(encoder, functions, x_points, n_epochs=500, lr=1e-3, name="Encoder"):
"""Train a function encoder to reconstruct functions"""
optimizer = optim.Adam(encoder.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, factor=0.5)
losses = []
encoder.train()
x_tensor = torch.tensor(x_points, dtype=torch.float32).unsqueeze(-1).to(device)
pbar = tqdm(range(n_epochs), desc=f"Training {name}")
for epoch in pbar:
# Random batch
idx = np.random.choice(len(functions), min(32, len(functions)))
batch_functions = torch.tensor(functions[idx], dtype=torch.float32).to(device)
# Prepare query points
batch_x = x_tensor.unsqueeze(0).repeat(len(idx), 1, 1)
# Forward pass
pred = encoder(batch_functions, batch_x)
loss = F.mse_loss(pred, batch_functions)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
scheduler.step(loss)
if epoch % 50 == 0:
pbar.set_postfix({'Loss': f'{loss.item():.6f}'})
return losses
def visualize_basis_functions(encoder, x_range=(-2, 2), name="Encoder"):
"""Visualize learned basis functions"""
x = torch.linspace(x_range[0], x_range[1], 200).unsqueeze(-1).to(device)
with torch.no_grad():
basis = encoder.decode_basis(x).cpu().numpy()
x = x.cpu().numpy().squeeze()
plt.figure(figsize=(10, 4))
for i in range(min(basis.shape[1], 10)):
plt.plot(x, basis[:, i], linewidth=2, alpha=0.7, label=f'φ_{i+1}')
plt.title(f'{name} Basis Functions')
plt.xlabel('x')
plt.grid(True, alpha=0.3)
if basis.shape[1] <= 10:
plt.legend(ncol=2)
plt.show()
print("Helper functions defined")
Helper functions defined
Example 1: Derivative Operator¶
Learn $\mathcal{T}[f] = f'$ where $f'(x) = \frac{df}{dx}$ (derivative operator).
# Generate derivative data - SAME AS DEEPONET
def generate_derivative_data(num_functions=2000, num_points=100, x_range=(-2, 2)):
"""Generate cubic polynomials and their derivatives (same as DeepONet)"""
np.random.seed(42)
# Random cubic polynomial coefficients - SAME AS DEEPONET
coeffs = np.random.randn(num_functions, 4) * 0.5
x = np.linspace(x_range[0], x_range[1], num_points)
functions = np.zeros((num_functions, num_points))
derivatives = np.zeros((num_functions, num_points))
for i in range(num_functions):
a, b, c, d = coeffs[i]
# f(x) = ax^3 + bx^2 + cx + d
functions[i] = a * x**3 + b * x**2 + c * x + d
# f'(x) = 3ax^2 + 2bx + c
derivatives[i] = 3 * a * x**2 + 2 * b * x + c
return coeffs, x, functions, derivatives
print("=== DERIVATIVE OPERATOR EXAMPLE ===")
print("Learning: f(x) → f'(x)")
coeffs, x, functions, derivatives = generate_derivative_data()
# Split data - same 80/20 split
n_train = int(0.8 * len(functions))
train_functions = functions[:n_train]
train_derivatives = derivatives[:n_train]
test_functions = functions[n_train:]
test_derivatives = derivatives[n_train:]
print(f"Data: {n_train} training, {len(test_functions)} test functions")
print(f"Domain: x ∈ [{x[0]:.1f}, {x[-1]:.1f}]")
print(f"Coefficients: ax³ + bx² + cx + d with scale 0.5")
# Visualize samples
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for i in range(3):
ax = axes[i]
ax.plot(x, functions[i], 'b-', linewidth=2, label='f(x)')
ax.plot(x, derivatives[i], 'r-', linewidth=2, label="f'(x)")
a, b, c, d = coeffs[i]
ax.set_title(f'Sample {i+1}: [{a:.2f},{b:.2f},{c:.2f},{d:.2f}]')
ax.grid(True, alpha=0.3)
if i == 0:
ax.legend()
ax.set_xlabel('x')
plt.suptitle('Derivative Operator: Function → Derivative', fontsize=14)
plt.tight_layout()
plt.show()
=== DERIVATIVE OPERATOR EXAMPLE === Learning: f(x) → f'(x) (same setup as DeepONet) Data: 1600 training, 400 test functions Domain: x ∈ [-2.0, 2.0] Coefficients: ax³ + bx² + cx + d with scale 0.5
Train Function Encoders¶
print("\nTraining function encoders...")
# Create encoders - use small basis first like DeepONet (p=3)
# Then increase if needed
deriv_source_encoder = FunctionEncoder(sensor_dim=100, n_basis=10, hidden_dim=64).to(device)
deriv_target_encoder = FunctionEncoder(sensor_dim=100, n_basis=10, hidden_dim=64).to(device)
# Train source encoder (cubic polynomials)
print("\n1. Source encoder (cubic space - functions):")
source_losses = train_encoder(deriv_source_encoder, train_functions, x,
n_epochs=500, lr=0.001, name="Function Encoder")
# Train target encoder (quadratic polynomials - derivatives)
print("\n2. Target encoder (quadratic space - derivatives):")
target_losses = train_encoder(deriv_target_encoder, train_derivatives, x,
n_epochs=500, lr=0.001, name="Derivative Encoder")
# Visualize basis functions
visualize_basis_functions(deriv_source_encoder, x_range=(-2, 2), name="Source (Cubic)")
visualize_basis_functions(deriv_target_encoder, x_range=(-2, 2), name="Target (Quadratic)")
print(f"\nFinal losses - Source: {source_losses[-1]:.6f}, Target: {target_losses[-1]:.6f}")
Training function encoders... 1. Source encoder (cubic space - functions):
Training Function Encoder: 100%|██████████| 500/500 [00:01<00:00, 303.39it/s, Loss=0.028090]
2. Target encoder (quadratic space - derivatives):
Training Derivative Encoder: 100%|██████████| 500/500 [00:01<00:00, 273.72it/s, Loss=0.172096]
Final losses - Source: 0.020061, Target: 0.323011
Learn and Apply the Derivative Operator¶
# Create B2B operator
deriv_b2b = B2BOperator(deriv_source_encoder, deriv_target_encoder)
# Learn transformation matrix
print("Learning transformation matrix...")
train_source_tensor = torch.tensor(train_functions, dtype=torch.float32).to(device)
train_target_tensor = torch.tensor(train_derivatives, dtype=torch.float32).to(device)
A_deriv, fit_error = deriv_b2b.learn_transformation(
train_source_tensor, train_target_tensor
)
print(f"Transformation matrix shape: {A_deriv.shape}")
print(f"Fitting error: {fit_error:.6f}")
# Visualize transformation matrix and SVD
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
# Matrix heatmap
ax = axes[0]
im = ax.imshow(A_deriv.cpu().numpy(), cmap='RdBu_r', aspect='auto')
plt.colorbar(im, ax=ax, label='Weight')
ax.set_title('Derivative Operator Transformation Matrix')
ax.set_xlabel('Source Basis')
ax.set_ylabel('Target Basis')
# Singular values (line plot)
ax = axes[1]
U, S, Vt = torch.linalg.svd(A_deriv.cpu())
ax.plot(range(1, len(S)+1), S.numpy(), 'o-', linewidth=2, markersize=6)
ax.set_title('Singular Value Decay')
ax.set_xlabel('Index')
ax.set_ylabel('Singular Value')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)
# Cumulative energy
ax = axes[2]
cumsum = torch.cumsum(S**2, dim=0) / torch.sum(S**2)
ax.plot(range(1, len(S)+1), cumsum.numpy(), 'o-', linewidth=2, markersize=6, color='green')
ax.axhline(0.99, color='r', linestyle='--', label='99% energy')
ax.set_title('Cumulative Energy')
ax.set_xlabel('Number of Singular Values')
ax.set_ylabel('Fraction of Total Energy')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"\nMatrix rank: {torch.linalg.matrix_rank(A_deriv).item()}")
print(f"Condition number: {torch.linalg.cond(A_deriv).item():.2e}")
Learning transformation matrix... Transformation matrix shape: torch.Size([10, 10]) Fitting error: 0.001159
Matrix rank: 10 Condition number: 6.56e+03
Test the Derivative Operator¶
# Test on unseen functions
n_test_vis = 6
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()
test_errors = []
x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(device)
for i in range(n_test_vis):
test_idx = i * 10
# Apply B2B operator
source_func = torch.tensor(test_functions[test_idx:test_idx+1], dtype=torch.float32).to(device)
pred_deriv = deriv_b2b.apply(source_func, x_tensor).squeeze().cpu().numpy()
true_deriv = test_derivatives[test_idx]
# Compute error
mse = np.mean((pred_deriv - true_deriv)**2)
test_errors.append(mse)
# Plot
ax = axes[i]
ax.plot(x, test_functions[test_idx], 'b-', linewidth=2, alpha=0.7, label='f(x)')
ax.plot(x, true_deriv, 'r-', linewidth=2, label="True f'(x)")
ax.plot(x, pred_deriv, 'g--', linewidth=2, label="B2B f'(x)")
ax.set_title(f'Test {i+1}: MSE = {mse:.6f}')
ax.grid(True, alpha=0.3)
if i == 0:
ax.legend()
ax.set_xlabel('x')
plt.suptitle('B2B Derivative Operator Results', fontsize=14)
plt.tight_layout()
plt.show()
print(f"Average test MSE: {np.mean(test_errors):.6f} ± {np.std(test_errors):.6f}")
Average test MSE: 0.145266 ± 0.164010
Understanding Matrix Decompositions: SVD and Eigendecomposition¶
Before we compare transfer methods, let's build intuition for two fundamental matrix decompositions and how they're used in the B2B framework.
Singular Value Decomposition (SVD)¶
The geometric picture: Every matrix $A$ describes a linear transformation. SVD reveals this transformation as three simple steps:
$$A = U \Sigma V^T$$
- $V^T$: Rotate input to align with "principal input directions"
- $\Sigma$: Stretch along these directions (diagonal matrix of singular values)
- $U$: Rotate to "principal output directions"
Key insight: Any linear transformation = rotation → stretch → rotation.
Example in B2B: Our transformation matrix $A: \mathbb{R}^{10} \to \mathbb{R}^{10}$ maps source coefficients to target coefficients.
- Columns of $V$ are the "best input modes" (right singular vectors)
- Columns of $U$ are the corresponding "output modes" (left singular vectors)
- Diagonal entries $\sigma_i$ in $\Sigma$ measure the importance of each mode
Why it matters:
- Rank and compactness: Small singular values $\sigma_i \approx 0$ indicate the operator is compact—those directions contribute little
- Regularization: Drop small $\sigma_i$ to remove noise and improve stability
- Conditioning: The ratio $\sigma_{\text{max}}/\sigma_{\text{min}}$ reveals numerical sensitivity
The rank-$k$ approximation: $$A \approx U_k \Sigma_k V_k^T = \sum_{i=1}^{k} \sigma_i \mathbf{u}_i \mathbf{v}_i^T$$
This expresses the operator as a sum of rank-1 operators, ordered by importance.
Two Ways to Use SVD in B2B¶
Approach 1: Post-hoc SVD Analysis (What we use for B2B)
After learning the transformation matrix $A$ via least squares, compute its SVD to analyze the operator:
- Train encoders independently for input and output spaces
- Learn transformation: $A = (C_U^T C_U + \lambda I)^{-1} C_U^T C_V$
- Analyze $A$ via SVD: $A = U\Sigma V^T$
- Examine singular values for insight: decay rate, effective rank, conditioning
This gives us the "Direct Matrix" approach with spectral analysis afterward.
Approach 2: End-to-End SVD Learning (The SVD variant)
Learn the SVD structure directly by parameterizing $\{u_i\}, \{v_i\}, \{\sigma_i\}$ as neural networks and trainable parameters:
- Initialize basis networks $\{u_i(y | \theta_i^u)\}$ and $\{v_i(x | \theta_i^v)\}$
- Initialize singular values $\sigma = [\sigma_1, \sigma_2, \ldots, \sigma_k]$ as learnable parameters
- For input $f$, compute coefficients: $\alpha = \arg\min_\alpha \|f - \sum_j \alpha_j v_j\|^2$
- Predict output: $\hat{T}f = \sum_i \sigma_i \alpha_i u_i$
- Train end-to-end: minimize $\|Tf - \hat{T}f\|^2$ via gradient descent on $\{\theta^u\}, \{\theta^v\}, \{\sigma\}$
Trade-off: End-to-end training learns the SVD that minimizes operator error directly, but doesn't explicitly train bases to span the input/output spaces independently. This can hurt generalization outside the training distribution.
Eigendecomposition¶
The geometric picture: For square matrices, eigendecomposition finds special directions that only stretch—no rotation.
$$A = V \Lambda V^{-1}$$
where $\Lambda$ is diagonal (eigenvalues) and columns of $V$ are eigenvectors.
Key insight: If $\mathbf{v}$ is an eigenvector with eigenvalue $\lambda$, then $A\mathbf{v} = \lambda \mathbf{v}$. The transformation preserves the direction, only changing magnitude.
Example: For a self-adjoint operator (input and output spaces are the same), eigenvectors reveal the "natural modes."
- Each eigenvector $\mathbf{v}_i$ is an invariant direction
- The eigenvalue $\lambda_i$ is the "gain" in that direction
- Large $|\lambda_i|$ → mode amplified; small → suppressed
Eigendecomposition in B2B: Two Approaches¶
Approach 1: Via the Gram Matrix (Regularization)
For regularized least squares, we can solve via eigendecomposition of $X^TX$:
$$X^T X = V \Lambda V^T$$
Since $X^TX$ is symmetric positive semi-definite:
- $V$ is orthogonal
- $\Lambda$ has non-negative eigenvalues
The solution: $$A = (X^TX + \alpha I)^{-1} X^T Y = V(\Lambda + \alpha I)^{-1} V^T X^T Y$$
Tikhonov regularization: The term $\alpha I$ prevents division by tiny eigenvalues, stabilizing the solution.
Truncation: Keep only the top $k$ eigenvalues for a low-rank approximation.
Approach 2: End-to-End Eigendecomposition Learning (The ED variant)
For self-adjoint operators (same input/output space), learn the eigendecomposition directly:
- Initialize basis networks $\{v_i(x | \theta_i)\}$
- Initialize eigenvalues $\lambda = [\lambda_1, \lambda_2, \ldots, \lambda_k]$ as learnable parameters
- For input $f$, compute coefficients: $\alpha = \arg\min_\alpha \|f - \sum_j \alpha_j v_j\|^2$
- Predict output: $\hat{T}f = \sum_i \lambda_i \alpha_i v_i$
- Train end-to-end: minimize $\|Tf - \hat{T}f\|^2$ via gradient descent on $\{\theta\}, \{\lambda\}$
Constraint: Input and output must be on the same domain (self-adjoint operator).
Advantage: Directly learn physically meaningful eigenmodes and their eigenvalues.
Comparison: Three Approaches¶
Property | Direct Matrix (B2B) | SVD (End-to-End) | ED (End-to-End) |
---|---|---|---|
Training | Two-stage: encoders then matrix | End-to-end | End-to-end |
Transformation | $A$ via least squares | $\{U, \Sigma, V\}$ via gradient descent | $\{V, \Lambda\}$ via gradient descent |
Input/Output spaces | Can be different | Can be different | Must be the same |
Basis quality | Explicitly trained to span spaces | Trained only for operator loss | Trained only for operator loss |
Generalization | Best for unseen functions | May struggle outside training | May struggle outside training |
Interpretability | SVD analysis post-hoc | Direct access to singular values | Direct access to eigenvalues |
Best for | General linear operators | Linear operators needing spectral analysis | Self-adjoint operators |
Key Takeaway¶
Direct Matrix (B2B): Learns good representations of function spaces, then finds the transformation. Better generalization.
SVD/ED (End-to-End): Learns the spectral decomposition directly to minimize operator error. More interpretable eigenstructure, but may not capture the full input/output spaces as well.
Both approaches give you spectral information, but they optimize different objectives!
Understanding Matrix Decompositions: SVD and Eigendecomposition¶
Before we compare transfer methods, let's build intuition for two fundamental matrix decompositions.
Singular Value Decomposition (SVD)¶
The geometric picture: Every matrix $A$ describes a linear transformation. SVD reveals this transformation in three simple steps:
$$A = U \Sigma V^T$$
- $V^T$: Rotate to align with the "input directions"
- $\Sigma$: Stretch along these directions (diagonal matrix of singular values)
- $U$: Rotate to the "output directions"
Key insight: Any linear transformation is fundamentally a rotation, followed by axis-aligned stretching, followed by another rotation.
Example: Consider our transformation matrix $A: \mathbb{R}^{10} \to \mathbb{R}^{10}$ from source basis to target basis.
- The columns of $V$ are the "best input directions" (right singular vectors)
- The columns of $U$ are the corresponding "output directions" (left singular vectors)
- The diagonal entries $\sigma_i$ in $\Sigma$ tell us how much each direction gets stretched
Why it matters for operators:
- Truncation: Small singular values $\sigma_i \approx 0$ mean those directions contribute little. Drop them for a low-rank approximation.
- Regularization: Discard noisy modes by keeping only the top $k$ singular values.
- Conditioning: If $\sigma_{\text{max}}/\sigma_{\text{min}}$ is large, the operator is ill-conditioned (sensitive to noise).
The SVD recipe for operator learning: $$A \approx U_k \Sigma_k V_k^T = \sum_{i=1}^{k} \sigma_i \mathbf{u}_i \mathbf{v}_i^T$$
This is a sum of rank-1 operators, ordered by importance. Keep the first $k$ terms for a stable approximation.
Eigendecomposition¶
The geometric picture: For square matrices, eigendecomposition finds special directions that don't rotate—they only stretch.
$$A = V \Lambda V^{-1}$$
where $\Lambda$ is diagonal (eigenvalues) and columns of $V$ are eigenvectors.
Key insight: If $\mathbf{v}$ is an eigenvector with eigenvalue $\lambda$, then $A\mathbf{v} = \lambda \mathbf{v}$. The direction doesn't change, only the magnitude.
Example: Our transformation matrix $A$ maps source coefficients to target coefficients. Eigenvectors reveal the "natural modes" of this transformation.
- Each eigenvector $\mathbf{v}_i$ is a direction in coefficient space
- The eigenvalue $\lambda_i$ tells us the "gain" in that direction
- Large $|\lambda_i|$ means that mode is amplified; small means suppressed
Why eigendecomposition for operator learning?
Our approach uses eigendecomposition of the Gram matrix $X^T X$ (not $A$ directly):
$$X^T X = V \Lambda V^T$$
Since $X^T X$ is symmetric positive semi-definite, $V$ is orthogonal and $\Lambda$ has non-negative eigenvalues.
The connection to least squares: $$A = (X^T X)^{-1} X^T Y = V \Lambda^{-1} V^T X^T Y$$
Small eigenvalues in $\Lambda$ cause instability (division by near-zero). The fix: truncate or regularize.
Tikhonov regularization via eigendecomposition: $$A = V (\Lambda + \alpha I)^{-1} V^T X^T Y$$
The term $\alpha I$ prevents division by tiny eigenvalues.
SVD vs Eigendecomposition: When to Use Each¶
Property | SVD | Eigendecomposition |
---|---|---|
Works for | Any matrix (rectangular OK) | Square matrices only |
Geometric meaning | Input→Output direction pairs | Invariant directions |
Decomposition | $A = U\Sigma V^T$ | $A = V\Lambda V^{-1}$ |
Orthogonality | $U, V$ always orthogonal | $V$ orthogonal only if $A$ symmetric |
Stability | Always numerically stable | Can be unstable for non-symmetric $A$ |
Our use case | Direct analysis of $A$ | Regularized least squares via $X^TX$ |
For our B2B framework:
- SVD of $A$: Reveals the most important basis transformations
- Eigen of $X^T X$: Regularizes the least squares solution by controlling small eigenvalues
Both give us tools to understand and regularize the operator transformation matrix.
Transfer Methods Comparison¶
We'll compare four approaches for learning the transformation:
- Direct Matrix (Least Squares): $A = (C_U^T C_U)^{-1} C_U^T C_V$
- SVD-based: Truncate small singular values for regularization
- Eigendecomposition: Use spectral decomposition for symmetric approximation
- Non-linear MLP: Learn non-linear transformation between coefficient spaces
class TransferMethodComparison:
"""Compare different transformation learning methods"""
def __init__(self, source_encoder, target_encoder):
self.source_encoder = source_encoder
self.target_encoder = target_encoder
self.methods = {}
def learn_direct_matrix(self, source_funcs, target_funcs, reg=1e-6):
"""Method 1: Direct least squares matrix"""
with torch.no_grad():
C_U = self.source_encoder.encode(source_funcs)
C_V = self.target_encoder.encode(target_funcs)
X = C_U.cpu()
Y = C_V.cpu()
# Regularized least squares
XtX = X.T @ X + reg * torch.eye(X.shape[1])
XtY = X.T @ Y
A = torch.linalg.solve(XtX, XtY).T
self.methods['matrix'] = A.to(device)
return A.to(device)
def learn_svd_truncated(self, source_funcs, target_funcs, k=None):
"""Method 2: SVD with truncation"""
with torch.no_grad():
C_U = self.source_encoder.encode(source_funcs)
C_V = self.target_encoder.encode(target_funcs)
X = C_U.cpu()
Y = C_V.cpu()
# SVD of correlation matrix
U_x, S_x, Vt_x = torch.linalg.svd(X, full_matrices=False)
U_y, S_y, Vt_y = torch.linalg.svd(Y, full_matrices=False)
# Truncate to k components
if k is None:
k = min(X.shape[1], Y.shape[1])
# Compute transformation via SVD
# A = Y^T X (X^T X)^{-1} ≈ V_y S_y U_y^T U_x S_x^{-1} V_x^T
S_x_inv = torch.zeros_like(S_x)
S_x_inv[:k] = 1.0 / (S_x[:k] + 1e-6)
A = Vt_y[:k, :].T @ torch.diag(S_y[:k]) @ U_y[:, :k].T @ U_x[:, :k] @ torch.diag(S_x_inv[:k]) @ Vt_x[:k, :]
self.methods['svd'] = A.to(device)
return A.to(device)
def learn_eigendecomposition(self, source_funcs, target_funcs, k=None):
"""Method 3: Eigenvalue-based regularization (Tikhonov)
Uses eigendecomposition for regularized least squares:
1. Decompose X^T X = V D V^T
2. Truncate small eigenvalues
3. Solve: A = Y^T X V D_k^{-1} V^T
This is mathematically equivalent to SVD but uses eigendecomposition.
"""
with torch.no_grad():
C_U = self.source_encoder.encode(source_funcs)
C_V = self.target_encoder.encode(target_funcs)
X = C_U.cpu()
Y = C_V.cpu()
# Eigendecomposition of Gram matrix X^T X
Gram = X.T @ X
eigvals, eigvecs = torch.linalg.eigh(Gram)
# Sort descending
idx = torch.argsort(eigvals, descending=True)
eigvals = eigvals[idx]
eigvecs = eigvecs[:, idx]
if k is None:
k = len(eigvals)
# Truncate to k components and regularize
eigvals_k = eigvals[:k]
eigvecs_k = eigvecs[:, :k]
# Compute A = Y^T X V D^{-1} V^T
# This is equivalent to regularized least squares with eigenvalue truncation
D_inv = torch.diag(1.0 / (eigvals_k + 1e-6))
A = Y.T @ X @ eigvecs_k @ D_inv @ eigvecs_k.T
self.methods['eigen'] = A.to(device)
return A.to(device)
def learn_mlp(self, source_funcs, target_funcs, hidden_dim=64, n_epochs=500, lr=1e-3):
"""Method 4: Non-linear MLP transformation"""
# Get coefficients
with torch.no_grad():
C_U = self.source_encoder.encode(source_funcs)
C_V = self.target_encoder.encode(target_funcs)
n_source = C_U.shape[1]
n_target = C_V.shape[1]
# MLP architecture
mlp = nn.Sequential(
nn.Linear(n_source, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, n_target)
).to(device)
optimizer = optim.Adam(mlp.parameters(), lr=lr)
# Train
dataset = TensorDataset(C_U, C_V)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
mlp.train()
for epoch in tqdm(range(n_epochs), desc="Training MLP"):
for batch_u, batch_v in loader:
pred_v = mlp(batch_u)
loss = F.mse_loss(pred_v, batch_v)
optimizer.zero_grad()
loss.backward()
optimizer.step()
self.methods['mlp'] = mlp
return mlp
def apply_method(self, method_name, source_func, query_points):
"""Apply a specific method"""
self.source_encoder.eval()
self.target_encoder.eval()
with torch.no_grad():
source_coeffs = self.source_encoder.encode(source_func)
if method_name == 'mlp':
target_coeffs = self.methods[method_name](source_coeffs)
else:
target_coeffs = source_coeffs @ self.methods[method_name].T
return self.target_encoder.reconstruct(target_coeffs, query_points)
print("Transfer method comparison framework ready")
Transfer method comparison framework ready
# Train all transfer methods on derivative operator
print("=== COMPARING TRANSFER METHODS ===\n")
transfer_comp = TransferMethodComparison(deriv_source_encoder, deriv_target_encoder)
print("Learning transformations...")
print("\n1. Direct Matrix (Least Squares)")
A_matrix = transfer_comp.learn_direct_matrix(train_source_tensor, train_target_tensor)
print(f" Matrix shape: {A_matrix.shape}")
print("\n2. SVD-based (k=7 components)")
A_svd = transfer_comp.learn_svd_truncated(train_source_tensor, train_target_tensor, k=7)
print(f" Matrix shape: {A_svd.shape}")
print("\n3. Eigendecomposition (k=7 components)")
A_eigen = transfer_comp.learn_eigendecomposition(train_source_tensor, train_target_tensor, k=7)
print(f" Matrix shape: {A_eigen.shape}")
print("\n4. Non-linear MLP")
mlp_transform = transfer_comp.learn_mlp(train_source_tensor, train_target_tensor,
hidden_dim=32, n_epochs=300)
print(f" MLP parameters: {sum(p.numel() for p in mlp_transform.parameters())}")
=== COMPARING TRANSFER METHODS === Learning transformations... 1. Direct Matrix (Least Squares) Matrix shape: torch.Size([10, 10]) 2. SVD-based (k=7 components) Matrix shape: torch.Size([10, 10]) 3. Eigendecomposition (k=7 components) Matrix shape: torch.Size([10, 10]) 4. Non-linear MLP
Training MLP: 100%|██████████| 300/300 [00:16<00:00, 18.73it/s]
MLP parameters: 1738
# Visualize transformation matrices and SVD decay
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
methods_to_plot = [('matrix', 'Direct Matrix'), ('svd', 'SVD (k=7)'), ('eigen', 'Eigen (k=7)')]
# Top row: Transformation matrices
for col, (method, title) in enumerate(methods_to_plot):
ax = axes[0, col]
A = transfer_comp.methods[method].cpu().numpy()
im = ax.imshow(A, cmap='RdBu_r', aspect='auto', vmin=-2, vmax=2)
ax.set_title(title)
ax.set_xlabel('Source Basis')
ax.set_ylabel('Target Basis')
plt.colorbar(im, ax=ax, label='Weight')
# Bottom row: SVD decay (line plots)
ax = axes[1, 0]
for method, label, color in [('matrix', 'Direct', 'blue'), ('svd', 'SVD', 'orange'), ('eigen', 'Eigen', 'green')]:
A = transfer_comp.methods[method].cpu()
_, S, _ = torch.linalg.svd(A)
ax.plot(range(1, len(S)+1), S.numpy(), 'o-', label=label, linewidth=2, markersize=6, color=color)
ax.set_xlabel('Index')
ax.set_ylabel('Singular Value')
ax.set_title('Singular Value Decay Comparison')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)
# Condition numbers
ax = axes[1, 1]
cond_nums = []
method_labels = []
for method, label in [('matrix', 'Direct'), ('svd', 'SVD'), ('eigen', 'Eigen')]:
A = transfer_comp.methods[method].cpu()
cond = torch.linalg.cond(A).item()
cond_nums.append(cond)
method_labels.append(label)
ax.bar(method_labels, cond_nums, alpha=0.7, color=['blue', 'orange', 'green'])
ax.set_ylabel('Condition Number')
ax.set_title('Matrix Conditioning')
ax.set_yscale('log')
ax.grid(True, alpha=0.3, axis='y')
# Rank comparison
ax = axes[1, 2]
ranks = []
for method, label in [('matrix', 'Direct'), ('svd', 'SVD'), ('eigen', 'Eigen')]:
A = transfer_comp.methods[method].cpu()
rank = torch.linalg.matrix_rank(A).item()
ranks.append(rank)
ax.bar(method_labels, ranks, alpha=0.7, color=['blue', 'orange', 'green'])
ax.set_ylabel('Matrix Rank')
ax.set_title('Effective Rank')
ax.set_ylim([0, max(ranks) + 2])
ax.grid(True, alpha=0.3, axis='y')
plt.suptitle('Transfer Methods: Matrices and Spectral Analysis', fontsize=14)
plt.tight_layout()
plt.show()
# Test all methods
methods = ['matrix', 'svd', 'eigen', 'mlp']
method_names = ['Direct Matrix', 'SVD (k=7)', 'Eigendecomp (k=7)', 'Non-linear MLP']
fig, axes = plt.subplots(4, 4, figsize=(16, 16))
test_indices = [5, 25, 45, 65]
method_errors = {m: [] for m in methods}
for row, method in enumerate(methods):
for col, test_idx in enumerate(test_indices):
ax = axes[row, col]
# Apply method
source_func = torch.tensor(test_functions[test_idx:test_idx+1], dtype=torch.float32).to(device)
pred_deriv = transfer_comp.apply_method(method, source_func, x_tensor).squeeze().cpu().numpy()
true_deriv = test_derivatives[test_idx]
# Compute error
mse = np.mean((pred_deriv - true_deriv)**2)
method_errors[method].append(mse)
# Plot
ax.plot(x, test_functions[test_idx], 'b-', linewidth=2, alpha=0.6, label='f(x)')
ax.plot(x, true_deriv, 'r-', linewidth=2, label="True f'")
ax.plot(x, pred_deriv, 'g--', linewidth=2, label=method_names[row])
if row == 0:
ax.set_title(f'Test {col+1}', fontsize=12)
if col == 0:
ax.set_ylabel(method_names[row], fontsize=11)
ax.text(0.95, 0.05, f'MSE={mse:.6f}', transform=ax.transAxes,
ha='right', va='bottom', fontsize=9, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
if row == 0 and col == 0:
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
ax.set_xlabel('x')
plt.suptitle('Transfer Methods Comparison: Derivative Operator', fontsize=16)
plt.tight_layout()
plt.show()
# Summary statistics
print("\n=== PERFORMANCE SUMMARY ===")
print(f"{'Method':<20} {'Mean MSE':>12} {'Std MSE':>12}")
print("-" * 45)
for method, name in zip(methods, method_names):
errors = method_errors[method]
print(f"{name:<20} {np.mean(errors):>12.6f} {np.std(errors):>12.6f}")
=== PERFORMANCE SUMMARY === Method Mean MSE Std MSE --------------------------------------------- Direct Matrix 0.126488 0.077209 SVD (k=7) 0.125766 0.073900 Eigendecomp (k=7) 0.125748 0.073900 Non-linear MLP 0.127705 0.077591
Example 2: Poisson Equation Solver¶
Learn the solution operator for the Poisson equation: $$-\nabla^2 u = f \text{ in } \Omega, \quad u = 0 \text{ on } \partial\Omega$$
In 1D: $-\frac{d^2u}{dx^2} = f(x)$ with $u(0) = u(1) = 0$
def generate_poisson_data(n_samples=1000, n_points=100):
"""Generate Poisson equation data"""
x = np.linspace(0, 1, n_points)
dx = x[1] - x[0]
# Create finite difference matrix for -d²/dx²
main_diag = 2 * np.ones(n_points - 2) / dx**2
off_diag = -np.ones(n_points - 3) / dx**2
A_fd = diags([off_diag, main_diag, off_diag], [-1, 0, 1]).toarray()
A_inv = np.linalg.inv(A_fd)
sources = []
solutions = []
np.random.seed(42)
for i in range(n_samples):
# Generate random source function (combination of sines)
f = np.zeros(n_points)
n_modes = np.random.randint(2, 6)
for k in range(n_modes):
mode = np.random.randint(1, 10)
amplitude = np.random.randn()
phase = np.random.rand() * 2 * np.pi
f += amplitude * np.sin(mode * np.pi * x + phase)
# Solve Poisson equation
u = np.zeros(n_points)
u[1:-1] = A_inv @ f[1:-1]
sources.append(f)
solutions.append(u)
return np.array(sources), np.array(solutions), x
print("\n=== POISSON EQUATION EXAMPLE ===")
poisson_sources, poisson_solutions, x_poisson = generate_poisson_data(n_samples=1500)
# Split data
n_train_poisson = 1200
train_sources_poisson = poisson_sources[:n_train_poisson]
train_solutions_poisson = poisson_solutions[:n_train_poisson]
test_sources_poisson = poisson_sources[n_train_poisson:]
test_solutions_poisson = poisson_solutions[n_train_poisson:]
print(f"Data: {n_train_poisson} training, {len(test_sources_poisson)} test")
# Visualize samples
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for i in range(3):
ax = axes[i]
ax.plot(x_poisson, poisson_sources[i], 'r-', linewidth=2, label='f(x)')
ax.plot(x_poisson, poisson_solutions[i], 'b-', linewidth=2, label='u(x)')
ax.set_title(f'Sample {i+1}')
ax.grid(True, alpha=0.3)
if i == 0:
ax.legend()
ax.set_xlabel('x')
plt.suptitle('Poisson Equation: Source → Solution')
plt.tight_layout()
plt.show()
=== POISSON EQUATION EXAMPLE === Data: 1200 training, 300 test
Train Encoders for Poisson¶
print("\nTraining Poisson encoders...")
# Create encoders with more basis functions for this problem
poisson_source_encoder = FunctionEncoder(sensor_dim=100, n_basis=15, hidden_dim=64).to(device)
poisson_solution_encoder = FunctionEncoder(sensor_dim=100, n_basis=15, hidden_dim=64).to(device)
# Train encoders
print("1. Source encoder (f space):")
poisson_source_losses = train_encoder(poisson_source_encoder, train_sources_poisson, x_poisson,
n_epochs=400, name="Source Encoder")
print("\n2. Solution encoder (u space):")
poisson_solution_losses = train_encoder(poisson_solution_encoder, train_solutions_poisson, x_poisson,
n_epochs=400, name="Solution Encoder")
print(f"\nFinal losses - Source: {poisson_source_losses[-1]:.6f}, Solution: {poisson_solution_losses[-1]:.6f}")
Training Poisson encoders... 1. Source encoder (f space):
Training Source Encoder: 100%|██████████| 400/400 [00:01<00:00, 380.19it/s, Loss=1.213085]
2. Solution encoder (u space):
Training Solution Encoder: 100%|██████████| 400/400 [00:01<00:00, 313.71it/s, Loss=0.000266]
Final losses - Source: 1.497604, Solution: 0.000363
Learn and Apply Poisson Solver¶
# Create B2B operator for Poisson
poisson_b2b = B2BOperator(poisson_source_encoder, poisson_solution_encoder)
# Learn transformation
print("Learning Poisson transformation matrix...")
train_source_poisson_tensor = torch.tensor(train_sources_poisson, dtype=torch.float32).to(device)
train_solution_poisson_tensor = torch.tensor(train_solutions_poisson, dtype=torch.float32).to(device)
A_poisson, fit_error_poisson = poisson_b2b.learn_transformation(
train_source_poisson_tensor, train_solution_poisson_tensor
)
print(f"Transformation matrix shape: {A_poisson.shape}")
print(f"Fitting error: {fit_error_poisson:.6f}")
# Visualize transformation matrix and SVD
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
# Matrix heatmap
ax = axes[0]
im = ax.imshow(A_poisson.cpu().numpy(), cmap='RdBu_r', aspect='auto')
plt.colorbar(im, ax=ax, label='Weight')
ax.set_title('Poisson Solver Transformation Matrix')
ax.set_xlabel('Source Basis')
ax.set_ylabel('Solution Basis')
# Singular values (line plot)
ax = axes[1]
U, S, Vt = torch.linalg.svd(A_poisson.cpu())
ax.plot(range(1, len(S)+1), S.numpy(), 'o-', linewidth=2, markersize=6, color='orange')
ax.set_title('Singular Value Decay')
ax.set_xlabel('Index')
ax.set_ylabel('Singular Value')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)
# Cumulative energy
ax = axes[2]
cumsum = torch.cumsum(S**2, dim=0) / torch.sum(S**2)
ax.plot(range(1, len(S)+1), cumsum.numpy(), 'o-', linewidth=2, markersize=6, color='green')
ax.axhline(0.99, color='r', linestyle='--', label='99% energy')
ax.set_title('Cumulative Energy')
ax.set_xlabel('Number of Singular Values')
ax.set_ylabel('Fraction of Total Energy')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"\nMatrix rank: {torch.linalg.matrix_rank(A_poisson).item()}")
print(f"Condition number: {torch.linalg.cond(A_poisson).item():.2e}")
Learning Poisson transformation matrix... Transformation matrix shape: torch.Size([15, 15]) Fitting error: 0.000034
Matrix rank: 15 Condition number: 3.96e+04
Test Poisson Solver¶
# Test Poisson solver
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()
poisson_test_errors = []
x_poisson_tensor = torch.tensor(x_poisson, dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(device)
for i in range(6):
test_idx = i * 10
# Apply B2B operator
source_func = torch.tensor(test_sources_poisson[test_idx:test_idx+1], dtype=torch.float32).to(device)
pred_solution = poisson_b2b.apply(source_func, x_poisson_tensor).squeeze().cpu().numpy()
true_solution = test_solutions_poisson[test_idx]
# Compute error
mse = np.mean((pred_solution - true_solution)**2)
rel_error = np.sqrt(mse) / np.sqrt(np.mean(true_solution**2) + 1e-8)
poisson_test_errors.append(rel_error)
# Plot
ax = axes[i]
ax.plot(x_poisson, test_sources_poisson[test_idx], 'r-', linewidth=2, alpha=0.7, label='f(x)')
ax.plot(x_poisson, true_solution, 'b-', linewidth=2, label='True u(x)')
ax.plot(x_poisson, pred_solution, 'g--', linewidth=2, label='B2B u(x)')
ax.set_title(f'Test {i+1}: Rel. Error = {rel_error:.4f}')
ax.grid(True, alpha=0.3)
if i == 0:
ax.legend()
ax.set_xlabel('x')
plt.suptitle('B2B Poisson Solver Results', fontsize=14)
plt.tight_layout()
plt.show()
print(f"Average relative error: {np.mean(poisson_test_errors):.4f} ± {np.std(poisson_test_errors):.4f}")
Average relative error: 0.8479 ± 0.1633
Example 3: 1D Nonlinear Darcy Flow¶
Same as DeepONet: Solve the nonlinear Darcy equation with solution-dependent permeability.
def generate_darcy_data(n_funcs=1000, n_points=40):
"""Generate 1D nonlinear Darcy flow data"""
def permeability(s):
return 0.2 + s**2
# Gaussian process for source function
x = np.linspace(0, 1, n_points)
l, sigma = 0.04, 1.0
K = sigma**2 * np.exp(-0.5 * (x[:, None] - x[None, :])**2 / l**2)
K += 1e-6 * np.eye(n_points)
def solve_darcy(u_func):
dx = x[1] - x[0]
s = np.zeros(n_points)
for _ in range(100): # Fixed point iteration
kappa = permeability(s)
main_diag = (kappa[1:] + kappa[:-1]) / dx**2
upper_diag = -kappa[1:-1] / dx**2
lower_diag = -kappa[1:-1] / dx**2
A = diags([lower_diag, main_diag, upper_diag], [-1, 0, 1],
shape=(n_points-2, n_points-2))
s_interior = spsolve(A, u_func[1:-1])
s_new = np.zeros(n_points)
s_new[1:-1] = s_interior
s = 0.5 * s_new + 0.5 * s
return s
# Generate dataset
np.random.seed(42)
U, S = [], []
print("Generating Darcy dataset...")
for i in tqdm(range(n_funcs), desc="Solving PDEs"):
u = multivariate_normal.rvs(mean=np.zeros(n_points), cov=K)
s = solve_darcy(u)
U.append(u)
S.append(s)
return np.array(U), np.array(S), x
print("\n=== 1D NONLINEAR DARCY EXAMPLE ===")
darcy_sources, darcy_solutions, x_darcy = generate_darcy_data(n_funcs=1000)
# Split data
n_train_darcy = 800
train_sources_darcy = darcy_sources[:n_train_darcy]
train_solutions_darcy = darcy_solutions[:n_train_darcy]
test_sources_darcy = darcy_sources[n_train_darcy:]
test_solutions_darcy = darcy_solutions[n_train_darcy:]
print(f"\nData: {n_train_darcy} training, {len(test_sources_darcy)} test")
# Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for i in range(3):
ax = axes[i]
ax.plot(x_darcy, darcy_sources[i], 'g-', linewidth=2, label='f(x)')
ax.plot(x_darcy, darcy_solutions[i], 'b-', linewidth=2, label='u(x)')
ax.set_title(f'Sample {i+1}')
ax.grid(True, alpha=0.3)
if i == 0:
ax.legend()
ax.set_xlabel('x')
plt.suptitle('1D Nonlinear Darcy Flow')
plt.tight_layout()
plt.show()
=== 1D NONLINEAR DARCY EXAMPLE === Generating Darcy dataset...
Solving PDEs: 100%|██████████| 1000/1000 [00:07<00:00, 125.79it/s]
Data: 800 training, 200 test
Train Encoders for Darcy¶
print("\nTraining Darcy encoders...")
# Create encoders
darcy_source_encoder = FunctionEncoder(sensor_dim=40, n_basis=20, hidden_dim=128).to(device)
darcy_solution_encoder = FunctionEncoder(sensor_dim=40, n_basis=20, hidden_dim=128).to(device)
# Train
print("1. Source encoder:")
darcy_source_losses = train_encoder(darcy_source_encoder, train_sources_darcy, x_darcy,
n_epochs=500, lr=0.001, name="Darcy Source")
print("\n2. Solution encoder:")
darcy_solution_losses = train_encoder(darcy_solution_encoder, train_solutions_darcy, x_darcy,
n_epochs=500, lr=0.001, name="Darcy Solution")
print(f"\nFinal losses - Source: {darcy_source_losses[-1]:.6f}, Solution: {darcy_solution_losses[-1]:.6f}")
Training Darcy encoders... 1. Source encoder:
Training Darcy Source: 100%|██████████| 500/500 [00:01<00:00, 332.48it/s, Loss=0.817789]
2. Solution encoder:
Training Darcy Solution: 100%|██████████| 500/500 [00:01<00:00, 307.72it/s, Loss=0.000827]
Final losses - Source: 0.847877, Solution: 0.000871
Learn and Apply Darcy Operator¶
# Create B2B operator for Darcy
darcy_b2b = B2BOperator(darcy_source_encoder, darcy_solution_encoder)
# Learn transformation
print("Learning Darcy transformation matrix...")
train_source_darcy_tensor = torch.tensor(train_sources_darcy, dtype=torch.float32).to(device)
train_solution_darcy_tensor = torch.tensor(train_solutions_darcy, dtype=torch.float32).to(device)
A_darcy, fit_error_darcy = darcy_b2b.learn_transformation(
train_source_darcy_tensor, train_solution_darcy_tensor
)
print(f"Transformation matrix shape: {A_darcy.shape}")
print(f"Fitting error: {fit_error_darcy:.6f}")
# Visualize transformation matrix and SVD
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
# Matrix heatmap
ax = axes[0]
im = ax.imshow(A_darcy.cpu().numpy(), cmap='RdBu_r', aspect='auto')
plt.colorbar(im, ax=ax, label='Weight')
ax.set_title('Darcy Operator Transformation Matrix')
ax.set_xlabel('Source Basis')
ax.set_ylabel('Solution Basis')
# Singular values (line plot)
ax = axes[1]
U, S, Vt = torch.linalg.svd(A_darcy.cpu())
ax.plot(range(1, len(S)+1), S.numpy(), 'o-', linewidth=2, markersize=6, color='purple')
ax.set_title('Singular Value Decay')
ax.set_xlabel('Index')
ax.set_ylabel('Singular Value')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)
# Cumulative energy
ax = axes[2]
cumsum = torch.cumsum(S**2, dim=0) / torch.sum(S**2)
ax.plot(range(1, len(S)+1), cumsum.numpy(), 'o-', linewidth=2, markersize=6, color='green')
ax.axhline(0.99, color='r', linestyle='--', label='99% energy')
ax.set_title('Cumulative Energy')
ax.set_xlabel('Number of Singular Values')
ax.set_ylabel('Fraction of Total Energy')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"\nMatrix rank: {torch.linalg.matrix_rank(A_darcy).item()}")
print(f"Condition number: {torch.linalg.cond(A_darcy).item():.2e}")
Learning Darcy transformation matrix... Transformation matrix shape: torch.Size([20, 20]) Fitting error: 0.000621
Matrix rank: 20 Condition number: 4.39e+04
Test Darcy Operator¶
# Test Darcy operator
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()
darcy_test_errors = []
x_darcy_tensor = torch.tensor(x_darcy, dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(device)
for i in range(6):
test_idx = i * 5
# Apply B2B
source_func = torch.tensor(test_sources_darcy[test_idx:test_idx+1], dtype=torch.float32).to(device)
pred_solution = darcy_b2b.apply(source_func, x_darcy_tensor).squeeze().cpu().numpy()
true_solution = test_solutions_darcy[test_idx]
# Error
mse = np.mean((pred_solution - true_solution)**2)
rel_error = np.sqrt(mse) / np.sqrt(np.mean(true_solution**2) + 1e-8)
darcy_test_errors.append(rel_error)
# Plot
ax = axes[i]
ax.plot(x_darcy, test_sources_darcy[test_idx], 'g-', linewidth=2, alpha=0.7, label='f(x)')
ax.plot(x_darcy, true_solution, 'b-', linewidth=2, label='True u(x)')
ax.plot(x_darcy, pred_solution, 'r--', linewidth=2, label='B2B u(x)')
ax.set_title(f'Test {i+1}: Rel. Error = {rel_error:.4f}')
ax.grid(True, alpha=0.3)
if i == 0:
ax.legend()
ax.set_xlabel('x')
plt.suptitle('B2B Darcy Operator Results', fontsize=14)
plt.tight_layout()
plt.show()
print(f"Average relative error: {np.mean(darcy_test_errors):.4f} ± {np.std(darcy_test_errors):.4f}")
Average relative error: 0.2108 ± 0.1021