Skip to content

Configuration Showcase¤

This notebook serves as a showcase how the different configurations available in Trainax can be depicted schematically.

import jax.numpy as jnp

import trainax as tx
/home/felix/Documents/phd/local_repos/trainax/trainax/_general_trainer.py:7: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
  from tqdm.autonotebook import tqdm

Supervised¤

A supervised configuration is special because all data can be pre-computed. No ref_stepper or residuum_fn is needed on the fly (and hence also does not have to be differentiable).

One-Step supervised¤

No description has been provided for this image

# The default is one-step supervised learning
tx.configuration.Supervised()
Supervised(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[1]
)

Two-Step supervised (rollout) Training¤

No description has been provided for this image

We roll out the neural emulator for two autoregressive steps. Its parameters are shared between the two predictions. Similarly, the ref_stepper is used to create the reference trajectory; the loss is aggregated as a sum over the two time levels.

tx.configuration.Supervised(num_rollout_steps=2)
Supervised(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[2]
)

Three-Step supervised (rollout) Training¤

No description has been provided for this image

Same idead as above but with an additional rollout step.

tx.configuration.Supervised(num_rollout_steps=3)
Supervised(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

Three-Step supervised (rollout) Training with loss only at final state¤

No description has been provided for this image

The loss is only taken from the last step. Essentially, this corresponds to weighting the time levels with \([0, 0, 1]\), respectively. (More weighting options are possible, of course.)

tx.configuration.Supervised(
    num_rollout_steps=3, time_level_weights=jnp.array([0.0, 0.0, 1.0])
)
Supervised(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

Three-Step supervised (rollout) Training with no backpropagation through time¤

(Displays the primal evaluation together with the cotangent flow; grey dashed line indicates a cutted gradient.)

No description has been provided for this image

This interrupts a gradient flow backward over the autoregressive network execution. Gradients can still flow into the parameter space.

tx.configuration.Supervised(num_rollout_steps=3, cut_bptt=True)
Supervised(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=True,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

Four Steps supervised (rollout) Training with sparse backpropagation through time¤

No description has been provided for this image

Only every second backpropagation step is allowed to flow through the network.

tx.configuration.Supervised(num_rollout_steps=4, cut_bptt=True, cut_bptt_every=2)
Supervised(
  num_rollout_steps=4,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=True,
  cut_bptt_every=2,
  time_level_weights=f32[4]
)

Diverted Chain¤

Two-Steps with branch length one¤

No description has been provided for this image

The ref_stepper is not run autoregressively for two steps from the initial condition but rather for one step, branching off from the main chain created by the emulator.

# `num_rollout_steps` referse to the number of autoregressive steps performed by
# the neural emulator
tx.configuration.DivertedChainBranchOne(num_rollout_steps=2)
DivertedChainBranchOne(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[2]
)
# Alternatively, the general interface can be used
tx.configuration.DivertedChain(num_rollout_steps=2, num_branch_steps=1)
DivertedChain(
  num_rollout_steps=2,
  num_branch_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[2],
  branch_level_weights=f32[1]
)

Three-steps with branch length one¤

No description has been provided for this image

tx.configuration.DivertedChainBranchOne(num_rollout_steps=3)
DivertedChainBranchOne(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[3]
)

Four-steps with branch length one¤

No description has been provided for this image

tx.configuration.DivertedChainBranchOne(num_rollout_steps=4)
DivertedChainBranchOne(
  num_rollout_steps=4,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[4]
)

Three-steps with branch length two¤

No description has been provided for this image

# Can only be done with the general interface
tx.configuration.DivertedChain(num_rollout_steps=3, num_branch_steps=2)
DivertedChain(
  num_rollout_steps=3,
  num_branch_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[3],
  branch_level_weights=f32[2]
)

Two-Steps with no differentiable physics¤

No description has been provided for this image

tx.configuration.DivertedChainBranchOne(
    num_rollout_steps=2,
    cut_div_chain=True,
)
DivertedChainBranchOne(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_div_chain=True,
  time_level_weights=f32[2]
)

Two-Steps with no backpropagation through time¤

No description has been provided for this image

tx.configuration.DivertedChainBranchOne(
    num_rollout_steps=2,
    cut_bptt=True,
)
DivertedChainBranchOne(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=True,
  cut_bptt_every=1,
  cut_div_chain=False,
  time_level_weights=f32[2]
)

Two-Steps with no backpropagation through time and no differentiable physics¤

No description has been provided for this image

tx.configuration.DivertedChainBranchOne(
    num_rollout_steps=2,
    cut_bptt=True,
    cut_div_chain=True,
)
DivertedChainBranchOne(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=True,
  cut_bptt_every=1,
  cut_div_chain=True,
  time_level_weights=f32[2]
)

Mix-Chain¤

So far, Trainax only supports "post-physics" mixing, meaning that the main chain is built by first performing a specified number of autoregressive network steps, and then a specified number of ref_stepper steps.

The reference trajectory is always built by autoregressively unrolling the ref_stepper.

One-Step Network with one Step Physics¤

No description has been provided for this image

tx.configuration.MixChainPostPhysics(
    num_rollout_steps=1,
    num_post_physics_steps=1,
)
MixChainPostPhysics(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  num_post_physics_steps=1,
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[2]
)

One-Step Network with one step physics and loss only at final state¤

No description has been provided for this image

Similar to the supervised setting, this is achieved by choosing proper time_level_weights. For MixChainPostPhysics the time_level_weights refer to the entire main chain, i.e., the trajectory created by the former network steps and the latter physics steps.

tx.configuration.MixChainPostPhysics(
    num_rollout_steps=1,
    num_post_physics_steps=1,
    time_level_weights=jnp.array([0.0, 1.0]),
)
MixChainPostPhysics(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  num_post_physics_steps=1,
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[2]
)

Two-Step Network with one step physics¤

No description has been provided for this image

tx.configuration.MixChainPostPhysics(
    num_rollout_steps=1,
    num_post_physics_steps=2,
)
MixChainPostPhysics(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  num_post_physics_steps=2,
  cut_bptt=False,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

Two-Step Network with one step physics and no backpropagation through time¤

No description has been provided for this image

tx.configuration.MixChainPostPhysics(
    num_rollout_steps=1,
    num_post_physics_steps=2,
    cut_bptt=True,
)
MixChainPostPhysics(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  num_post_physics_steps=2,
  cut_bptt=True,
  cut_bptt_every=1,
  time_level_weights=f32[3]
)

Residuum¤

Instead of having a ref_stepper that can be unrolled autoregressively, these configurations rely on a residuum_fn that defines a condition based on two consecutive time levels.

One-Step Residuum¤

No description has been provided for this image

tx.configuration.Residuum(
    num_rollout_steps=1,
)
Residuum(
  num_rollout_steps=1,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_prev=False,
  cut_next=False,
  time_level_weights=f32[1]
)

Two Steps Residuum Training¤

No description has been provided for this image

tx.configuration.Residuum(
    num_rollout_steps=2,
)
Residuum(
  num_rollout_steps=2,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_prev=False,
  cut_next=False,
  time_level_weights=f32[2]
)

Three Steps Residuum Training¤

No description has been provided for this image

tx.configuration.Residuum(
    num_rollout_steps=3,
)
Residuum(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=False,
  cut_bptt_every=1,
  cut_prev=False,
  cut_next=False,
  time_level_weights=f32[3]
)

Three Steps Residuum with no backpropagation through time¤

No description has been provided for this image

tx.configuration.Residuum(num_rollout_steps=3, cut_bptt=True)
Residuum(
  num_rollout_steps=3,
  time_level_loss=MSELoss(batch_reduction=<function mean>),
  cut_bptt=True,
  cut_bptt_every=1,
  cut_prev=False,
  cut_next=False,
  time_level_weights=f32[3]
)

Other Residuum Options¤

It is possible to cut the prev and next contribution to the residuum_fn.

Teacher Forcing¤

Resets the main chain with information from the autoregressive reference chain. It is essentially the opposite of diverted chain learning.

It has similarities as if one selected minibatches over the entire trajectories. However, this setup guarantees that within one gradient update, multiple consecutive time levels are considered without having the network to rollout autoregressively.

Three Steps teacher forcing with reset every step¤

No description has been provided for this image

# TODO Implementation

Four Steps teacher forcing with reset every second step¤

No description has been provided for this image

# TODO implementation

Four Steps teacher forcing with reset every second step and no backpropagation through time¤

No description has been provided for this image

# TODO implementation

How about correction learning?¤

All the above mentioned setups are also usable for correction learning, i.e., when the emulator is not just a pure network but has some (differentiable) (coarse) solver component. For example, in the case of sequential correction

No description has been provided for this image

See this websites for options of potential corrector layouts and options to cut gradients within it.

All these layouts are not provided by Trainax. This is just to showcase that the configurations of Trainax can be used in a more general context.