Welcome to the final stage of our journey!
The Goal¶
We have successfully trained a Physics-Informed Neural Network (PINN) to solve a Partial Differential Equation (PDE). However, the neural network is a “black box” — it gives us the right numbers, but not the mathematical understanding.
Our Mission: Use Symbolic Regression to extracting the analytical law directly from the trained neural network.
The Process¶
Train a PINN: We will re-train a PINN on the Heat Equation: .
The Oracle: Use the trained PINN as an “Oracle” to generate clean, high-resolution data.
Distillation: Use
PySRto find the symbolic equation that best fits the PINN’s output.Discovery: Recover the analytical solution: .
1. Setup and Dependencies¶
We need torch for the PINN and pysr for symbolic regression.
import matplotlib.pyplot as plt
# create a colormap with x,t and y
import numpy as np
import torch
import torch.nn as nn
from pysr import PySRRegressor
# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython
Using device: cpu
2. Train the PINN (The Teacher)¶
We solved the Heat Equation in Day 2. Let’s quickly re-implement it here to have a fresh model.
Problem:
with , and initial condition .
class PINN(nn.Module):
def __init__(self):
super().__init__()
# Simple MLP: 2 inputs (x, t) -> 1 output (u)
self.net = nn.Sequential(
nn.Linear(2, 32), nn.Tanh(), nn.Linear(32, 32), nn.Tanh(), nn.Linear(32, 1)
)
def forward(self, x, t):
# Concatenate x and t to form a (N, 2) input
inputs = torch.cat([x, t], dim=1)
return self.net(inputs)
def compute_pde_residual(model, x, t, alpha=0.01):
x.requires_grad = True
t.requires_grad = True
u = model(x, t)
# Gradients
u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[
0
]
u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True)[
0
]
u_xx = torch.autograd.grad(
u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True
)[0]
return u_t - alpha * u_xx
# True solution for comparison
def true_solution(x, t, alpha=0.01):
return np.sin(np.pi * x) * np.exp(-alpha * np.pi**2 * t)Training Loop¶
model = PINN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
alpha = 0.01
iterations = 5000
print("Training PINN...")
for i in range(iterations):
optimizer.zero_grad()
# 1. Collocation points for PDE
x_pde = torch.rand(2000, 1).to(device) # x in [0, 1]
t_pde = torch.rand(2000, 1).to(device) # t in [0, 1]
res = compute_pde_residual(model, x_pde, t_pde, alpha)
loss_pde = torch.mean(res**2)
# 2. Initial Condition (t=0)
x_ic = torch.rand(500, 1).to(device)
t_ic = torch.zeros(500, 1).to(device)
u_ic_pred = model(x_ic, t_ic)
u_ic_true = torch.sin(np.pi * x_ic)
loss_ic = torch.mean((u_ic_pred - u_ic_true) ** 2)
# 3. Boundary Conditions (x=0, x=1)
t_bc = torch.rand(500, 1).to(device)
x_bc0 = torch.zeros(500, 1).to(device)
x_bc1 = torch.ones(500, 1).to(device)
u_bc0 = model(x_bc0, t_bc)
u_bc1 = model(x_bc1, t_bc)
loss_bc = torch.mean(u_bc0**2) + torch.mean(u_bc1**2)
loss = loss_pde + loss_ic + loss_bc
loss.backward()
optimizer.step()
if i % 1000 == 0:
print(f"Iter {i}, Loss: {loss.item():.5f}")
print("Training Complete.")3. The Oracle: Data Distillation¶
Now that we have a trained model, we will use it to generate a high-quality dataset. This dataset is cleaner than real-world data and allows symbolic regression to work efficiently.
# Generate a grid of points
x_vals = np.linspace(0, 1, 50)
t_vals = np.linspace(0, 1, 50)
X_grid, T_grid = np.meshgrid(x_vals, t_vals)
X_flat = X_grid.flatten()[:, None]
T_flat = T_grid.flatten()[:, None]
x_tensor = torch.tensor(X_flat, dtype=torch.float32).to(device)
t_tensor = torch.tensor(T_flat, dtype=torch.float32).to(device)
# Ask the Oracle (PINN) for the solution
model.eval()
with torch.no_grad():
u_pred = model(x_tensor, t_tensor).cpu().numpy()
# Prepare data for PySR
# Input: [x, t]
# Output: u
X_sr = np.hstack([X_flat, T_flat])
y_sr = u_pred.flatten()
print(f"Distilled Dataset Shape: {X_sr.shape} -> {y_sr.shape}")plt.figure(figsize=(8, 6))
scatter = plt.scatter(X_sr[:, 0], X_sr[:, 1], c=y_sr, cmap="viridis", s=10)
plt.colorbar(scatter, label="y_sr value")
plt.xlabel("x")
plt.ylabel("t")
plt.title("y(x,t)")
plt.show()4. Symbolic Regression with PySR¶
We will now feed this data into PySRRegressor.
We give it a hint by providing the unary operators sin and exp, as we suspect the solution might involve waves or decay.
model_sr = PySRRegressor(
niterations=100,
binary_operators=["+", "*", "-", "/"],
unary_operators=[
"sin", # Expected in spatial part
"exp", # Expected in temporal part
# "cos", # Optional, let's see if it picks it up
],
variable_names=["x", "t"],
verbosity=1,
)
print("Starting Symbolic Regression... might take a minute...")
model_sr.fit(X_sr, y_sr)/Users/pedrolugao/Documents/PeruClasses/.venv/lib/python3.12/site-packages/pysr/sr.py:1046: FutureWarning: `variable_names` is a data-dependent parameter and should be passed when fit is called. Ignoring parameter; please pass `variable_names` during the call to fit instead.
warnings.warn(
/Users/pedrolugao/Documents/PeruClasses/.venv/lib/python3.12/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off.
warnings.warn(
Compiling Julia backend...
Starting Symbolic Regression... might take a minute...
[ Info: Started!
Expressions evaluated per second: 2.500e+04
Progress: 153 / 3100 total iterations (4.935%)
════════════════════════════════════════════════════════════════════════════════════════════════════
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity Loss Score Equation
1 9.193e-02 0.000e+00 y = 0.59287
4 1.478e-03 1.377e+00 y = sin(x₀ * 3.1622)
6 4.997e-04 5.422e-01 y = sin(3.1622 * x₀) / 1.0367
10 3.781e-04 6.967e-02 y = sin(((x₀ + x₀) * 1.0716) + x₀) * 0.95112
12 3.780e-04 1.519e-04 y = sin(((x₀ + x₀) * 1.0716) + (x₀ - 0.00049893)) * 0.9511...
2
18 3.780e-04 9.934e-08 y = sin(x₀ + (((x₀ + ((x₀ * 1.1432) - 0.00026576)) - (x₀ -...
x₀)) - 0.00026621)) * 0.95113
───────────────────────────────────────────────────────────────────────────────────────────────────
════════════════════════════════════════════════════════════════════════════════════════════════════
Press 'q' and then <enter> to stop execution early.
Expressions evaluated per second: 3.760e+04
Progress: 491 / 3100 total iterations (15.839%)
════════════════════════════════════════════════════════════════════════════════════════════════════
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity Loss Score Equation
1 9.193e-02 0.000e+00 y = 0.59287
4 1.478e-03 1.377e+00 y = sin(x₀ * 3.1622)
6 3.781e-04 6.816e-01 y = sin(x₀ * 3.1424) * 0.95102
8 3.781e-04 3.874e-07 y = sin(x₀ + (x₀ * 2.1424)) * 0.95102
10 3.780e-04 5.371e-05 y = sin((x₀ * 2.1431) + (x₀ - 0.00051257)) * 0.95113
12 1.725e-04 3.922e-01 y = (sin(x₀ * 3.1281) * 0.46964) / ((x₁ * 0.059224) + 0.45...
815)
14 3.891e-07 3.047e+00 y = ((x₁ * -0.09429) + 0.99818) * sin((x₀ + (x₀ * 1.1424))...
+ x₀)
16 3.476e-07 5.640e-02 y = sin((((x₀ + x₀) * 1.0716) + x₀) + -0.00050485) * ((x₁ ...
* -0.094292) + 0.99827)
20 3.474e-07 8.383e-05 y = ((x₁ * -0.094295) + 0.99828) * sin(x₀ + ((((x₀ + x₀) -...
(x₀ * 0.34967)) * 1.2986) + -0.00051613))
21 3.412e-07 1.817e-02 y = (sin(x₁ * -0.094422) + 0.9983) * sin(x₀ + (((x₀ + (x₀ ...
- (x₀ * 0.34968))) * 1.2986) + -0.00051524))
22 3.352e-07 1.758e-02 y = (sin(sin(x₁ * -0.09458)) + 0.99834) * sin(x₀ + ((((x₀ ...
+ x₀) - (x₀ * 0.34954)) * 1.2985) + -0.00050866))
23 3.294e-07 1.748e-02 y = (sin(sin(sin(x₁ * -0.094697))) + 0.99837) * sin(x₀ + (...
(((x₀ + x₀) - (x₀ * 0.34954)) * 1.2985) + -0.00051018))
24 3.293e-07 4.133e-04 y = (sin(sin(sin(sin(x₁ * -0.094558)))) + 0.99833) * sin(x...
₀ + ((1.2986 * ((x₀ + x₀) - (0.34962 * x₀))) + -0.00049651...
))
───────────────────────────────────────────────────────────────────────────────────────────────────
════════════════════════════════════════════════════════════════════════════════════════════════════
Press 'q' and then <enter> to stop execution early.
Expressions evaluated per second: 3.650e+04
Progress: 706 / 3100 total iterations (22.774%)
════════════════════════════════════════════════════════════════════════════════════════════════════
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity Loss Score Equation
1 9.193e-02 0.000e+00 y = 0.59287
4 1.478e-03 1.377e+00 y = sin(x₀ * 3.1622)
6 3.781e-04 6.816e-01 y = sin(x₀ * 3.1424) * 0.95102
8 3.781e-04 3.874e-07 y = sin(x₀ + (x₀ * 2.1424)) * 0.95102
10 3.780e-04 5.371e-05 y = sin((x₀ * 2.1431) + (x₀ - 0.00051257)) * 0.95113
12 3.889e-07 3.440e+00 y = ((x₁ * -0.094289) + 0.99817) * sin(x₀ + (x₀ * 2.1424))
14 3.889e-07 1.377e-05 y = ((x₁ * -0.094289) + 0.99817) * sin(x₀ + ((x₀ * 1.1424)...
+ x₀))
15 3.826e-07 1.624e-02 y = sin(((x₀ * 1.1424) + x₀) + x₀) * (0.99819 + sin(x₁ * -...
0.094413))
16 3.476e-07 9.616e-02 y = sin((((x₀ + x₀) * 1.0716) + x₀) + -0.00050485) * ((x₁ ...
* -0.094292) + 0.99827)
20 3.474e-07 1.112e-04 y = ((x₁ * -0.094293) + 0.99828) * sin(x₀ + (((x₀ + (x₀ - ...
(x₀ * 0.34965))) + -0.00040343) * 1.2986))
21 3.412e-07 1.812e-02 y = (sin(x₁ * -0.094407) + 0.9983) * sin((x₀ + ((x₀ + (x₀ ...
- (x₀ * 0.34925))) * 1.2983)) + -0.00051925)
22 3.352e-07 1.762e-02 y = (sin(sin(x₁ * -0.094531)) + 0.99832) * sin((x₀ + (((x₀...
+ x₀) - (x₀ * 0.34925)) * 1.2983)) + -0.00051228)
23 3.294e-07 1.760e-02 y = sin((x₀ + -0.00053016) + ((x₀ + (x₀ - (x₀ * 0.34933)))...
* 1.2984)) * (sin(sin(sin(x₁ * -0.094674))) + 0.99836)
24 3.238e-07 1.694e-02 y = sin((x₀ + -0.00053143) + (((x₀ + x₀) - (x₀ * 0.34933))...
* 1.2984)) * (sin(sin(sin(sin(x₁ * -0.094802)))) + 0.9983...
9)
25 3.187e-07 1.585e-02 y = (sin(sin(sin(sin(sin(x₁ * -0.094913))))) + 0.9984) * s...
in(x₀ + ((((x₀ + x₀) - (x₀ * 0.34917)) * 1.2982) - 0.00050...
908))
26 3.136e-07 1.623e-02 y = sin(x₀ + (((x₀ + (x₀ - (x₀ * 0.34934))) * 1.2984) + -0...
.00054347)) * (sin(sin(sin(sin(sin(sin(x₁ * -0.095077)))))...
) + 0.99845)
27 3.089e-07 1.513e-02 y = (sin(sin(sin(sin(sin(sin(sin(x₁ * -0.095195))))))) + 0...
.99848) * sin((x₀ + (((x₀ - (x₀ * 0.34919)) + x₀) * 1.2983...
)) + -0.00054563)
28 3.045e-07 1.445e-02 y = (sin(sin(sin(sin(sin(sin(sin(sin(x₁ * -0.095325)))))))...
) + 0.9985) * sin((((x₀ + (x₀ - (x₀ * 0.34913))) * 1.2982)...
+ x₀) + -0.00052455)
29 3.002e-07 1.394e-02 y = sin(x₀ + ((((x₀ - (x₀ * 0.3494)) + x₀) * 1.2984) + -0....
00053199)) * (sin(sin(sin(sin(sin(sin(sin(sin(sin(x₁ * -0....
095492))))))))) + 0.99855)
30 2.963e-07 1.313e-02 y = (sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(x₁ * -0.09562...
8)))))))))) + 0.99858) * sin(-0.00053874 + (x₀ + ((x₀ + (x...
₀ - (x₀ * 0.34937))) * 1.2984)))
───────────────────────────────────────────────────────────────────────────────────────────────────
════════════════════════════════════════════════════════════════════════════════════════════════════
Press 'q' and then <enter> to stop execution early.
Expressions evaluated per second: 4.070e+04
Progress: 992 / 3100 total iterations (32.000%)
════════════════════════════════════════════════════════════════════════════════════════════════════
[ Info: Final population:
[ Info: Results saved to:
Error in callback _flush_stdio (for post_execute), with arguments args (),kwargs {}:
---------------------------------------------------------------------------
UnicodeDecodeError Traceback (most recent call last)
File ~/.julia/packages/PythonCall/avYrV/src/JlWrap/any.jl:262, in __call__(self, *args, **kwargs)
260 return ValueBase.__dir__(self) + self._jl_callmethod($(pyjl_methodnum(pyjlany_dir)))
261 def __call__(self, *args, **kwargs):
--> 262 return self._jl_callmethod($(pyjl_methodnum(pyjlany_call)), args, kwargs)
263 def __bool__(self):
264 return True
UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 4094-4095: unexpected end of datamodel_sr.latex(0)'0.593'Error in callback _flush_stdio (for post_execute), with arguments args (),kwargs {}:
---------------------------------------------------------------------------
UnicodeDecodeError Traceback (most recent call last)
File ~/.julia/packages/PythonCall/avYrV/src/JlWrap/any.jl:262, in __call__(self, *args, **kwargs)
260 return ValueBase.__dir__(self) + self._jl_callmethod($(pyjl_methodnum(pyjlany_dir)))
261 def __call__(self, *args, **kwargs):
--> 262 return self._jl_callmethod($(pyjl_methodnum(pyjlany_call)), args, kwargs)
263 def __bool__(self):
264 return True
UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 4094-4095: unexpected end of data5. Analysis and “Verification”¶
Let’s look at the best equation found.
print("Best equation:")
print(model_sr.sympy())
print("\nTarget Analytical Equation:")
print(f"sin(pi * x) * exp(-{alpha * np.pi**2:.4f} * t)")
print(f"approx: sin(3.1415*x) * exp(-0.0987*t)")Best equation:
(0.998159 + x1*(-0.09428466))*sin(x0*3.1424408)
Target Analytical Equation:
sin(pi * x) * exp(-0.0987 * t)
approx: sin(3.1415*x) * exp(-0.0987*t)
Error in callback _flush_stdio (for post_execute), with arguments args (),kwargs {}:
---------------------------------------------------------------------------
UnicodeDecodeError Traceback (most recent call last)
File ~/.julia/packages/PythonCall/avYrV/src/JlWrap/any.jl:262, in __call__(self, *args, **kwargs)
260 return ValueBase.__dir__(self) + self._jl_callmethod($(pyjl_methodnum(pyjlany_dir)))
261 def __call__(self, *args, **kwargs):
--> 262 return self._jl_callmethod($(pyjl_methodnum(pyjlany_call)), args, kwargs)
263 def __bool__(self):
264 return True
UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 4094-4095: unexpected end of data# Verify accuracy
y_sym_pred = model_sr.predict(X_sr)
mse = np.mean((y_sym_pred - y_sr) ** 2)
print(f"MSE between Symbolic and PINN: {mse:.6e}")MSE between Symbolic and PINN: 3.889278e-07
Error in callback _flush_stdio (for post_execute), with arguments args (),kwargs {}:
---------------------------------------------------------------------------
UnicodeDecodeError Traceback (most recent call last)
File ~/.julia/packages/PythonCall/avYrV/src/JlWrap/any.jl:262, in __call__(self, *args, **kwargs)
260 return ValueBase.__dir__(self) + self._jl_callmethod($(pyjl_methodnum(pyjlany_dir)))
261 def __call__(self, *args, **kwargs):
--> 262 return self._jl_callmethod($(pyjl_methodnum(pyjlany_call)), args, kwargs)
263 def __bool__(self):
264 return True
UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 4094-4095: unexpected end of data6. Exercises¶
Noisy Oracle: Add Gaussian noise to
y_srbefore passing it to PySR. How robust is the symbolic regression?Different PDE: Change the PINN to solve the Burgers’ equation (from Day 1) and try to distill the shockwave equation.
Missing Operators: Remove “sin” from the unary_operators list. Can PySR approximate the sine wave using Taylor expansion terms (polynomials)?