Configuration Showcase¤
This notebook serves as a showcase how the different configurations available in
Trainax
can be depicted schematically.
import jax.numpy as jnp
import trainax as tx
Supervised¤
A supervised configuration is special because all data can be pre-computed. No
ref_stepper
or residuum_fn
is needed on the fly (and hence also does not
have to be differentiable).
One-Step supervised¤
# The default is one-step supervised learning
tx.configuration.Supervised()
Two-Step supervised (rollout) Training¤
We roll out the neural emulator for two autoregressive steps. Its parameters are
shared between the two predictions. Similarly, the ref_stepper
is used to
create the reference trajectory; the loss is aggregated as a sum over the two
time levels.
tx.configuration.Supervised(num_rollout_steps=2)
tx.configuration.Supervised(num_rollout_steps=3)
Three-Step supervised (rollout) Training with loss only at final state¤
The loss is only taken from the last step. Essentially, this corresponds to weighting the time levels with \([0, 0, 1]\), respectively. (More weighting options are possible, of course.)
tx.configuration.Supervised(
num_rollout_steps=3, time_level_weights=jnp.array([0.0, 0.0, 1.0])
)
Three-Step supervised (rollout) Training with no backpropagation through time¤
(Displays the primal evaluation together with the cotangent flow; grey dashed line indicates a cutted gradient.)
This interrupts a gradient flow backward over the autoregressive network execution. Gradients can still flow into the parameter space.
tx.configuration.Supervised(num_rollout_steps=3, cut_bptt=True)
Four Steps supervised (rollout) Training with sparse backpropagation through time¤
Only every second backpropagation step is allowed to flow through the network.
tx.configuration.Supervised(num_rollout_steps=4, cut_bptt=True, cut_bptt_every=2)
Diverted Chain¤
Two-Steps with branch length one¤
The ref_stepper
is not run autoregressively for two steps from the initial
condition but rather for one step, branching off from the main chain created by
the emulator.
# `num_rollout_steps` referse to the number of autoregressive steps performed by
# the neural emulator
tx.configuration.DivertedChainBranchOne(num_rollout_steps=2)
# Alternatively, the general interface can be used
tx.configuration.DivertedChain(num_rollout_steps=2, num_branch_steps=1)
Three-steps with branch length one¤
tx.configuration.DivertedChainBranchOne(num_rollout_steps=3)
Four-steps with branch length one¤
tx.configuration.DivertedChainBranchOne(num_rollout_steps=4)
Three-steps with branch length two¤
# Can only be done with the general interface
tx.configuration.DivertedChain(num_rollout_steps=3, num_branch_steps=2)
Two-Steps with no differentiable physics¤
tx.configuration.DivertedChainBranchOne(
num_rollout_steps=2,
cut_div_chain=True,
)
Two-Steps with no backpropagation through time¤
tx.configuration.DivertedChainBranchOne(
num_rollout_steps=2,
cut_bptt=True,
)
Two-Steps with no backpropagation through time and no differentiable physics¤
tx.configuration.DivertedChainBranchOne(
num_rollout_steps=2,
cut_bptt=True,
cut_div_chain=True,
)
Mix-Chain¤
So far, Trainax
only supports "post-physics" mixing, meaning that the main
chain is built by first performing a specified number of autoregressive network
steps, and then a specified number of ref_stepper
steps.
The reference trajectory is always built by autoregressively unrolling the
ref_stepper
.
One-Step Network with one Step Physics¤
tx.configuration.MixChainPostPhysics(
num_rollout_steps=1,
num_post_physics_steps=1,
)
One-Step Network with one step physics and loss only at final state¤
Similar to the supervised setting, this is achieved by choosing proper
time_level_weights
. For MixChainPostPhysics
the time_level_weights
refer
to the entire main chain, i.e., the trajectory created by the former network
steps and the latter physics steps.
tx.configuration.MixChainPostPhysics(
num_rollout_steps=1,
num_post_physics_steps=1,
time_level_weights=jnp.array([0.0, 1.0]),
)
Two-Step Network with one step physics¤
tx.configuration.MixChainPostPhysics(
num_rollout_steps=1,
num_post_physics_steps=2,
)
Two-Step Network with one step physics and no backpropagation through time¤
tx.configuration.MixChainPostPhysics(
num_rollout_steps=1,
num_post_physics_steps=2,
cut_bptt=True,
)
Residuum¤
Instead of having a ref_stepper
that can be unrolled autoregressively, these
configurations rely on a residuum_fn
that defines a condition based on two
consecutive time levels.
One-Step Residuum¤
tx.configuration.Residuum(
num_rollout_steps=1,
)
Two Steps Residuum Training¤
tx.configuration.Residuum(
num_rollout_steps=2,
)
Three Steps Residuum Training¤
tx.configuration.Residuum(
num_rollout_steps=3,
)
Three Steps Residuum with no backpropagation through time¤
tx.configuration.Residuum(num_rollout_steps=3, cut_bptt=True)
Other Residuum Options¤
It is possible to cut the prev
and next
contribution to the residuum_fn
.
Teacher Forcing¤
Resets the main chain with information from the autoregressive reference chain. It is essentially the opposite of diverted chain learning.
It has similarities as if one selected minibatches over the entire trajectories. However, this setup guarantees that within one gradient update, multiple consecutive time levels are considered without having the network to rollout autoregressively.
Three Steps teacher forcing with reset every step¤
# TODO Implementation
Four Steps teacher forcing with reset every second step¤
# TODO implementation
Four Steps teacher forcing with reset every second step and no backpropagation through time¤
# TODO implementation
How about correction learning?¤
All the above mentioned setups are also usable for correction learning, i.e., when the emulator is not just a pure network but has some (differentiable) (coarse) solver component. For example, in the case of sequential correction
See this websites for options of potential corrector layouts and options to cut gradients within it.
All these layouts are not provided by Trainax
. This is just to showcase
that the configurations of Trainax
can be used in a more general context.