Skip to content

Diverted Chain¤

trainax.trainer.DivertedChainBranchOneTrainer ¤

Bases: GeneralTrainer

Source code in trainax/trainer/_diverted_chain_branch_one.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class DivertedChainBranchOneTrainer(GeneralTrainer):
    def __init__(
        self,
        data_trajectories,
        *,
        ref_stepper: eqx.Module,
        residuum_fn: eqx.Module = None,  # for compatibility
        optimizer: optax.GradientTransformation,
        callback_fn: Optional[BaseCallback] = None,
        num_training_steps: int,
        batch_size: int,
        num_rollout_steps: int = 1,
        time_level_loss: BaseLoss = MSELoss(),
        cut_bptt: bool = False,
        cut_bptt_every: int = 1,
        cut_div_chain: bool = False,
        time_level_weights: Optional[
            Float[Array, "num_rollout_steps"]  # noqa F821
        ] = None,
        do_sub_stacking: bool = True,
    ):
        """
        Diverted chain (rollout) configuration with branch length fixed to one.

        Essentially, this amounts to a one-step difference to a reference
        (created on the fly by the differentiable `ref_stepper`). Falls back to
        classical one-step supervised training for `num_rollout_steps=1`
        (default).

        **Arguments:**

        - `data_trajectories`: The batch of trajectories to slice. This must be
            a PyTree of Arrays who have at least two leading axes: a batch-axis
            and a time axis. For example, the zeroth axis can be associated with
            multiple initial conditions or constitutive parameters and the first
            axis represents all temporal snapshots. A PyTree can also just be an
            array. You can provide additional leafs in the PyTree, e.g., for the
            corresponding constitutive parameters etc. Make sure that the
            emulator has the corresponding signature.
        - `ref_stepper`: The reference stepper to use for the diverted chain.
            This is called on-the-fly.
        - `residuum_fn`: For compatibility with other configurations; not used.
        - `optimizer`: The optimizer to use for training. For example, this can
            be `optax.adam(LEARNING_RATE)`. Also use this to supply an optimizer
            with learning rate decay, for example
            `optax.adam(optax.exponential_decay(...))`. If your learning rate
            decay is designed for a certain number of update steps, make sure
            that it aligns with `num_training_steps`.
        - `callback_fn`: A callback to use during training. Defaults to None.
        - `num_training_steps`: The number of training steps to perform.
        - `batch_size`: The batch size to use for training. Batches are
            randomly sampled across both multiple trajectories, but also over
            different windows within one trajectory.
        - `num_rollout_steps: The number of time steps to autoregressively
            roll out the model.
        - `time_level_loss`: The loss function to use at each time step.
        - `cut_bptt`: Whether to cut the backpropagation through time (BPTT),
            i.e., insert a `jax.lax.stop_gradient` into the autoregressive
            network main chain.
        - `cut_bptt_every`: The frequency at which to cut the BPTT. Only
            relevant if `cut_bptt` is True. Defaults to 1 (meaning after each
            step).
        - `cut_div_chain`: Whether to cut the diverted chain, i.e.,
            insert a `jax.lax.stop_gradient` to not have cotangents flow over
            the `ref_stepper`. In this case, the `ref_stepper` does not have to
            be differentiable.
        - `time_level_weights`: An array of length `num_rollout_steps` that
            contains the weights for each time step. Defaults to None, which
            means that all time steps have the same weight (=1.0). (keyword-only
            argument)


        !!! info
            * The `ref_stepper` is called on-the-fly. If its forward (and vjp)
                execution are expensive, this will dominate the computational
                cost of this configuration.
            * The usage of the `ref_stepper` includes the first branch starting
                from the initial condition. Hence, no reference trajectory is
                required.
            * Under reverse-mode automatic differentiation memory usage grows
                linearly with `num_rollout_steps`.
        """
        trajectory_sub_stacker = TrajectorySubStacker(
            data_trajectories,
            sub_trajectory_len=num_rollout_steps + 1,  # +1 for the IC
            do_sub_stacking=do_sub_stacking,
            only_store_ic=True,  # Not needed because we use the ref_stepper
        )
        loss_configuration = DivertedChainBranchOne(
            num_rollout_steps=num_rollout_steps,
            time_level_loss=time_level_loss,
            cut_bptt=cut_bptt,
            cut_bptt_every=cut_bptt_every,
            cut_div_chain=cut_div_chain,
            time_level_weights=time_level_weights,
        )
        super().__init__(
            trajectory_sub_stacker,
            loss_configuration,
            ref_stepper=ref_stepper,
            residuum_fn=residuum_fn,
            optimizer=optimizer,
            num_minibatches=num_training_steps,
            batch_size=batch_size,
            callback_fn=callback_fn,
        )
__init__ ¤
__init__(
    data_trajectories,
    *,
    ref_stepper: eqx.Module,
    residuum_fn: eqx.Module = None,
    optimizer: optax.GradientTransformation,
    callback_fn: Optional[BaseCallback] = None,
    num_training_steps: int,
    batch_size: int,
    num_rollout_steps: int = 1,
    time_level_loss: BaseLoss = MSELoss(),
    cut_bptt: bool = False,
    cut_bptt_every: int = 1,
    cut_div_chain: bool = False,
    time_level_weights: Optional[
        Float[Array, num_rollout_steps]
    ] = None,
    do_sub_stacking: bool = True
)

Diverted chain (rollout) configuration with branch length fixed to one.

Essentially, this amounts to a one-step difference to a reference (created on the fly by the differentiable ref_stepper). Falls back to classical one-step supervised training for num_rollout_steps=1 (default).

Arguments:

  • data_trajectories: The batch of trajectories to slice. This must be a PyTree of Arrays who have at least two leading axes: a batch-axis and a time axis. For example, the zeroth axis can be associated with multiple initial conditions or constitutive parameters and the first axis represents all temporal snapshots. A PyTree can also just be an array. You can provide additional leafs in the PyTree, e.g., for the corresponding constitutive parameters etc. Make sure that the emulator has the corresponding signature.
  • ref_stepper: The reference stepper to use for the diverted chain. This is called on-the-fly.
  • residuum_fn: For compatibility with other configurations; not used.
  • optimizer: The optimizer to use for training. For example, this can be optax.adam(LEARNING_RATE). Also use this to supply an optimizer with learning rate decay, for example optax.adam(optax.exponential_decay(...)). If your learning rate decay is designed for a certain number of update steps, make sure that it aligns with num_training_steps.
  • callback_fn: A callback to use during training. Defaults to None.
  • num_training_steps: The number of training steps to perform.
  • batch_size: The batch size to use for training. Batches are randomly sampled across both multiple trajectories, but also over different windows within one trajectory.
  • `num_rollout_steps: The number of time steps to autoregressively roll out the model.
  • time_level_loss: The loss function to use at each time step.
  • cut_bptt: Whether to cut the backpropagation through time (BPTT), i.e., insert a jax.lax.stop_gradient into the autoregressive network main chain.
  • cut_bptt_every: The frequency at which to cut the BPTT. Only relevant if cut_bptt is True. Defaults to 1 (meaning after each step).
  • cut_div_chain: Whether to cut the diverted chain, i.e., insert a jax.lax.stop_gradient to not have cotangents flow over the ref_stepper. In this case, the ref_stepper does not have to be differentiable.
  • time_level_weights: An array of length num_rollout_steps that contains the weights for each time step. Defaults to None, which means that all time steps have the same weight (=1.0). (keyword-only argument)

Info

  • The ref_stepper is called on-the-fly. If its forward (and vjp) execution are expensive, this will dominate the computational cost of this configuration.
  • The usage of the ref_stepper includes the first branch starting from the initial condition. Hence, no reference trajectory is required.
  • Under reverse-mode automatic differentiation memory usage grows linearly with num_rollout_steps.
Source code in trainax/trainer/_diverted_chain_branch_one.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def __init__(
    self,
    data_trajectories,
    *,
    ref_stepper: eqx.Module,
    residuum_fn: eqx.Module = None,  # for compatibility
    optimizer: optax.GradientTransformation,
    callback_fn: Optional[BaseCallback] = None,
    num_training_steps: int,
    batch_size: int,
    num_rollout_steps: int = 1,
    time_level_loss: BaseLoss = MSELoss(),
    cut_bptt: bool = False,
    cut_bptt_every: int = 1,
    cut_div_chain: bool = False,
    time_level_weights: Optional[
        Float[Array, "num_rollout_steps"]  # noqa F821
    ] = None,
    do_sub_stacking: bool = True,
):
    """
    Diverted chain (rollout) configuration with branch length fixed to one.

    Essentially, this amounts to a one-step difference to a reference
    (created on the fly by the differentiable `ref_stepper`). Falls back to
    classical one-step supervised training for `num_rollout_steps=1`
    (default).

    **Arguments:**

    - `data_trajectories`: The batch of trajectories to slice. This must be
        a PyTree of Arrays who have at least two leading axes: a batch-axis
        and a time axis. For example, the zeroth axis can be associated with
        multiple initial conditions or constitutive parameters and the first
        axis represents all temporal snapshots. A PyTree can also just be an
        array. You can provide additional leafs in the PyTree, e.g., for the
        corresponding constitutive parameters etc. Make sure that the
        emulator has the corresponding signature.
    - `ref_stepper`: The reference stepper to use for the diverted chain.
        This is called on-the-fly.
    - `residuum_fn`: For compatibility with other configurations; not used.
    - `optimizer`: The optimizer to use for training. For example, this can
        be `optax.adam(LEARNING_RATE)`. Also use this to supply an optimizer
        with learning rate decay, for example
        `optax.adam(optax.exponential_decay(...))`. If your learning rate
        decay is designed for a certain number of update steps, make sure
        that it aligns with `num_training_steps`.
    - `callback_fn`: A callback to use during training. Defaults to None.
    - `num_training_steps`: The number of training steps to perform.
    - `batch_size`: The batch size to use for training. Batches are
        randomly sampled across both multiple trajectories, but also over
        different windows within one trajectory.
    - `num_rollout_steps: The number of time steps to autoregressively
        roll out the model.
    - `time_level_loss`: The loss function to use at each time step.
    - `cut_bptt`: Whether to cut the backpropagation through time (BPTT),
        i.e., insert a `jax.lax.stop_gradient` into the autoregressive
        network main chain.
    - `cut_bptt_every`: The frequency at which to cut the BPTT. Only
        relevant if `cut_bptt` is True. Defaults to 1 (meaning after each
        step).
    - `cut_div_chain`: Whether to cut the diverted chain, i.e.,
        insert a `jax.lax.stop_gradient` to not have cotangents flow over
        the `ref_stepper`. In this case, the `ref_stepper` does not have to
        be differentiable.
    - `time_level_weights`: An array of length `num_rollout_steps` that
        contains the weights for each time step. Defaults to None, which
        means that all time steps have the same weight (=1.0). (keyword-only
        argument)


    !!! info
        * The `ref_stepper` is called on-the-fly. If its forward (and vjp)
            execution are expensive, this will dominate the computational
            cost of this configuration.
        * The usage of the `ref_stepper` includes the first branch starting
            from the initial condition. Hence, no reference trajectory is
            required.
        * Under reverse-mode automatic differentiation memory usage grows
            linearly with `num_rollout_steps`.
    """
    trajectory_sub_stacker = TrajectorySubStacker(
        data_trajectories,
        sub_trajectory_len=num_rollout_steps + 1,  # +1 for the IC
        do_sub_stacking=do_sub_stacking,
        only_store_ic=True,  # Not needed because we use the ref_stepper
    )
    loss_configuration = DivertedChainBranchOne(
        num_rollout_steps=num_rollout_steps,
        time_level_loss=time_level_loss,
        cut_bptt=cut_bptt,
        cut_bptt_every=cut_bptt_every,
        cut_div_chain=cut_div_chain,
        time_level_weights=time_level_weights,
    )
    super().__init__(
        trajectory_sub_stacker,
        loss_configuration,
        ref_stepper=ref_stepper,
        residuum_fn=residuum_fn,
        optimizer=optimizer,
        num_minibatches=num_training_steps,
        batch_size=batch_size,
        callback_fn=callback_fn,
    )
__call__ ¤
__call__(
    stepper: eqx.Module,
    key: PRNGKeyArray,
    opt_state: Optional[optax.OptState] = None,
    *,
    return_loss_history: bool = True,
    record_loss_every: int = 1,
    spawn_tqdm: bool = True
) -> Union[
    tuple[eqx.Module, Float[Array, num_minibatches]],
    eqx.Module,
    tuple[eqx.Module, Float[Array, num_minibatches], list],
    tuple[eqx.Module, list],
]

Perform the entire training of an autoregressive neural emulator given in an initial state as stepper.

This method's return signature depends on the presence of a callback function. If a callback function is provided, this function has at max three return values. The first return value is the trained stepper, the second return value is the loss history, and the third return value is the auxiliary history. The auxiliary history is a list of the return values of the callback function at each minibatch. If no callback function is provided, this function has at max two return values. The first return value is the trained stepper, and the second return value is the loss history. If return_loss_history is set to False, the loss history will not be returned.

Arguments:

  • stepper: The equinox Module to be trained.
  • key: The random key to be used for shuffling the minibatches.
  • opt_state: The initial optimizer state. Defaults to None, meaning the optimizer will be reinitialized.
  • return_loss_history: Whether to return the loss history.
  • record_loss_every: Record the loss every record_loss_every minibatches. Defaults to 1, i.e., record every minibatch.
  • spawn_tqdm: Whether to spawn the tqdm progress meter showing the current update step and displaying the epoch with its respetive minibatch counter.

Returns:

  • Varying, see above. It will always return the trained stepper as the first return value.

Tip

You can use equinox.filter_vmap to train mulitple networks (of the same architecture) at the same time. For example, if your GPU is not fully utilized yet, this will give you a init-seed statistic basically for free.

Source code in trainax/_general_trainer.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def __call__(
    self,
    stepper: eqx.Module,
    key: PRNGKeyArray,
    opt_state: Optional[optax.OptState] = None,
    *,
    return_loss_history: bool = True,
    record_loss_every: int = 1,
    spawn_tqdm: bool = True,
) -> Union[
    tuple[eqx.Module, Float[Array, "num_minibatches"]],
    eqx.Module,
    tuple[eqx.Module, Float[Array, "num_minibatches"], list],
    tuple[eqx.Module, list],
]:
    """
    Perform the entire training of an autoregressive neural emulator given
    in an initial state as `stepper`.

    This method's return signature depends on the presence of a callback
    function. If a callback function is provided, this function has at max
    three return values. The first return value is the trained stepper, the
    second return value is the loss history, and the third return value is
    the auxiliary history. The auxiliary history is a list of the return
    values of the callback function at each minibatch. If no callback
    function is provided, this function has at max two return values. The
    first return value is the trained stepper, and the second return value
    is the loss history. If `return_loss_history` is set to `False`, the
    loss history will not be returned.

    **Arguments:**

    - `stepper`: The equinox Module to be trained.
    - `key`: The random key to be used for shuffling the minibatches.
    - `opt_state`: The initial optimizer state. Defaults to None, meaning
        the optimizer will be reinitialized.
    - `return_loss_history`: Whether to return the loss history.
    - `record_loss_every`: Record the loss every `record_loss_every`
        minibatches. Defaults to 1, i.e., record every minibatch.
    - `spawn_tqdm`: Whether to spawn the tqdm progress meter showing the
        current update step and displaying the epoch with its respetive
        minibatch counter.

    **Returns:**

    - Varying, see above. It will always return the trained stepper as the
        first return value.

    !!! tip
        You can use `equinox.filter_vmap` to train mulitple networks (of the
        same architecture) at the same time. For example, if your GPU is not
        fully utilized yet, this will give you a init-seed statistic
        basically for free.
    """
    loss_history = []
    if self.callback_fn is not None:
        aux_history = []

    mixer = PermutationMixer(
        num_total_samples=self.trajectory_sub_stacker.num_total_samples,
        num_minibatches=self.num_minibatches,
        batch_size=self.batch_size,
        shuffle_key=key,
    )

    if spawn_tqdm:
        p_meter = tqdm(
            total=self.num_minibatches,
            desc=f"E: {0:05d}, B: {0:05d}",
        )

    update_fn = eqx.filter_jit(self.step_fn)

    trained_stepper = stepper
    if opt_state is None:
        opt_state = self.optimizer.init(eqx.filter(trained_stepper, eqx.is_array))

    for update_i in range(self.num_minibatches):
        batch_indices, (expoch_id, batch_id) = mixer(update_i, return_info=True)
        data = self.trajectory_sub_stacker(batch_indices)
        if self.callback_fn is not None:
            aux = self.callback_fn(update_i, trained_stepper, data)
            aux_history.append(aux)
        trained_stepper, opt_state, loss = update_fn(
            trained_stepper, opt_state, data
        )
        if update_i % record_loss_every == 0:
            loss_history.append(loss)
        if spawn_tqdm:
            p_meter.update(1)

            p_meter.set_description(
                f"E: {expoch_id:05d}, B: {batch_id:05d}",
            )

    if spawn_tqdm:
        p_meter.close()

    loss_history = jnp.array(loss_history)

    if self.callback_fn is not None:
        if return_loss_history:
            return trained_stepper, loss_history, aux_history
        else:
            return trained_stepper, aux_history
    else:
        if return_loss_history:
            return trained_stepper, loss_history
        else:
            return trained_stepper