Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Day 3: Symbolic Distillation of PINNs

Welcome to the final stage of our journey!

The Goal

We have successfully trained a Physics-Informed Neural Network (PINN) to solve a Partial Differential Equation (PDE). However, the neural network is a “black box” — it gives us the right numbers, but not the mathematical understanding.

Our Mission: Use Symbolic Regression to extracting the analytical law u(x,t)u(x, t) directly from the trained neural network.

The Process

  1. Train a PINN: We will re-train a PINN on the Heat Equation: ut=αuxxu_t = \alpha u_{xx}.

  2. The Oracle: Use the trained PINN as an “Oracle” to generate clean, high-resolution data.

  3. Distillation: Use PySR to find the symbolic equation that best fits the PINN’s output.

  4. Discovery: Recover the analytical solution: u(x,t)=sin(πx)eαπ2tu(x, t) = \sin(\pi x)e^{-\alpha \pi^2 t}.

1. Setup and Dependencies

We need torch for the PINN and pysr for symbolic regression.

import matplotlib.pyplot as plt

# create a colormap with x,t and y
import numpy as np
import torch
import torch.nn as nn
from pysr import PySRRegressor

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython
Using device: cpu

2. Train the PINN (The Teacher)

We solved the Heat Equation in Day 2. Let’s quickly re-implement it here to have a fresh model.

Problem:

ut=α2ux2\frac{\partial u}{\partial t} = \alpha \frac{\partial^2 u}{\partial x^2}

with α=0.01\alpha = 0.01, and initial condition u(x,0)=sin(πx)u(x, 0) = \sin(\pi x).

class PINN(nn.Module):
    def __init__(self):
        super().__init__()
        # Simple MLP: 2 inputs (x, t) -> 1 output (u)
        self.net = nn.Sequential(
            nn.Linear(2, 32), nn.Tanh(), nn.Linear(32, 32), nn.Tanh(), nn.Linear(32, 1)
        )

    def forward(self, x, t):
        # Concatenate x and t to form a (N, 2) input
        inputs = torch.cat([x, t], dim=1)
        return self.net(inputs)


def compute_pde_residual(model, x, t, alpha=0.01):
    x.requires_grad = True
    t.requires_grad = True
    u = model(x, t)

    # Gradients
    u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[
        0
    ]
    u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True)[
        0
    ]
    u_xx = torch.autograd.grad(
        u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True
    )[0]

    return u_t - alpha * u_xx


# True solution for comparison
def true_solution(x, t, alpha=0.01):
    return np.sin(np.pi * x) * np.exp(-alpha * np.pi**2 * t)

Training Loop

model = PINN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
alpha = 0.01
iterations = 5000

print("Training PINN...")
for i in range(iterations):
    optimizer.zero_grad()

    # 1. Collocation points for PDE
    x_pde = torch.rand(2000, 1).to(device)  # x in [0, 1]
    t_pde = torch.rand(2000, 1).to(device)  # t in [0, 1]
    res = compute_pde_residual(model, x_pde, t_pde, alpha)
    loss_pde = torch.mean(res**2)

    # 2. Initial Condition (t=0)
    x_ic = torch.rand(500, 1).to(device)
    t_ic = torch.zeros(500, 1).to(device)
    u_ic_pred = model(x_ic, t_ic)
    u_ic_true = torch.sin(np.pi * x_ic)
    loss_ic = torch.mean((u_ic_pred - u_ic_true) ** 2)

    # 3. Boundary Conditions (x=0, x=1)
    t_bc = torch.rand(500, 1).to(device)
    x_bc0 = torch.zeros(500, 1).to(device)
    x_bc1 = torch.ones(500, 1).to(device)
    u_bc0 = model(x_bc0, t_bc)
    u_bc1 = model(x_bc1, t_bc)
    loss_bc = torch.mean(u_bc0**2) + torch.mean(u_bc1**2)

    loss = loss_pde + loss_ic + loss_bc
    loss.backward()
    optimizer.step()

    if i % 1000 == 0:
        print(f"Iter {i}, Loss: {loss.item():.5f}")

print("Training Complete.")

3. The Oracle: Data Distillation

Now that we have a trained model, we will use it to generate a high-quality dataset. This dataset is cleaner than real-world data and allows symbolic regression to work efficiently.

# Generate a grid of points
x_vals = np.linspace(0, 1, 50)
t_vals = np.linspace(0, 1, 50)
X_grid, T_grid = np.meshgrid(x_vals, t_vals)

X_flat = X_grid.flatten()[:, None]
T_flat = T_grid.flatten()[:, None]

x_tensor = torch.tensor(X_flat, dtype=torch.float32).to(device)
t_tensor = torch.tensor(T_flat, dtype=torch.float32).to(device)

# Ask the Oracle (PINN) for the solution
model.eval()
with torch.no_grad():
    u_pred = model(x_tensor, t_tensor).cpu().numpy()

# Prepare data for PySR
# Input: [x, t]
# Output: u
X_sr = np.hstack([X_flat, T_flat])
y_sr = u_pred.flatten()

print(f"Distilled Dataset Shape: {X_sr.shape} -> {y_sr.shape}")
plt.figure(figsize=(8, 6))


scatter = plt.scatter(X_sr[:, 0], X_sr[:, 1], c=y_sr, cmap="viridis", s=10)
plt.colorbar(scatter, label="y_sr value")
plt.xlabel("x")
plt.ylabel("t")
plt.title("y(x,t)")
plt.show()

4. Symbolic Regression with PySR

We will now feed this data into PySRRegressor. We give it a hint by providing the unary operators sin and exp, as we suspect the solution might involve waves or decay.

model_sr = PySRRegressor(
    niterations=100,
    binary_operators=["+", "*", "-", "/"],
    unary_operators=[
        "sin",  # Expected in spatial part
        "exp",  # Expected in temporal part
        # "cos",  # Optional, let's see if it picks it up
    ],
    variable_names=["x", "t"],
    verbosity=1,
)

print("Starting Symbolic Regression... might take a minute...")
model_sr.fit(X_sr, y_sr)
/Users/pedrolugao/Documents/PeruClasses/.venv/lib/python3.12/site-packages/pysr/sr.py:1046: FutureWarning: `variable_names` is a data-dependent parameter and should be passed when fit is called. Ignoring parameter; please pass `variable_names` during the call to fit instead.
  warnings.warn(
/Users/pedrolugao/Documents/PeruClasses/.venv/lib/python3.12/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off.
  warnings.warn(
Compiling Julia backend...
Starting Symbolic Regression... might take a minute...
[ Info: Started!

Expressions evaluated per second: 2.500e+04
Progress: 153 / 3100 total iterations (4.935%)
════════════════════════════════════════════════════════════════════════════════════════════════════
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity  Loss       Score      Equation
1           9.193e-02  0.000e+00  y = 0.59287
4           1.478e-03  1.377e+00  y = sin(x₀ * 3.1622)
6           4.997e-04  5.422e-01  y = sin(3.1622 * x₀) / 1.0367
10          3.781e-04  6.967e-02  y = sin(((x₀ + x₀) * 1.0716) + x₀) * 0.95112
12          3.780e-04  1.519e-04  y = sin(((x₀ + x₀) * 1.0716) + (x₀ - 0.00049893)) * 0.9511...
                                      2
18          3.780e-04  9.934e-08  y = sin(x₀ + (((x₀ + ((x₀ * 1.1432) - 0.00026576)) - (x₀ -...
                                       x₀)) - 0.00026621)) * 0.95113
───────────────────────────────────────────────────────────────────────────────────────────────────
════════════════════════════════════════════════════════════════════════════════════════════════════
Press 'q' and then <enter> to stop execution early.

Expressions evaluated per second: 3.760e+04
Progress: 491 / 3100 total iterations (15.839%)
════════════════════════════════════════════════════════════════════════════════════════════════════
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity  Loss       Score      Equation
1           9.193e-02  0.000e+00  y = 0.59287
4           1.478e-03  1.377e+00  y = sin(x₀ * 3.1622)
6           3.781e-04  6.816e-01  y = sin(x₀ * 3.1424) * 0.95102
8           3.781e-04  3.874e-07  y = sin(x₀ + (x₀ * 2.1424)) * 0.95102
10          3.780e-04  5.371e-05  y = sin((x₀ * 2.1431) + (x₀ - 0.00051257)) * 0.95113
12          1.725e-04  3.922e-01  y = (sin(x₀ * 3.1281) * 0.46964) / ((x₁ * 0.059224) + 0.45...
                                      815)
14          3.891e-07  3.047e+00  y = ((x₁ * -0.09429) + 0.99818) * sin((x₀ + (x₀ * 1.1424))...
                                       + x₀)
16          3.476e-07  5.640e-02  y = sin((((x₀ + x₀) * 1.0716) + x₀) + -0.00050485) * ((x₁ ...
                                      * -0.094292) + 0.99827)
20          3.474e-07  8.383e-05  y = ((x₁ * -0.094295) + 0.99828) * sin(x₀ + ((((x₀ + x₀) -...
                                       (x₀ * 0.34967)) * 1.2986) + -0.00051613))
21          3.412e-07  1.817e-02  y = (sin(x₁ * -0.094422) + 0.9983) * sin(x₀ + (((x₀ + (x₀ ...
                                      - (x₀ * 0.34968))) * 1.2986) + -0.00051524))
22          3.352e-07  1.758e-02  y = (sin(sin(x₁ * -0.09458)) + 0.99834) * sin(x₀ + ((((x₀ ...
                                      + x₀) - (x₀ * 0.34954)) * 1.2985) + -0.00050866))
23          3.294e-07  1.748e-02  y = (sin(sin(sin(x₁ * -0.094697))) + 0.99837) * sin(x₀ + (...
                                      (((x₀ + x₀) - (x₀ * 0.34954)) * 1.2985) + -0.00051018))
24          3.293e-07  4.133e-04  y = (sin(sin(sin(sin(x₁ * -0.094558)))) + 0.99833) * sin(x...
                                      ₀ + ((1.2986 * ((x₀ + x₀) - (0.34962 * x₀))) + -0.00049651...
                                      ))
───────────────────────────────────────────────────────────────────────────────────────────────────
════════════════════════════════════════════════════════════════════════════════════════════════════
Press 'q' and then <enter> to stop execution early.

Expressions evaluated per second: 3.650e+04
Progress: 706 / 3100 total iterations (22.774%)
════════════════════════════════════════════════════════════════════════════════════════════════════
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity  Loss       Score      Equation
1           9.193e-02  0.000e+00  y = 0.59287
4           1.478e-03  1.377e+00  y = sin(x₀ * 3.1622)
6           3.781e-04  6.816e-01  y = sin(x₀ * 3.1424) * 0.95102
8           3.781e-04  3.874e-07  y = sin(x₀ + (x₀ * 2.1424)) * 0.95102
10          3.780e-04  5.371e-05  y = sin((x₀ * 2.1431) + (x₀ - 0.00051257)) * 0.95113
12          3.889e-07  3.440e+00  y = ((x₁ * -0.094289) + 0.99817) * sin(x₀ + (x₀ * 2.1424))
14          3.889e-07  1.377e-05  y = ((x₁ * -0.094289) + 0.99817) * sin(x₀ + ((x₀ * 1.1424)...
                                       + x₀))
15          3.826e-07  1.624e-02  y = sin(((x₀ * 1.1424) + x₀) + x₀) * (0.99819 + sin(x₁ * -...
                                      0.094413))
16          3.476e-07  9.616e-02  y = sin((((x₀ + x₀) * 1.0716) + x₀) + -0.00050485) * ((x₁ ...
                                      * -0.094292) + 0.99827)
20          3.474e-07  1.112e-04  y = ((x₁ * -0.094293) + 0.99828) * sin(x₀ + (((x₀ + (x₀ - ...
                                      (x₀ * 0.34965))) + -0.00040343) * 1.2986))
21          3.412e-07  1.812e-02  y = (sin(x₁ * -0.094407) + 0.9983) * sin((x₀ + ((x₀ + (x₀ ...
                                      - (x₀ * 0.34925))) * 1.2983)) + -0.00051925)
22          3.352e-07  1.762e-02  y = (sin(sin(x₁ * -0.094531)) + 0.99832) * sin((x₀ + (((x₀...
                                       + x₀) - (x₀ * 0.34925)) * 1.2983)) + -0.00051228)
23          3.294e-07  1.760e-02  y = sin((x₀ + -0.00053016) + ((x₀ + (x₀ - (x₀ * 0.34933)))...
                                       * 1.2984)) * (sin(sin(sin(x₁ * -0.094674))) + 0.99836)
24          3.238e-07  1.694e-02  y = sin((x₀ + -0.00053143) + (((x₀ + x₀) - (x₀ * 0.34933))...
                                       * 1.2984)) * (sin(sin(sin(sin(x₁ * -0.094802)))) + 0.9983...
                                      9)
25          3.187e-07  1.585e-02  y = (sin(sin(sin(sin(sin(x₁ * -0.094913))))) + 0.9984) * s...
                                      in(x₀ + ((((x₀ + x₀) - (x₀ * 0.34917)) * 1.2982) - 0.00050...
                                      908))
26          3.136e-07  1.623e-02  y = sin(x₀ + (((x₀ + (x₀ - (x₀ * 0.34934))) * 1.2984) + -0...
                                      .00054347)) * (sin(sin(sin(sin(sin(sin(x₁ * -0.095077)))))...
                                      ) + 0.99845)
27          3.089e-07  1.513e-02  y = (sin(sin(sin(sin(sin(sin(sin(x₁ * -0.095195))))))) + 0...
                                      .99848) * sin((x₀ + (((x₀ - (x₀ * 0.34919)) + x₀) * 1.2983...
                                      )) + -0.00054563)
28          3.045e-07  1.445e-02  y = (sin(sin(sin(sin(sin(sin(sin(sin(x₁ * -0.095325)))))))...
                                      ) + 0.9985) * sin((((x₀ + (x₀ - (x₀ * 0.34913))) * 1.2982)...
                                       + x₀) + -0.00052455)
29          3.002e-07  1.394e-02  y = sin(x₀ + ((((x₀ - (x₀ * 0.3494)) + x₀) * 1.2984) + -0....
                                      00053199)) * (sin(sin(sin(sin(sin(sin(sin(sin(sin(x₁ * -0....
                                      095492))))))))) + 0.99855)
30          2.963e-07  1.313e-02  y = (sin(sin(sin(sin(sin(sin(sin(sin(sin(sin(x₁ * -0.09562...
                                      8)))))))))) + 0.99858) * sin(-0.00053874 + (x₀ + ((x₀ + (x...
                                      ₀ - (x₀ * 0.34937))) * 1.2984)))
───────────────────────────────────────────────────────────────────────────────────────────────────
════════════════════════════════════════════════════════════════════════════════════════════════════
Press 'q' and then <enter> to stop execution early.

Expressions evaluated per second: 4.070e+04
Progress: 992 / 3100 total iterations (32.000%)
════════════════════════════════════════════════════════════════════════════════════════════════════
[ Info: Final population:
[ Info: Results saved to:
Loading...
Error in callback _flush_stdio (for post_execute), with arguments args (),kwargs {}:
---------------------------------------------------------------------------
UnicodeDecodeError                        Traceback (most recent call last)
File ~/.julia/packages/PythonCall/avYrV/src/JlWrap/any.jl:262, in __call__(self, *args, **kwargs)
    260     return ValueBase.__dir__(self) + self._jl_callmethod($(pyjl_methodnum(pyjlany_dir)))
    261 def __call__(self, *args, **kwargs):
--> 262     return self._jl_callmethod($(pyjl_methodnum(pyjlany_call)), args, kwargs)
    263 def __bool__(self):
    264     return True

UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 4094-4095: unexpected end of data
model_sr.latex(0)
'0.593'
Error in callback _flush_stdio (for post_execute), with arguments args (),kwargs {}:
---------------------------------------------------------------------------
UnicodeDecodeError                        Traceback (most recent call last)
File ~/.julia/packages/PythonCall/avYrV/src/JlWrap/any.jl:262, in __call__(self, *args, **kwargs)
    260     return ValueBase.__dir__(self) + self._jl_callmethod($(pyjl_methodnum(pyjlany_dir)))
    261 def __call__(self, *args, **kwargs):
--> 262     return self._jl_callmethod($(pyjl_methodnum(pyjlany_call)), args, kwargs)
    263 def __bool__(self):
    264     return True

UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 4094-4095: unexpected end of data

5. Analysis and “Verification”

Let’s look at the best equation found.

print("Best equation:")
print(model_sr.sympy())

print("\nTarget Analytical Equation:")
print(f"sin(pi * x) * exp(-{alpha * np.pi**2:.4f} * t)")
print(f"approx: sin(3.1415*x) * exp(-0.0987*t)")
Best equation:
(0.998159 + x1*(-0.09428466))*sin(x0*3.1424408)

Target Analytical Equation:
sin(pi * x) * exp(-0.0987 * t)
approx: sin(3.1415*x) * exp(-0.0987*t)
Error in callback _flush_stdio (for post_execute), with arguments args (),kwargs {}:
---------------------------------------------------------------------------
UnicodeDecodeError                        Traceback (most recent call last)
File ~/.julia/packages/PythonCall/avYrV/src/JlWrap/any.jl:262, in __call__(self, *args, **kwargs)
    260     return ValueBase.__dir__(self) + self._jl_callmethod($(pyjl_methodnum(pyjlany_dir)))
    261 def __call__(self, *args, **kwargs):
--> 262     return self._jl_callmethod($(pyjl_methodnum(pyjlany_call)), args, kwargs)
    263 def __bool__(self):
    264     return True

UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 4094-4095: unexpected end of data
# Verify accuracy
y_sym_pred = model_sr.predict(X_sr)
mse = np.mean((y_sym_pred - y_sr) ** 2)
print(f"MSE between Symbolic and PINN: {mse:.6e}")
MSE between Symbolic and PINN: 3.889278e-07
Error in callback _flush_stdio (for post_execute), with arguments args (),kwargs {}:
---------------------------------------------------------------------------
UnicodeDecodeError                        Traceback (most recent call last)
File ~/.julia/packages/PythonCall/avYrV/src/JlWrap/any.jl:262, in __call__(self, *args, **kwargs)
    260     return ValueBase.__dir__(self) + self._jl_callmethod($(pyjl_methodnum(pyjlany_dir)))
    261 def __call__(self, *args, **kwargs):
--> 262     return self._jl_callmethod($(pyjl_methodnum(pyjlany_call)), args, kwargs)
    263 def __bool__(self):
    264     return True

UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 4094-4095: unexpected end of data

6. Exercises

  1. Noisy Oracle: Add Gaussian noise to y_sr before passing it to PySR. How robust is the symbolic regression?

  2. Different PDE: Change the PINN to solve the Burgers’ equation (from Day 1) and try to distill the shockwave equation.

  3. Missing Operators: Remove “sin” from the unary_operators list. Can PySR approximate the sine wave using Taylor expansion terms (polynomials)?