Trajectory Mixing¤
trainax.TrajectorySubStacker
¤
Bases: Module
Source code in trainax/_mixer.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
|
__init__
¤
__init__(
data_trajectories: PyTree[
Float[Array, "num_samples trj_len ..."]
],
sub_trajectory_len: int,
*,
do_sub_stacking: bool = True,
only_store_ic: bool = False
)
Slice a batch of trajectories into sub-trajectories.
Useful to create windows of specific length for (unrolled) training methodologies of autoregressive neural emulators.
Arguments:
data_trajectories
: The batch of trajectories to slice. This must be a PyTree of Arrays who have at least two leading axes: a batch-axis and a time axis. For example, the zeroth axis can be associated with multiple initial conditions or constitutive parameters and the first axis represents all temporal snapshots. A PyTree can also just be an array. You can provide additional leafs in the PyTree, e.g., for the corresponding constitutive parameters etc. Make sure that the emulator has the corresponding signature.sub_trajectory_len
: The length of the sub-trajectories. This must be smaller equal to the length of the trajectories (trj_len
). For unrolled training witht
steps, set this tot+1
to include the necessary initial condition.do_sub_stacking
: Whether to slice out all possible (overlapping) windows out of thetrj_len
or just slice thetrj_len
axis from0:sub_trajectory_len
.only_store_ic
: Whether to only store the initial condition of the sub-trajectories. This can be helpful for configurations that do not need the reference trajectory like residuum-based learning strategies.
Info
- Since the windows sliced out are overlapping, the produces
internal array can be large, especially if
sub_trajectory_len
is large. Certainly, this is not the most memory-efficient solution but is sufficient if your problem easily fits into memory. Consider overwriting this class with a more memory efficient implementation if you run into memory issues.
Source code in trainax/_mixer.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
|
__call__
¤
__call__(
indices: slice,
) -> PyTree[Float[Array, "len(indices) sub_trj_len ..."]]
Slice out sub-samples based on the given indices.
Arguments:
indices
: The indices to slice out the sub-trajectories, e.g., this can be[0, 4, 5]
to slice out the zeroth, fourth, and fifth sub-trajectories or it can be aslice
object.
Returns:
PyTree[Float[Array, "len(indices) sub_trj_len ..."]]
: The sliced sub-trajectories.
Source code in trainax/_mixer.py
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
|
trainax.PermutationMixer
¤
Bases: Module
Source code in trainax/_mixer.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
|
__init__
¤
__init__(
num_total_samples: int,
num_minibatches: int,
batch_size: int,
shuffle_key: PRNGKeyArray,
)
Precompute permuations for a given number of minibatches within a dataset. Automatically determines the number of necessary epochs (runs over the entire dataset). Upon calling returns a collection of indices to produce a new minibatch.
If the remainder minibatch in one epoch is smaller than the batch size, it will not be extended using data from the next epoch, but returned as smaller list of indices.
Arguments:
num_total_samples
: The total number of samples in the dataset.num_minibatches
: The size of minibatches to train on.batch_size
: The size of the minibatches.shuffle_key
: The key to create the permutation; needed for deterministic reproducibility.
Warning
ValueError: If the batch size is larger than the total number of samples for one epoch.
Source code in trainax/_mixer.py
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
|
__call__
¤
__call__(
i: int, *, return_info: bool = False
) -> Int[Array, batch_size]
Given the batch index i
, return the corresponding indices to slice out
the minibatch.
Arguments:
i
: The batch index.return_info
: Whether to return additional information about the current epoch and batch index.
Returns:
- The indices to slice out the minibatch in form of an array of integers.
Warning
ValueError: If the batch index is larger than the number of minibatches (because likely there will be no permuation for it)
Source code in trainax/_mixer.py
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
|