Supervised Learning of a Burgers predictorยค
Let \(\mathcal{P}\) be the transient simulation operator (=time stepper) associated with the 1d Burgers equation under a certain configuration. The goal of this notebook is to learn an autogressive neural operator (\(\approx\) a neural predictor) \(f_\theta\) that mimics the behavior of \(\mathcal{P}\), i.e., \(f_\theta \approx \mathcal{P}\).
The Optimization shall be in terms of a one-step supervised loss
with \(\mathcal{D}\) being a dataset of sample trajectories created with the reference solver \(\mathcal{P}\).
This notebook contains the following sections:
- Data Generation:
- Instantiating a distribution of initial conditions
- Sample the distribution for a set of training and validation/testing initial conditions
- Rollout the initial conditions of both sets by means of the reference solver \(\mathcal{P}\)
- Model Definition (just a simple periodic convolutional network)
- Training
- Evaluation
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import equinox as eqx
import optax
from typing import List
from tqdm.autonotebook import tqdm
import exponax as ex
Below are constants we fix for the training; feel free to play around with them
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
DIFFUSIVITY = 0.03
INITIAL_CONDITION_WAVENUMBER_CUTOFF = 5
TRAIN_NUM_SAMPLES = 200
TRAIN_TEMPORAL_HORIZON = 50
TRAIN_SEED = 0
INIT_SEED = 1337
OPTIMIZER = optax.adam(optax.exponential_decay(1e-4, 500, 0.90))
NUM_EPOCHS = 50
BATCH_SIZE = 100
SHUFFLE_SEED = 42
VAL_NUM_SAMPLES = 20
VAL_TEMPORAL_HORIZON = 100
VAL_SEED = 773
Instantiate the reference stepper \(\mathcal{P}\) as a Burgers solver in 1d.
reference_stepper = ex.stepper.Burgers(
1, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=DIFFUSIVITY
)
1. Data Generationยค
The initial condition distribution are all functions described by the first
WAVENUMBER_CUTOFF=5
Fourier modes.
IC_GENERATOR = ex.ic.RandomTruncatedFourierSeries(
1, cutoff=INITIAL_CONDITION_WAVENUMBER_CUTOFF
)
The singleton axis between the sample and spatial axis is because the state has one channel
train_ic_set = ex.build_ic_set(
IC_GENERATOR,
num_points=NUM_POINTS,
num_samples=TRAIN_NUM_SAMPLES,
key=jax.random.PRNGKey(TRAIN_SEED),
)
# (TRAIN_NUM_SAMPLES, 1, NUM_POINTS)
train_ic_set.shape
Plotting the first three initial conditions of the train IC set.
plt.plot(train_ic_set[:3, 0, :].T);
Rollout the initial conditions of the training dataset
train_trj_set = jax.vmap(
ex.rollout(reference_stepper, TRAIN_TEMPORAL_HORIZON, include_init=True)
)(train_ic_set)
# (TRAIN_N_SAMPLES, TRAIN_TEMPORAL_HORIZON + 1, 1, N)
train_trj_set.shape
Plot the training trajectories as spatial-temporal plots. One can see that within the first 5-10 time steps, the typical Burgers shocks develop which then propagate. Over time, the state dissipates.
fig, ax_s = plt.subplots(2, 2, figsize=(10, 7))
for i, ax in enumerate(ax_s.ravel()):
ax.imshow(
train_trj_set[i, :, 0, :].T,
cmap="RdBu_r",
vmin=-1,
vmax=1,
aspect="auto",
origin="lower",
)
ax.set_xlabel("time")
ax.set_ylabel("space")
Repeat initial condition set generation and rollout for the validation set.
val_ic_set = ex.build_ic_set(
IC_GENERATOR,
num_points=NUM_POINTS,
num_samples=VAL_NUM_SAMPLES,
key=jax.random.PRNGKey(VAL_SEED),
)
val_trj_set = jax.vmap(
ex.rollout(reference_stepper, VAL_TEMPORAL_HORIZON, include_init=True)
)(val_ic_set)
Investigating the shape of training data (num_samples, num_timesteps, 1,
num_points)
we essentially have 2 "batch axes". To train with lagged windows of
size 2, we have slice those windows out the num_timesteps
axis. This is what
the function ex.stack_sub_trajectories
does. We apply it with jax.vmap
to
vectorize over the num_samples
axis. The configuration in_axes=(0, None)
is
relevant to vectorize over the data input and not over the prescribed window
length of 2
.
# sub trajectories are or length 2 (corresponding to an input and output)
train_trj_set_substacked = jax.vmap(ex.stack_sub_trajectories, in_axes=(0, None))(
train_trj_set, 2
)
# (TRAIN_N_SAMPLES, n_sub_trj_s, 2, 1, N)
train_trj_set_substacked.shape
Using the jnp.concatenate
function to merge into one batch axis. The returning
data array has a very typical format of (num_windws, window_length,
num_channels, ...)
with the ...
indicating an arbitrary number of spatial
dimensions. Here, we have one spatial axis of size num_points
. Note that we
need a window length of 2 to do one-step supervised training to have both an
input and a target.
# Merge the two batch axes (samples & sub trajectories)
train_set = jnp.concatenate(train_trj_set_substacked)
# (TRAIN_N_SAMPLES * n_sub_trj_s, 2, 1, N)
train_set.shape
Let's visualize the first (zeroth) window
plt.plot(train_set[0, :, 0, :].T)
2. Model Definitionยค
Below is a simple feed-forward convolutional network with periodic padding and
hard-coded tanh activation. The user can adjust the number of hidden layers and
their sizes. num_hidden_layers=0
refers to a linear kernel=3 convolution with
a bias.
class SimplePeriodicConvNet(eqx.Module):
conv_layers: List[eqx.nn.Conv1d]
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_width: int,
num_hidden_layers: int,
*,
key,
):
channel_list = (
[in_channels] + [hidden_width] * num_hidden_layers + [out_channels]
)
self.conv_layers = []
for fan_in, fan_out in zip(channel_list[:-1], channel_list[1:]):
subkey, key = jr.split(key)
self.conv_layers.append(
eqx.nn.Conv1d(
fan_in,
fan_out,
kernel_size=3,
key=subkey,
)
)
def periodic_padding(
self,
x,
):
# padding over channels space
return jnp.pad(x, ((0, 0), (1, 1)), mode="wrap")
def __call__(
self,
x,
):
for conv_layer in self.conv_layers[:-1]:
x = self.periodic_padding(x)
x = conv_layer(x)
x = jax.nn.tanh(x)
x = self.periodic_padding(x)
x = self.conv_layers[-1](x)
return x
Sample initialization for a sanity check of the model
neural_stepper = SimplePeriodicConvNet(1, 1, 32, 5, key=jr.PRNGKey(INIT_SEED))
neural_stepper
Let's see what the network predicts with its initial parameter state
initial_neural_trj_set = jax.vmap(
ex.rollout(neural_stepper, VAL_TEMPORAL_HORIZON, include_init=True)
)(val_ic_set)
As expected, the prediction is rather bad in the sense of accuracy but it is also a good start for training. ;)
fig, ax_s = plt.subplots(1, 2, figsize=(8, 5))
ax_s[0].imshow(
val_trj_set[0, :, 0, :].T,
cmap="RdBu",
vmin=-1,
vmax=1,
aspect="auto",
origin="lower",
)
ax_s[0].set_xlabel("time")
ax_s[0].set_ylabel("space")
ax_s[0].set_title("Reference")
ax_s[1].imshow(
initial_neural_trj_set[0, :, 0, :].T,
cmap="RdBu",
vmin=-1,
vmax=1,
aspect="auto",
origin="lower",
)
ax_s[1].set_xlabel("time")
ax_s[1].set_ylabel("space")
ax_s[1].set_title("Initial neural")
3. Trainingยค
Below is a simple JAX-based dataloader
def dataloader(
data,
*,
batch_size: int,
key,
):
n_samples = data.shape[0]
n_batches = int(jnp.ceil(n_samples / batch_size))
permutation = jax.random.permutation(key, n_samples)
for batch_id in range(n_batches):
start = batch_id * batch_size
end = min((batch_id + 1) * batch_size, n_samples)
batch_indices = permutation[start:end]
sub_data = data[batch_indices]
yield sub_data
Below is a straightforward training loop in JAX using the prescribed optimizer. Notice, how we slice the input and target out of one data batch as the first operation in the loss function.
Training should be \(\sim 10s\) on a GPU.
neural_stepper = SimplePeriodicConvNet(1, 1, 32, 5, key=jr.PRNGKey(5))
opt_state = OPTIMIZER.init(eqx.filter(neural_stepper, eqx.is_array))
def loss_fn(model, batch):
x, y = batch[:, 0], batch[:, 1]
y_hat = jax.vmap(model)(x)
return jnp.mean(jnp.square(y - y_hat))
@eqx.filter_jit
def step_fn(model, state, batch):
loss, grad = eqx.filter_value_and_grad(loss_fn)(model, batch)
updates, new_state = OPTIMIZER.update(grad, state, model)
new_model = eqx.apply_updates(model, updates)
return new_model, new_state, loss
shuffle_key = jr.PRNGKey(SHUFFLE_SEED)
train_loss_history = []
for epoch in tqdm(range(NUM_EPOCHS), position=0):
shuffle_key, subkey = jr.split(shuffle_key)
for batch in dataloader(train_set, batch_size=BATCH_SIZE, key=subkey):
neural_stepper, opt_state, loss = step_fn(neural_stepper, opt_state, batch)
train_loss_history.append(loss)
The final state of training is very noisy (likely due to the small batch sizes and a suboptimal learning rate scheduling), but we still sufficiently trained the network -> Reduced the loss ~4 orders of magnitude.
plt.plot(train_loss_history)
plt.yscale("log")
Let's use the trained neural predictor to rollout the validation intial conditions
trained_neural_prediction = jax.vmap(
ex.rollout(neural_stepper, VAL_TEMPORAL_HORIZON, include_init=True)
)(val_ic_set)
A visual comparison with the validation reference reveils a close match
fig, ax_s = plt.subplots(1, 2, figsize=(8, 5))
ax_s[0].imshow(
val_trj_set[0, :, 0, :].T,
cmap="RdBu",
vmin=-1,
vmax=1,
aspect="auto",
origin="lower",
)
ax_s[0].set_xlabel("time")
ax_s[0].set_ylabel("space")
ax_s[0].set_title("Reference")
ax_s[1].imshow(
trained_neural_prediction[0, :, 0, :].T,
cmap="RdBu",
vmin=-1,
vmax=1,
aspect="auto",
origin="lower",
)
ax_s[1].set_xlabel("time")
ax_s[1].set_ylabel("space")
ax_s[1].set_title("Trained neural")
For a numerical assessment, we compute the normalized root-mean-squared error (nRMSE) between the prediction and the reference for each sample and each time step.
The ex.metrics.nRMSE
takes two state arrays (consisting of the leading channel
dimension and an arbitrary number of spatial dimensions). We will apply two
jax.vmap
to vectorize over both the sample and time axis.
# Compute normalized error rollout for trained network
error_trj_s = jax.vmap(jax.vmap(ex.metrics.nRMSE))(
trained_neural_prediction, val_trj_set
)
# (VAL_N_SAMPLES, VAL_TEMPORAL_HORIZON + 1,)
error_trj_s.shape
Let's visualize the nRMSE over time for each sample. Note that once the error reaches the threshold of 1.0, the RMSE of the predicted state is of the same order of magnitude of the reference state. This likely indicates that the prediction diverged from the ground truth.
plt.plot(error_trj_s.T)
plt.hlines(1.0, 0, VAL_TEMPORAL_HORIZON, linestyles="dashed", color="black")
plt.ylim(0, 2)
plt.grid()