Training a linear convolution to be an advection stepperยค
This example demonstrates that we can (re-)learn the first-order upwind (FOU) stencil for the 1d advection equation from trajectories created by such. We will see that that the this two-dimesional optimization problem is quadratic convex and that (supervised) training rollout will not change the global minimizer.
import jax
import jax.numpy as jnp
import optax
import equinox as eqx
import matplotlib.pyplot as plt
import trainax as tx
Let's the define the physics of the advection equation
where \(c = -0.1\) is the advection speed. The domain is \(\Omega = [0, 1]\) with homogeneous Dirichlet boundary conditions. We discretize it using \(N=40\) interior degrees of freedom. The size of the time step is \(\Delta t = 0.1\).
DOMAIN_EXTENT = 1.0
NUM_POINTS = 40
DT = 0.1
ADVECTION_SPEED = -0.1
The grid
denotes the coordinates of the interior degress of freedom.
grid = jnp.linspace(0.0, DOMAIN_EXTENT, NUM_POINTS + 2)[1:-1]
Let's extract the grid spacing \(\Delta x\) to build the numerical reference stepper.
dx = grid[1] - grid[0]
def step_fou(u):
u_right = jnp.pad(u, (0, 1), mode="constant")[1:]
u_next = u - ADVECTION_SPEED * DT / dx * (u_right - u)
return u_next
The initial condition is a sine arc.
u_0 = jnp.sin(jnp.pi * grid)
Let us integrate the initial condition for \(100\) time steps and stack the states
together with the initial state in a tensor of shape (101, 40)
.
trj = [u_0]
for _ in range(100):
trj.append(step_fou(trj[-1]))
trj = jnp.stack(trj)
trj.shape
Visualizing the trajectory as a spatio-temporal plot reveals that the initial condition moves to the left before ultimately disappearing beyond the left boundary.
What is only barely noticeable is the introduced numerical diffusion of the first-order upwind method.
plt.imshow(trj.T)
plt.xlabel("time")
plt.ylabel("space")
plt.colorbar()
Now, let's instantiate a supervised trainer. Internally, it will slice the long trajectory into windows of length \(2\) which are necessary for one-step supervised training.
Note that we need to add singleton axes for the channel and the batch (which was not yet present in the trajectory).
We will set the batch size high such that training will always be full batch. Adam can be chosen with a higher learning rate than usual because of the convex quadratic nature of the optimization problem.
one_step_trainer = tx.trainer.SupervisedTrainer(
trj[None, :, None, :],
optimizer=optax.adam(1e-2),
num_training_steps=3000,
batch_size=200,
)
A trainer is an Equinox
Module (=PyTree):
- A
trajectory_sub_stacker
which provides a systematic way to slice out specificied windows out of (batches of) trajectories. The window length has to be two because one-step supervised learning needs one input and one target placed at two consecutive time levels. This member attribute also reveals that there were \(100\) total windows in our trajectory of length \(101\). Hence, with the chosen batch size of \(200\), the trainer will resort to full batch training and reduce the batch size to the maximum possible value of \(100\). - A
loss_configuration
describing how this windowed data is used. Here we have aSupervised
configuration withnum_rollout_steps=1
(-> one-step supervised learning). ref_stepper
andresiduum_fn
are required for other configurations we do not consider for this notebook.- Then, the model saves optimization specific attribute like the
optimizer
,num_minibatches
(equals the number of update steps that will be performed),batch_size
, and potential callback functions.
one_step_trainer
Below is a function that builds an Ansatz linear convolution of kernel size two with no bias.
def build_fou_learner(center, right):
fou_learner = eqx.nn.Conv1d(
1,
1,
2,
padding=((0, 1),),
use_bias=False,
key=jax.random.PRNGKey(0), # Dummy key because we will overwrite the init
)
fou_learner = eqx.tree_at(
lambda leaf: leaf.weight,
fou_learner,
jnp.array([[[center, right]]]),
)
return fou_learner
Let us visualize the loss landscape over the two-dimensional parameter space.
center_range = jnp.linspace(0.0, 1.0, 50)
right_range = jnp.linspace(0.0, 1.0, 50)
We will use Equinox
' vmapping routines for neural components to create a
"structure of arrays" approach to have 50x50 different neural emulators.
fou_learners = eqx.filter_vmap(
eqx.filter_vmap(
build_fou_learner,
in_axes=(0, None),
),
in_axes=(None, 0),
)(center_range, right_range)
This batched stepper is visible in that the weight array (that saves the convolution kernel) has two leading batch axes of size \(50\) each.
fou_learners
All trainers come with a method full_loss
that evaluates a proposed stepper on
the entire dataset at once. We will nestedly vmap this function to obtain a
matrix of loss values.
(Be careful when using this function on very large datasets)
loss_landscape = eqx.filter_vmap(
eqx.filter_vmap(one_step_trainer.full_loss),
)(fou_learners)
Then we can visualize a contour plot of the (log10) loss landscape. This reveals the expected convex quadratic shape of the loss landscape.
plt.contourf(center_range, right_range, jnp.log10(loss_landscape), levels=50)
plt.colorbar()
plt.xlabel("Center Stencil")
plt.ylabel("Right Stencil")
Let's perform the corresponding optimization/training. For this, we will choose one sample stepper that we instantiate close to expected global optimizer.
trained_stepper, loss_history = one_step_trainer(
build_fou_learner(0.4, 0.5),
jax.random.PRNGKey(0),
)
Visualizing the training progress indicates a full convergence all the way down to double floating machine precision.
plt.semilogy(loss_history)
Below, we inspect the found weight. Up to some final rounding error, we found the estimate
This is exactly the FOU stencil!
trained_stepper.weight
plt.contourf(center_range, right_range, jnp.log10(loss_landscape), levels=50)
plt.colorbar()
plt.xlabel("Center Stencil")
plt.ylabel("Right Stencil")
plt.scatter(
trained_stepper.weight[0, 0, 0],
trained_stepper.weight[0, 0, 1],
marker="x",
color="red",
)
fou_stencil = jnp.array([[[1 + ADVECTION_SPEED * DT / dx, -ADVECTION_SPEED * DT / dx]]])
fou_stencil
Beyond one-step trainingยค
Let's see how the optimization landscape changes if introduce increasing training rollout.
one_step_trainer = tx.trainer.SupervisedTrainer(
trj[None, :, None, :],
optimizer=optax.adam(1e-2),
num_training_steps=3000,
batch_size=200,
num_rollout_steps=1,
)
two_step_trainer = tx.trainer.SupervisedTrainer(
trj[None, :, None, :],
optimizer=optax.adam(1e-2),
num_training_steps=3000,
batch_size=200,
num_rollout_steps=2,
)
three_step_trainer = tx.trainer.SupervisedTrainer(
trj[None, :, None, :],
optimizer=optax.adam(1e-2),
num_training_steps=3000,
batch_size=200,
num_rollout_steps=3,
)
four_step_trainer = tx.trainer.SupervisedTrainer(
trj[None, :, None, :],
optimizer=optax.adam(1e-2),
num_training_steps=3000,
batch_size=200,
num_rollout_steps=4,
)
one_step_loss_landscape = eqx.filter_vmap(
eqx.filter_vmap(one_step_trainer.full_loss),
)(fou_learners)
two_step_loss_landscape = eqx.filter_vmap(
eqx.filter_vmap(two_step_trainer.full_loss),
)(fou_learners)
three_step_loss_landscape = eqx.filter_vmap(
eqx.filter_vmap(three_step_trainer.full_loss),
)(fou_learners)
four_step_loss_landscape = eqx.filter_vmap(
eqx.filter_vmap(four_step_trainer.full_loss),
)(fou_learners)
Visualizing all landscapes together shows that the global optimizer seemingly does not move, we seemingly maintain convexity, only the landscape becomes steeper.
fig, ax_s = plt.subplots(2, 2, figsize=(12, 10))
for ax, loss_landscape, trainer in zip(
ax_s.flat,
[
one_step_loss_landscape,
two_step_loss_landscape,
three_step_loss_landscape,
four_step_loss_landscape,
],
[one_step_trainer, two_step_trainer, three_step_trainer, four_step_trainer],
):
im = ax.contourf(
center_range,
right_range,
jnp.log10(loss_landscape),
levels=jnp.linspace(-9, 2, 50),
)
ax.set_xlabel("Center Stencil")
ax.set_ylabel("Right Stencil")
ax.grid()
fig.colorbar(im, ax=ax_s.ravel().tolist(), label="Log10 Loss")
Let us also perform all corresponding optimizations to confirm that the global optimizer indeed does not change.
one_step_trained_stepper, one_step_loss_history = one_step_trainer(
build_fou_learner(0.4, 0.5),
jax.random.PRNGKey(0),
)
two_step_trained_stepper, two_step_loss_history = two_step_trainer(
build_fou_learner(0.4, 0.5),
jax.random.PRNGKey(0),
)
three_step_trained_stepper, three_step_loss_history = three_step_trainer(
build_fou_learner(0.4, 0.5),
jax.random.PRNGKey(0),
)
four_step_trained_stepper, four_step_loss_history = four_step_trainer(
build_fou_learner(0.4, 0.5),
jax.random.PRNGKey(0),
)
All the same optima are found.
one_step_trained_stepper.weight, two_step_trained_stepper.weight, three_step_trained_stepper.weight, four_step_trained_stepper.weight,
Qualitatively, all optimizations progress similarly but with scaled factors.
This is because, we supplied None
to the time_level_weights
argument of the
SupervisedTrainer
, this will lead to a uniform weighting of all time levels
(summation instead of mean).
plt.semilogy(one_step_loss_history, label="1-step")
plt.semilogy(two_step_loss_history, label="2-step")
plt.semilogy(three_step_loss_history, label="3-step")
plt.semilogy(four_step_loss_history, label="4-step")
plt.legend()