Supervised¤
trainax.configuration.Supervised
¤
Bases: BaseConfiguration
Source code in trainax/configuration/_supervised.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
|
__init__
¤
__init__(
num_rollout_steps: int = 1,
*,
time_level_loss: BaseLoss = MSELoss(),
cut_bptt: bool = False,
cut_bptt_every: int = 1,
time_level_weights: Optional[
Float[Array, num_rollout_steps]
] = None
)
General supervised (rollout) configuration.
Falls back to classical one-step supervised training for
num_rollout_steps=1
(default).
Arguments:
num_rollout_steps
: The number of time steps to autoregressively roll out the model. During calling this configuration, it requires a similarly long reference trajectory to be available. Defaults to 1.time_level_loss
: The loss function to use at each time step. Defaults to MSELoss().cut_bptt
: Whether to cut the backpropagation through time (BPTT), i.e., insert ajax.lax.stop_gradient
into the autoregressive network main chain. Defaults to False.cut_bptt_every
: The frequency at which to cut the BPTT. Only relevant ifcut_bptt
is True. Defaults to 1 (meaning after each step).time_level_weights
: An array of lengthnum_rollout_steps
that contains the weights for each time step. Defaults to None, which means that all time steps have the same weight (=1.0).
Warning
Under reverse-mode automatic differentiation memory usage grows
linearly with num_rollout_steps
.
Source code in trainax/configuration/_supervised.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
__call__
¤
__call__(
stepper: eqx.Module,
data: PyTree[Float[Array, "batch num_snapshots ..."]],
*,
ref_stepper: eqx.Module = None,
residuum_fn: eqx.Module = None
) -> float
Evaluate the supervised (rollout) configuration on the given data.
The data is supposed to have as many time steps as the number of rollout
steps plus one. No ref_stepper
or residuum_fn
is needed.
Arguments:
stepper
: The stepper to use for the configuration. Must have the signaturestepper(u_prev: PyTree) -> u_next: PyTree
.data
: The data to evaluate the configuration on. This should contain the initial condition and the target trajectory.ref_stepper
: For compatibility with other configurations; not used.residuum_fn
: For compatibility with other configurations; not used.
Returns:
- The loss value computed by this configuration.
Raises:
- ValueError: If the number of snapshots in the trajectory is less than the number of rollout steps plus one.
Source code in trainax/configuration/_supervised.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
|