Diverted-Chain¤
trainax.configuration.DivertedChainBranchOne
¤
Bases: BaseConfiguration
Source code in trainax/configuration/_diverted_chain_branch_one.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
|
__init__
¤
__init__(
num_rollout_steps: int = 1,
*,
time_level_loss: BaseLoss = MSELoss(),
cut_bptt: bool = False,
cut_bptt_every: int = 1,
cut_div_chain: bool = False,
time_level_weights: Optional[
Float[Array, num_rollout_steps]
] = None
)
Diverted chain (rollout) configuration with branch length fixed to one.
Essentially, this amounts to a one-step difference to a reference
(create on the fly by the differentiable ref_stepper
). Falls back to
classical one-step supervised training for num_rollout_steps=1
(default).
Arguments:
num_rollout_steps
: The number of time steps to autoregressively roll out the model.Defaults to 1.time_level_loss
: The loss function to use at each time step. Defaults to MSELoss().cut_bptt
: Whether to cut the backpropagation through time (BPTT), i.e., insert ajax.lax.stop_gradient
into the autoregressive network main chain. Defaults to False.cut_bptt_every
: The frequency at which to cut the BPTT. Only relevant ifcut_bptt
is True. Defaults to 1 (meaning after each step).cut_div_chain
: Whether to cut the diverted chain, i.e., insert ajax.lax.stop_gradient
to not have cotangents flow over theref_stepper
. In this case, theref_stepper
does not have to be differentiable.time_level_weights
: An array of lengthnum_rollout_steps
that contains the weights for each time step. Defaults to None, which means that all time steps have the same weight (=1.0).
Info
- The
ref_stepper
is called on-the-fly. If its forward (and vjp) execution are expensive, this will dominate the computational cost of this configuration. - The usage of the
ref_stepper
includes the first branch starting from the initial condition. Hence, no reference trajectory is required. - Under reverse-mode automatic differentiation memory usage grows
linearly with
num_rollout_steps
.
Source code in trainax/configuration/_diverted_chain_branch_one.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
|
__call__
¤
__call__(
stepper: eqx.Module,
data: PyTree[Float[Array, "batch num_snapshots ..."]],
*,
ref_stepper: eqx.Module,
residuum_fn: eqx.Module = None
) -> float
Evaluate the diverted chain (rollout) configuration on the given data.
The data only has to contain one time level, the initial condition.
Arguments:
stepper
: The stepper to use for the configuration. Must have the signaturestepper(u_prev: PyTree) -> u_next: PyTree
.data
: The data to evaluate the configuration on. This depends on the concrete configuration. In this case, it only contains the set of initial states.ref_stepper
: The reference stepper to use for the diverted chain. This is called on-the-fly.residuum_fn
: For compatibility with other configurations; not used.
Returns:
- The loss value computed by this configuration.
Source code in trainax/configuration/_diverted_chain_branch_one.py
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
|
trainax.configuration.DivertedChain
¤
Bases: BaseConfiguration
Source code in trainax/configuration/_diverted_chain.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
|
__init__
¤
__init__(
num_rollout_steps: int = 1,
num_branch_steps: int = 1,
*,
time_level_loss: BaseLoss = MSELoss(),
cut_bptt: bool = False,
cut_bptt_every: int = 1,
cut_div_chain: bool = False,
time_level_weights: Optional[
Float[Array, num_rollout_steps]
] = None,
branch_level_weights: Optional[
Float[Array, num_branch_steps]
] = None
)
General diverted chain (rollout) configuration.
Contains the Supervised
configuration as special case of
num_branch_steps=num_rollout_steps
and the DivertedChainBranchOne
configuration as special case of num_branch_steps=1
.
Arguments:
num_rollout_steps
: The number of time steps to autoregressively roll out the model. Defaults to 1.num_branch_steps
: The number of time steps to branch off the main chain. Must be less than or equal tonum_rollout_steps
. Defaults to 1.time_level_loss
: The loss function to use at each time step. Defaults to MSELoss().cut_bptt
: Whether to cut the backpropagation through time (BPTT), i.e., insert ajax.lax.stop_gradient
into the autoregressive network main chain. Defaults to False.cut_bptt_every
: The frequency at which to cut the BPTT. Only relevant ifcut_bptt
is True. Defaults to 1 (meaning after each step).cut_div_chain
: Whether to cut the diverted chain, i.e., insert ajax.lax.stop_gradient
to not have cotangents flow over theref_stepper
. In this case, theref_stepper
does not have to be differentiable. Defaults to False.time_level_weights
: An array of lengthnum_rollout_steps
that contains the weights for each time step. Defaults to None, which means that all time steps have the same weight (=1.0).branch_level_weights
: An array of lengthnum_branch_steps
that contains the weights for each branch step. Defaults to None, which means that all branch steps have the same weight (=1.0).
Raises:
- ValueError: If
num_branch_steps
is greater thannum_rollout_steps
.
Info
- The
ref_stepper
is called on-the-fly. If its forward (and vjp) evaluation is expensive, this will dominate the computational cost of this configuration. - The usage of the
ref_stepper
includes the first branch starting from the initial condition. Hence, no reference trajectory is required. - Under reverse-mode automatic differentiation memory usage grows
with the product of
num_rollout_steps
andnum_branch_steps
.
Source code in trainax/configuration/_diverted_chain.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
|
__call__
¤
__call__(
stepper: eqx.Module,
data: PyTree[Float[Array, "batch num_snapshots ..."]],
*,
ref_stepper: eqx.Module,
residuum_fn: eqx.Module = None
) -> float
Evaluate the general diverted chain (rollout) configuration on the given data.
The data only has to contain one time level, the initial condition.
Arguments:
stepper
: The stepper to use for the configuration. Must have the signaturestepper(u_prev: PyTree) -> u_next: PyTree
.data
: The data to evaluate the configuration on. This depends on the concrete configuration. In this case, it only has to contain the set of initial states.ref_stepper
: The reference stepper to use for the diverted chain. This is called on-the-fly.residuum_fn
: For compatibility with other configurations; not used.
Returns:
- The loss value computed by this configuration.
Source code in trainax/configuration/_diverted_chain.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
|