Mixed-Chain¤
trainax.configuration.MixChainPostPhysics
¤
Bases: BaseConfiguration
Source code in trainax/configuration/_mix_chain_post_physics.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
|
__init__
¤
__init__(
num_rollout_steps: int = 1,
num_post_physics_steps: int = 1,
*,
time_level_loss: BaseLoss = MSELoss(),
cut_bptt: bool = False,
cut_bptt_every: int = 1,
time_level_weights: Optional[
Float[
Array,
num_rollout_steps + num_post_physics_steps,
]
] = None
)
Mix chain (rollout) configuration with autoregressive physics steps after the autoregressive emulator steps in the main chain.
This is a special case of potentially more complicated combitations of neural stepper with reference physics stepper in the main chain.
Arguments:
num_rollout_steps
: The number of time steps to autoregressively roll out the model. Defaults to 1.num_post_physics_steps
: The number of time steps to autoregressively roll physics after the model in the main chain. Defaults to 1. Hence, in the default config, the main chain is model -> physicstime_level_loss
: The loss function to use at each time step. Defaults totrainax.loss.MSELoss
.cut_bptt
: Whether to cut the backpropagation through time (BPTT), i.e., insert ajax.lax.stop_gradient
into the autoregressive network main chain. This excludes the post-physics steps; those are not cutted. Defaults to False.cut_bptt_every
: The frequency at which to cut the BPTT. Only relevant ifcut_bptt
is True. Defaults to 1 (meaning after each step).time_level_weights
: An array of lengthnum_rollout_steps+num_post_physics_steps
that contains the weights for each time step. Defaults to None, which means that all time steps have the same weight (=1.0).
Source code in trainax/configuration/_mix_chain_post_physics.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
|
__call__
¤
__call__(
stepper: eqx.Module,
data: PyTree[Float[Array, "batch num_snapshots ..."]],
*,
ref_stepper: eqx.Module,
residuum_fn: eqx.Module = None
) -> float
Evaluate the mix chain (rollout) configuration on the given data.
The data only has to contain as many time levels as the sum of the number of rollout steps and post physics steps plus one.
Arguments:
stepper
: The stepper to use for the configuration. Must have the signaturestepper(u_prev: PyTree) -> u_next: PyTree
.data
: The data to evaluate the configuration on. Has to contain the initial condition and the target trajectory.ref_stepper
: The reference stepper to use for the configuration. Must have the signatureref_stepper(u_prev: PyTree) -> u_next: PyTree
.residuum_fn
: For compatibility with other configurations; not used here.
Returns:
- The loss value computed by this configuration.
Raises:
- ValueError: If the number of snapshots in the trajectory is less than the number of rollout steps and post physics steps plus one.
Source code in trainax/configuration/_mix_chain_post_physics.py
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
|