Skip to content

Residuum¤

trainax.trainer.ResiduumTrainer ¤

Bases: GeneralTrainer

Source code in trainax/trainer/_residuum.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class ResiduumTrainer(GeneralTrainer):
    def __init__(
        self,
        data_trajectories: PyTree[Float[Array, "num_samples trj_len ..."]],
        *,
        ref_stepper: eqx.Module = None,  # for compatibility
        residuum_fn: eqx.Module,
        optimizer: optax.GradientTransformation,
        callback_fn: Optional[BaseCallback] = None,
        num_training_steps: int,
        batch_size: int,
        num_rollout_steps: int = 1,
        time_level_loss: BaseLoss = MSELoss(),
        cut_bptt: bool = False,
        cut_bptt_every: int = 1,
        cut_prev: bool = False,
        cut_next: bool = False,
        time_level_weights: Optional[
            Float[Array, "num_rollout_steps"]  # noqa F821
        ] = None,
        do_sub_stacking: bool = True,
    ):
        """
        Residuum (rollout) training for an autoregressive neural emulator on a
        collection of trajectories.

        If the `ref_stepper` resolves the `residuum_fn` for `u_next`, `Residuum`
        configuration and `DivertedChainBranchOne` configuration can be bounded
        (their losses can be bounded). We believe, however, that both induce
        different optimization trajectories (and different local optima) because
        the residuum-based loss is conditioned worse.

        **Arguments:**

        - `data_trajectories`: The batch of trajectories to slice. This must be
            a PyTree of Arrays who have at least two leading axes: a batch-axis
            and a time axis. For example, the zeroth axis can be associated with
            multiple initial conditions or constitutive parameters and the first
            axis represents all temporal snapshots. A PyTree can also just be an
            array. You can provide additional leafs in the PyTree, e.g., for the
            corresponding constitutive parameters etc. Make sure that the
            emulator has the corresponding signature.
        - `ref_stepper`: For compatibility with other configurations; not used.
        - `residuum_fn`: The residuum function to use for the configuration.
            Must have the signature `residuum_fn(u_next: PyTree, u_prev: PyTree)
            -> residuum: PyTree`.
        - `optimizer`: The optimizer to use for training. For example, this can
            be `optax.adam(LEARNING_RATE)`. Also use this to supply an optimizer
            with learning rate decay, for example
            `optax.adam(optax.exponential_decay(...))`. If your learning rate
            decay is designed for a certain number of update steps, make sure
            that it aligns with `num_training_steps`.
        - `callback_fn`: A callback to use during training. Defaults to None.
        - `num_training_steps`: The number of training steps to perform.
        - `batch_size`: The batch size to use for training. Batches are
            randomly sampled across both multiple trajectories, but also over
            different windows within one trajectory.
        - `num_rollout_steps`: The number of time steps to autoregressively roll
            out the model during training.
        - `time_level_loss`: The loss function to use at each time step.
        - `cut_bptt`: Whether to cut the backpropagation through time (BPTT),
            i.e., insert a `jax.lax.stop_gradient` into the autoregressive
            network main chain.
        - `cut_bptt_every`: The frequency at which to cut the BPTT.
            Only relevant if `cut_bptt` is True. Defaults to 1 (meaning after
            each step).
        - `cut_prev`: Whether to cut the previous time level contribution
            to `residuum_fn`.
        - `cut_next`: Whether to cut the next time level contribution
            to `residuum_fn`.
        - `time_level_weights: An array of length `num_rollout_steps` that
            contains the weights for each time step. Defaults to None, which
            means that all time steps have the same weight (=1.0).

        !!! info
            * Under reverse-mode automatic differentiation memory usage grows
                linearly with `num_rollout_steps`.
        """
        trajectory_sub_stacker = TrajectorySubStacker(
            data_trajectories,
            sub_trajectory_len=num_rollout_steps + 1,  # +1 for the IC
            do_sub_stacking=do_sub_stacking,
            only_store_ic=False,
        )
        loss_configuration = Residuum(
            num_rollout_steps=num_rollout_steps,
            time_level_loss=time_level_loss,
            cut_bptt=cut_bptt,
            cut_bptt_every=cut_bptt_every,
            cut_prev=cut_prev,
            cut_next=cut_next,
            time_level_weights=time_level_weights,
        )
        super().__init__(
            trajectory_sub_stacker,
            loss_configuration,
            ref_stepper=ref_stepper,
            residuum_fn=residuum_fn,
            optimizer=optimizer,
            num_minibatches=num_training_steps,
            batch_size=batch_size,
            callback_fn=callback_fn,
        )
__init__ ¤
__init__(
    data_trajectories: PyTree[
        Float[Array, "num_samples trj_len ..."]
    ],
    *,
    ref_stepper: eqx.Module = None,
    residuum_fn: eqx.Module,
    optimizer: optax.GradientTransformation,
    callback_fn: Optional[BaseCallback] = None,
    num_training_steps: int,
    batch_size: int,
    num_rollout_steps: int = 1,
    time_level_loss: BaseLoss = MSELoss(),
    cut_bptt: bool = False,
    cut_bptt_every: int = 1,
    cut_prev: bool = False,
    cut_next: bool = False,
    time_level_weights: Optional[
        Float[Array, num_rollout_steps]
    ] = None,
    do_sub_stacking: bool = True
)

Residuum (rollout) training for an autoregressive neural emulator on a collection of trajectories.

If the ref_stepper resolves the residuum_fn for u_next, Residuum configuration and DivertedChainBranchOne configuration can be bounded (their losses can be bounded). We believe, however, that both induce different optimization trajectories (and different local optima) because the residuum-based loss is conditioned worse.

Arguments:

  • data_trajectories: The batch of trajectories to slice. This must be a PyTree of Arrays who have at least two leading axes: a batch-axis and a time axis. For example, the zeroth axis can be associated with multiple initial conditions or constitutive parameters and the first axis represents all temporal snapshots. A PyTree can also just be an array. You can provide additional leafs in the PyTree, e.g., for the corresponding constitutive parameters etc. Make sure that the emulator has the corresponding signature.
  • ref_stepper: For compatibility with other configurations; not used.
  • residuum_fn: The residuum function to use for the configuration. Must have the signature residuum_fn(u_next: PyTree, u_prev: PyTree) -> residuum: PyTree.
  • optimizer: The optimizer to use for training. For example, this can be optax.adam(LEARNING_RATE). Also use this to supply an optimizer with learning rate decay, for example optax.adam(optax.exponential_decay(...)). If your learning rate decay is designed for a certain number of update steps, make sure that it aligns with num_training_steps.
  • callback_fn: A callback to use during training. Defaults to None.
  • num_training_steps: The number of training steps to perform.
  • batch_size: The batch size to use for training. Batches are randomly sampled across both multiple trajectories, but also over different windows within one trajectory.
  • num_rollout_steps: The number of time steps to autoregressively roll out the model during training.
  • time_level_loss: The loss function to use at each time step.
  • cut_bptt: Whether to cut the backpropagation through time (BPTT), i.e., insert a jax.lax.stop_gradient into the autoregressive network main chain.
  • cut_bptt_every: The frequency at which to cut the BPTT. Only relevant if cut_bptt is True. Defaults to 1 (meaning after each step).
  • cut_prev: Whether to cut the previous time level contribution to residuum_fn.
  • cut_next: Whether to cut the next time level contribution to residuum_fn.
  • time_level_weights: An array of lengthnum_rollout_steps` that contains the weights for each time step. Defaults to None, which means that all time steps have the same weight (=1.0).

Info

  • Under reverse-mode automatic differentiation memory usage grows linearly with num_rollout_steps.
Source code in trainax/trainer/_residuum.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def __init__(
    self,
    data_trajectories: PyTree[Float[Array, "num_samples trj_len ..."]],
    *,
    ref_stepper: eqx.Module = None,  # for compatibility
    residuum_fn: eqx.Module,
    optimizer: optax.GradientTransformation,
    callback_fn: Optional[BaseCallback] = None,
    num_training_steps: int,
    batch_size: int,
    num_rollout_steps: int = 1,
    time_level_loss: BaseLoss = MSELoss(),
    cut_bptt: bool = False,
    cut_bptt_every: int = 1,
    cut_prev: bool = False,
    cut_next: bool = False,
    time_level_weights: Optional[
        Float[Array, "num_rollout_steps"]  # noqa F821
    ] = None,
    do_sub_stacking: bool = True,
):
    """
    Residuum (rollout) training for an autoregressive neural emulator on a
    collection of trajectories.

    If the `ref_stepper` resolves the `residuum_fn` for `u_next`, `Residuum`
    configuration and `DivertedChainBranchOne` configuration can be bounded
    (their losses can be bounded). We believe, however, that both induce
    different optimization trajectories (and different local optima) because
    the residuum-based loss is conditioned worse.

    **Arguments:**

    - `data_trajectories`: The batch of trajectories to slice. This must be
        a PyTree of Arrays who have at least two leading axes: a batch-axis
        and a time axis. For example, the zeroth axis can be associated with
        multiple initial conditions or constitutive parameters and the first
        axis represents all temporal snapshots. A PyTree can also just be an
        array. You can provide additional leafs in the PyTree, e.g., for the
        corresponding constitutive parameters etc. Make sure that the
        emulator has the corresponding signature.
    - `ref_stepper`: For compatibility with other configurations; not used.
    - `residuum_fn`: The residuum function to use for the configuration.
        Must have the signature `residuum_fn(u_next: PyTree, u_prev: PyTree)
        -> residuum: PyTree`.
    - `optimizer`: The optimizer to use for training. For example, this can
        be `optax.adam(LEARNING_RATE)`. Also use this to supply an optimizer
        with learning rate decay, for example
        `optax.adam(optax.exponential_decay(...))`. If your learning rate
        decay is designed for a certain number of update steps, make sure
        that it aligns with `num_training_steps`.
    - `callback_fn`: A callback to use during training. Defaults to None.
    - `num_training_steps`: The number of training steps to perform.
    - `batch_size`: The batch size to use for training. Batches are
        randomly sampled across both multiple trajectories, but also over
        different windows within one trajectory.
    - `num_rollout_steps`: The number of time steps to autoregressively roll
        out the model during training.
    - `time_level_loss`: The loss function to use at each time step.
    - `cut_bptt`: Whether to cut the backpropagation through time (BPTT),
        i.e., insert a `jax.lax.stop_gradient` into the autoregressive
        network main chain.
    - `cut_bptt_every`: The frequency at which to cut the BPTT.
        Only relevant if `cut_bptt` is True. Defaults to 1 (meaning after
        each step).
    - `cut_prev`: Whether to cut the previous time level contribution
        to `residuum_fn`.
    - `cut_next`: Whether to cut the next time level contribution
        to `residuum_fn`.
    - `time_level_weights: An array of length `num_rollout_steps` that
        contains the weights for each time step. Defaults to None, which
        means that all time steps have the same weight (=1.0).

    !!! info
        * Under reverse-mode automatic differentiation memory usage grows
            linearly with `num_rollout_steps`.
    """
    trajectory_sub_stacker = TrajectorySubStacker(
        data_trajectories,
        sub_trajectory_len=num_rollout_steps + 1,  # +1 for the IC
        do_sub_stacking=do_sub_stacking,
        only_store_ic=False,
    )
    loss_configuration = Residuum(
        num_rollout_steps=num_rollout_steps,
        time_level_loss=time_level_loss,
        cut_bptt=cut_bptt,
        cut_bptt_every=cut_bptt_every,
        cut_prev=cut_prev,
        cut_next=cut_next,
        time_level_weights=time_level_weights,
    )
    super().__init__(
        trajectory_sub_stacker,
        loss_configuration,
        ref_stepper=ref_stepper,
        residuum_fn=residuum_fn,
        optimizer=optimizer,
        num_minibatches=num_training_steps,
        batch_size=batch_size,
        callback_fn=callback_fn,
    )
__call__ ¤
__call__(
    stepper: eqx.Module,
    key: PRNGKeyArray,
    opt_state: Optional[optax.OptState] = None,
    *,
    return_loss_history: bool = True,
    record_loss_every: int = 1,
    spawn_tqdm: bool = True
) -> Union[
    tuple[eqx.Module, Float[Array, num_minibatches]],
    eqx.Module,
    tuple[eqx.Module, Float[Array, num_minibatches], list],
    tuple[eqx.Module, list],
]

Perform the entire training of an autoregressive neural emulator given in an initial state as stepper.

This method's return signature depends on the presence of a callback function. If a callback function is provided, this function has at max three return values. The first return value is the trained stepper, the second return value is the loss history, and the third return value is the auxiliary history. The auxiliary history is a list of the return values of the callback function at each minibatch. If no callback function is provided, this function has at max two return values. The first return value is the trained stepper, and the second return value is the loss history. If return_loss_history is set to False, the loss history will not be returned.

Arguments:

  • stepper: The equinox Module to be trained.
  • key: The random key to be used for shuffling the minibatches.
  • opt_state: The initial optimizer state. Defaults to None, meaning the optimizer will be reinitialized.
  • return_loss_history: Whether to return the loss history.
  • record_loss_every: Record the loss every record_loss_every minibatches. Defaults to 1, i.e., record every minibatch.
  • spawn_tqdm: Whether to spawn the tqdm progress meter showing the current update step and displaying the epoch with its respetive minibatch counter.

Returns:

  • Varying, see above. It will always return the trained stepper as the first return value.

Tip

You can use equinox.filter_vmap to train mulitple networks (of the same architecture) at the same time. For example, if your GPU is not fully utilized yet, this will give you a init-seed statistic basically for free.

Source code in trainax/_general_trainer.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def __call__(
    self,
    stepper: eqx.Module,
    key: PRNGKeyArray,
    opt_state: Optional[optax.OptState] = None,
    *,
    return_loss_history: bool = True,
    record_loss_every: int = 1,
    spawn_tqdm: bool = True,
) -> Union[
    tuple[eqx.Module, Float[Array, "num_minibatches"]],
    eqx.Module,
    tuple[eqx.Module, Float[Array, "num_minibatches"], list],
    tuple[eqx.Module, list],
]:
    """
    Perform the entire training of an autoregressive neural emulator given
    in an initial state as `stepper`.

    This method's return signature depends on the presence of a callback
    function. If a callback function is provided, this function has at max
    three return values. The first return value is the trained stepper, the
    second return value is the loss history, and the third return value is
    the auxiliary history. The auxiliary history is a list of the return
    values of the callback function at each minibatch. If no callback
    function is provided, this function has at max two return values. The
    first return value is the trained stepper, and the second return value
    is the loss history. If `return_loss_history` is set to `False`, the
    loss history will not be returned.

    **Arguments:**

    - `stepper`: The equinox Module to be trained.
    - `key`: The random key to be used for shuffling the minibatches.
    - `opt_state`: The initial optimizer state. Defaults to None, meaning
        the optimizer will be reinitialized.
    - `return_loss_history`: Whether to return the loss history.
    - `record_loss_every`: Record the loss every `record_loss_every`
        minibatches. Defaults to 1, i.e., record every minibatch.
    - `spawn_tqdm`: Whether to spawn the tqdm progress meter showing the
        current update step and displaying the epoch with its respetive
        minibatch counter.

    **Returns:**

    - Varying, see above. It will always return the trained stepper as the
        first return value.

    !!! tip
        You can use `equinox.filter_vmap` to train mulitple networks (of the
        same architecture) at the same time. For example, if your GPU is not
        fully utilized yet, this will give you a init-seed statistic
        basically for free.
    """
    loss_history = []
    if self.callback_fn is not None:
        aux_history = []

    mixer = PermutationMixer(
        num_total_samples=self.trajectory_sub_stacker.num_total_samples,
        num_minibatches=self.num_minibatches,
        batch_size=self.batch_size,
        shuffle_key=key,
    )

    if spawn_tqdm:
        p_meter = tqdm(
            total=self.num_minibatches,
            desc=f"E: {0:05d}, B: {0:05d}",
        )

    update_fn = eqx.filter_jit(self.step_fn)

    trained_stepper = stepper
    if opt_state is None:
        opt_state = self.optimizer.init(eqx.filter(trained_stepper, eqx.is_array))

    for update_i in range(self.num_minibatches):
        batch_indices, (expoch_id, batch_id) = mixer(update_i, return_info=True)
        data = self.trajectory_sub_stacker(batch_indices)
        if self.callback_fn is not None:
            aux = self.callback_fn(update_i, trained_stepper, data)
            aux_history.append(aux)
        trained_stepper, opt_state, loss = update_fn(
            trained_stepper, opt_state, data
        )
        if update_i % record_loss_every == 0:
            loss_history.append(loss)
        if spawn_tqdm:
            p_meter.update(1)

            p_meter.set_description(
                f"E: {expoch_id:05d}, B: {batch_id:05d}",
            )

    if spawn_tqdm:
        p_meter.close()

    loss_history = jnp.array(loss_history)

    if self.callback_fn is not None:
        if return_loss_history:
            return trained_stepper, loss_history, aux_history
        else:
            return trained_stepper, aux_history
    else:
        if return_loss_history:
            return trained_stepper, loss_history
        else:
            return trained_stepper