Seed-Parallel training¤
Since PDEquinox
builds on Equinox
which is just a very thin extension to JAX
PyTrees we can make usage of automatic vectorization (jax.vmap
) in many
interesting ways.
One particularly efficient use-case is "seed-parallel" training. If your network does not fully utilize the GPU compute resources, you can train multiple independent networks (initialized with different seeds) at the same time. The most naive way would be by spawning multiple JAX processes that each take their respective share of the GPU. This has two disadvantages:
- Sharing a GPU between two compute-heavy processes is error-prone.
- Some tensors are the same among all processes (e.g. the dataset) and would have to be duplicated.
If we write the training of the network from the seed to the final trained
network as a (pure) function, we can simply use jax.vmap
on it. More
precisely, we have to use eqx.filter_vmap
because the return type is not a
pure PyTree.
In this example, we will train a simple feedfoward ConvNet to become the solver to the heat equation.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import equinox as eqx
import optax
from tqdm.autonotebook import tqdm
import pdequinox as pdeqx
Below is the data generation with the BTCS method. This is not relevant for the seed-parallel training.
NUM_POINTS = 48
NUM_SAMPLES = 1000
DOMAIN_EXTENT = 3.0
DT = 0.1
DIFFUSIVITY = 0.1
# Grid excludes the two Dirichlet points
grid = jnp.linspace(0, DOMAIN_EXTENT, NUM_POINTS + 2)[1:-1]
dx = grid[1] - grid[0]
laplacian = (
jnp.diag(jnp.ones(NUM_POINTS - 1), -1)
- 2 * jnp.diag(jnp.ones(NUM_POINTS), 0)
+ jnp.diag(jnp.ones(NUM_POINTS - 1), 1)
)
laplacian = laplacian / dx**2
system_matrix = jnp.eye(NUM_POINTS) - DIFFUSIVITY * DT * laplacian
def advance_state(u):
return jnp.linalg.solve(system_matrix, u)
def create_discontinuity(key):
limit_1_key, limit_2_key = jax.random.split(key)
lower_limit = jax.random.uniform(
limit_1_key, (), minval=0.2 * DOMAIN_EXTENT, maxval=0.4 * DOMAIN_EXTENT
)
upper_limit = jax.random.uniform(
limit_2_key, (), minval=0.6 * DOMAIN_EXTENT, maxval=0.8 * DOMAIN_EXTENT
)
discontinuity = jnp.where((grid >= lower_limit) & (grid <= upper_limit), 1.0, 0.0)
return discontinuity
primary_key = jax.random.PRNGKey(0)
keys = jax.random.split(primary_key, NUM_SAMPLES)
initial_states = jax.vmap(create_discontinuity)(keys)
next_states = jax.vmap(advance_state)(initial_states)
# Add a singleton channel axis to be compatible with convolutional layers
initial_states = initial_states[:, None, :]
next_states = next_states[:, None, :]
plt.plot(grid, initial_states[0, 0], label="Initial state")
plt.plot(grid, next_states[0, 0], label="Next state")
# 4:1 train-test split
train_size = NUM_SAMPLES // 5 * 4
train_initial_states, test_initial_states = jnp.split(initial_states, [train_size])
train_next_states, test_next_states = jnp.split(next_states, [train_size])
Below is a straightforward training loop for reference
# Training loop
heat_stepper_conv_net = pdeqx.arch.ConvNet(
1,
1,
1,
hidden_channels=16,
depth=6,
activation=jax.nn.relu,
key=jax.random.PRNGKey(0),
boundary_mode="dirichlet",
)
parameter_count = pdeqx.count_parameters(heat_stepper_conv_net)
print("Number of parameters: ", parameter_count)
optimizer = optax.adam(3e-4)
opt_state = optimizer.init(eqx.filter(heat_stepper_conv_net, eqx.is_array))
def loss_fn(model, x, y):
y_pred = jax.vmap(model)(x)
mse = jnp.mean(jnp.square(y_pred - y))
return mse
@eqx.filter_jit
def update_fn(model, state, x, y):
loss, grad = eqx.filter_value_and_grad(loss_fn)(model, x, y)
updates, new_state = optimizer.update(grad, state, model)
new_model = eqx.apply_updates(model, updates)
return new_model, new_state, loss
loss_history = []
shuffle_key = jax.random.PRNGKey(151)
for epoch in tqdm(range(100)):
shuffle_key, subkey = jax.random.split(shuffle_key)
for batch in pdeqx.dataloader(
(train_initial_states, train_next_states), batch_size=32, key=subkey
):
heat_stepper_conv_net, opt_state, loss = update_fn(
heat_stepper_conv_net,
opt_state,
*batch,
)
loss_history.append(loss)
plt.semilogy(loss_history)
BATCH_I = 5
plt.plot(grid, test_initial_states[BATCH_I, 0], label="initial state")
plt.plot(grid, test_next_states[BATCH_I, 0], label="gt next state")
plt.plot(
grid, heat_stepper_conv_net(test_initial_states[BATCH_I])[0], label="prediction"
)
plt.legend()
Let's now pack this into a training function that maps a jax.random.PRNGKey
to
a trained network.
Internally, this key will be used for pseudo-random parts:
- Initialization of the network
- Shuffling the dataset
def perform_training(key):
init_key, shuffle_key = jax.random.split(key)
heat_stepper_conv_net = pdeqx.arch.ConvNet(
1,
1,
1,
hidden_channels=16,
depth=6,
activation=jax.nn.relu,
key=init_key,
boundary_mode="dirichlet",
)
optimizer = optax.adam(3e-4)
opt_state = optimizer.init(eqx.filter(heat_stepper_conv_net, eqx.is_array))
def loss_fn(model, x, y):
y_pred = jax.vmap(model)(x)
mse = jnp.mean(jnp.square(y_pred - y))
return mse
@eqx.filter_jit
def update_fn(model, state, x, y):
loss, grad = eqx.filter_value_and_grad(loss_fn)(model, x, y)
updates, new_state = optimizer.update(grad, state, model)
new_model = eqx.apply_updates(model, updates)
return new_model, new_state, loss
loss_history = []
for epoch in tqdm(range(100)):
shuffle_key, subkey = jax.random.split(shuffle_key)
for batch in pdeqx.dataloader(
(train_initial_states, train_next_states), batch_size=32, key=subkey
):
heat_stepper_conv_net, opt_state, loss = update_fn(
heat_stepper_conv_net,
opt_state,
*batch,
)
loss_history.append(loss)
loss_history = jnp.array(loss_history)
return heat_stepper_conv_net, loss_history
First, let's see if the function works for a single key.
heat_stepper_conv_net, loss_history = perform_training(jax.random.PRNGKey(0))
plt.semilogy(loss_history)
BATCH_I = 5
plt.plot(grid, test_initial_states[BATCH_I, 0], label="initial state")
plt.plot(grid, test_next_states[BATCH_I, 0], label="gt next state")
plt.plot(
grid, heat_stepper_conv_net(test_initial_states[BATCH_I])[0], label="prediction"
)
plt.legend()
Now, we will train it with ten keys at once. Note that we have to use
eqx.filter_vmap
instead of jax.vmap
because the return type is not a pure
PyTree.
The training is now seed-parallel. It takes ~23 seconds to finish on my GPU (RTX 3060), the single training took ~7 seconds. Definitely, there is an overhead but still we get a speedup of roughly 3x.
keys = jax.random.split(jax.random.PRNGKey(0), 10)
heat_stepper_conv_net_ensemble, loss_history_ensemble = eqx.filter_vmap(
perform_training
)(keys)
Note that the returned network ensemble almost has the same PyTree structure as the original conv net except that all weight arrays have an additional leading "seed-axis"
heat_stepper_conv_net_ensemble
Also the returnd loss_history_ensemble
will be an array with an additional axis.
loss_history_ensemble
Let's plot the loss history of the ensemble.
plt.semilogy(loss_history_ensemble.T);
As well as the various predictions of the ensemble, which requires us to execute the inference forward pass in batch
predictions_ensemble = eqx.filter_vmap(
lambda model: jax.vmap(model)(test_initial_states)
)(heat_stepper_conv_net_ensemble)
BATCH_I = 5
plt.plot(grid, test_initial_states[BATCH_I, 0], label="initial state")
plt.plot(grid, test_next_states[BATCH_I, 0], label="gt next state")
plt.plot(grid, predictions_ensemble[:, BATCH_I, 0].T, label="prediction");
BATCH_I = 5
differences = predictions_ensemble[:, BATCH_I, 0] - test_next_states[BATCH_I, 0]
plt.plot(grid, differences.T);
Manipulating ensembles¤
We saw that once we have an ensemble of networks, calling it on new inputs
always requires some juggeling with the filtered vmap of Equinox
. Let's say,
we are interested in the network associated with seed 4 (index 4 using 0-based
indexing).
heat_stepper_conv_net_seed_4 = pdeqx.extract_from_ensemble(
heat_stepper_conv_net_ensemble, 4
)
BATCH_I = 5
plt.plot(grid, test_initial_states[BATCH_I, 0], label="initial state")
plt.plot(grid, test_next_states[BATCH_I, 0], label="gt next state")
plt.plot(
grid,
heat_stepper_conv_net_seed_4(test_initial_states[BATCH_I])[0],
label="prediction of seed 4",
)
plt.legend()
Let's say we want to combine seed 4 & 6 into a new ensemble, we could do it like this
heat_stepper_conv_net_seed_4 = pdeqx.extract_from_ensemble(
heat_stepper_conv_net_ensemble, 4
)
heat_stepper_conv_net_seed_6 = pdeqx.extract_from_ensemble(
heat_stepper_conv_net_ensemble, 6
)
heat_stepper_conv_net_ensemble_4_6 = pdeqx.combine_to_ensemble(
[heat_stepper_conv_net_seed_4, heat_stepper_conv_net_seed_6]
)
If we inspect the shape of its first weight array, we see that the seed axis is now of size 2.
heat_stepper_conv_net_ensemble_4_6.layers[0].weight.shape