Residuum¤
trainax.configuration.Residuum
¤
Bases: BaseConfiguration
Source code in trainax/configuration/_residuum.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
|
__init__
¤
__init__(
num_rollout_steps: int = 1,
*,
time_level_loss: BaseLoss = MSELoss(),
cut_bptt: bool = False,
cut_bptt_every: int = 1,
cut_prev: bool = False,
cut_next: bool = False,
time_level_weights: Optional[
Float[Array, num_rollout_steps]
] = None
)
Residuum (rollout) configuration for residua between two consecutive time levels.
If the ref_stepper
resolves the residuum_fn
for u_next
, Residuum
configuration and DivertedChainBranchOne
configuration can be bounded
(their losses can be bounded). We believe, however, that both induce
different optimization trajectories (and different local optima) because
the residuum-based loss is conditioned worse.
Arguments:
num_rollout_steps
: The number of time steps to autoregressively roll out the model. Defaults to 1.time_level_loss
: The loss function to use at each time step. Must operate based on a single input. Defaults to MSELoss().cut_bptt
: Whether to cut the backpropagation through time (BPTT), i.e., insert ajax.lax.stop_gradient
into the autoregressive network main chain. Defaults to False.cut_bptt_every
: The frequency at which to cut the BPTT. Only relevant ifcut_bptt
is True. Defaults to 1 (meaning after each step).cut_prev
: Whether to cut the previous time level contribution toresiduum_fn
. Defaults to False.cut_next
: Whether to cut the next time level contribution toresiduum_fn
. Defaults to False.time_level_weights
: An array of lengthnum_rollout_steps
that contains the weights for each time step. Defaults to None, which means that all time steps have the same weight (=1.0).
Source code in trainax/configuration/_residuum.py
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
|
__call__
¤
__call__(
stepper: eqx.Module,
data: PyTree[Float[Array, "batch num_snapshots ..."]],
*,
ref_stepper: eqx.Module = None,
residuum_fn: eqx.Module
) -> float
Evaluate the residuum (rollout) configuration on the given data.
The data only has to contain one time level, the initial condition. The
residuum_fn
will be used to compute a loss based on two consecutive
time levels.
Arguments:
stepper
: The stepper to use for the configuration. Must have the signaturestepper(u_prev: PyTree) -> u_next: PyTree
.data
: The data to evaluate the configuration on. This depends on the concrete configuration. In this case, it only contains the initial condition.ref_stepper
: The reference stepper to use for the configuration. Must have the signatureref_stepper(u_prev: PyTree) -> u_next: PyTree
. Defaults to None.residuum_fn
: The residuum function to use for the configuration. Must have the signatureresiduum_fn(u_next: PyTree, u_prev: PyTree) -> residuum: PyTree
.
Returns:
- The loss of the configuration.
Source code in trainax/configuration/_residuum.py
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
|