Skip to content

Temporal Evolution¤

Utilities to autoregressively evaluate steppers.

exponax.repeat ¤

repeat(
    stepper_fn: Union[
        Callable[[PyTree], PyTree],
        Callable[[PyTree, PyTree], PyTree],
    ],
    n: int,
    *,
    takes_aux: bool = False,
    constant_aux: bool = True
)

Transform a stepper function into a function that autoregressively (i.e., recursively applied to its own output) applies the stepper n times and returns the final state.

Based on takes_aux, the stepper function is either fully automomous, just mapping state to state, or takes an additional auxiliary input. This can be a force/control or additional metadata (like physical parameters, or time for non-autonomous systems).

Args: - stepper_fn: The time stepper to transform. If takes_aux = False (default), expected signature is u_next = stepper_fn(u), else u_next = stepper_fn(u, aux). u and u_next need to be PyTrees of identical structure, in the easiest case just arrays of same shape. - n: The number of times to apply the stepper. - takes_aux: Whether the stepper function takes an additional PyTree as second argument. - constant_aux: Whether the auxiliary input is constant over the trajectory. If True, the auxiliary input is repeated n times, otherwise the leading axis in the PyTree arrays has to be of length n.

Returns: - repeated_stepper_fn: A function that takes an initial condition u_0 and an auxiliary input aux (if takes_aux = True) and produces the final state by autoregressively applying the stepper n times. Returns a PyTree of the same structure as the initial condition.

Source code in exponax/_utils.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def repeat(
    stepper_fn: Union[Callable[[PyTree], PyTree], Callable[[PyTree, PyTree], PyTree]],
    n: int,
    *,
    takes_aux: bool = False,
    constant_aux: bool = True,
):
    """
    Transform a stepper function into a function that autoregressively (i.e.,
    recursively applied to its own output) applies the stepper `n` times and
    returns the final state.

    Based on `takes_aux`, the stepper function is either fully automomous, just
    mapping state to state, or takes an additional auxiliary input. This can be
    a force/control or additional metadata (like physical parameters, or time
    for non-autonomous systems).

    Args:
        - `stepper_fn`: The time stepper to transform. If `takes_aux = False`
            (default), expected signature is `u_next = stepper_fn(u)`, else
            `u_next = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees
            of identical structure, in the easiest case just arrays of same
            shape.
        - `n`: The number of times to apply the stepper.
        - `takes_aux`: Whether the stepper function takes an additional PyTree
            as second argument.
        - `constant_aux`: Whether the auxiliary input is constant over the
            trajectory. If `True`, the auxiliary input is repeated `n` times,
            otherwise the leading axis in the PyTree arrays has to be of length
            `n`.

    Returns:
        - `repeated_stepper_fn`: A function that takes an initial condition
            `u_0` and an auxiliary input `aux` (if `takes_aux = True`) and
            produces the final state by autoregressively applying the stepper
            `n` times. Returns a PyTree of the same structure as the initial
            condition.
    """

    if takes_aux:

        def scan_fn(u, aux):
            u_next = stepper_fn(u, aux)
            return u_next, None

        def repeated_stepper_fn(u_0, aux):
            if constant_aux:
                aux = jtu.tree_map(
                    lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), n, axis=0), aux
                )

            final, _ = jax.lax.scan(scan_fn, u_0, aux, length=n)
            return final

        return repeated_stepper_fn

    else:

        def scan_fn(u, _):
            u_next = stepper_fn(u)
            return u_next, None

        def repeated_stepper_fn(u_0):
            final, _ = jax.lax.scan(scan_fn, u_0, None, length=n)
            return final

        return repeated_stepper_fn

exponax.rollout ¤

rollout(
    stepper_fn: Union[
        Callable[[PyTree], PyTree],
        Callable[[PyTree, PyTree], PyTree],
    ],
    n: int,
    *,
    include_init: bool = False,
    takes_aux: bool = False,
    constant_aux: bool = True
)

Transform a stepper function into a function that autoregressively (i.e., recursively applied to its own output) produces a trajectory of length n.

Based on takes_aux, the stepper function is either fully automomous, just mapping state to state, or takes an additional auxiliary input. This can be a force/control or additional metadata (like physical parameters, or time for non-autonomous systems).

Args: - stepper_fn: The time stepper to transform. If takes_aux = False (default), expected signature is u_next = stepper_fn(u), else u_next = stepper_fn(u, aux). u and u_next need to be PyTrees of identical structure, in the easiest case just arrays of same shape. - n: The number of time steps to rollout the trajectory into the future. If include_init = False (default) produces the n steps into the future. - include_init: Whether to include the initial condition in the trajectory. If True, the arrays in the returning PyTree have shape (n + 1, ...), else (n, ...). Default: False. - takes_aux: Whether the stepper function takes an additional PyTree as second argument. - constant_aux: Whether the auxiliary input is constant over the trajectory. If True, the auxiliary input is repeated n times, otherwise the leading axis in the PyTree arrays has to be of length n.

Returns: - rollout_stepper_fn: A function that takes an initial condition u_0 and an auxiliary input aux (if takes_aux = True) and produces the trajectory by autoregressively applying the stepper n times. If include_init = True, the trajectory has shape (n + 1, ...), else (n, ...). Returns a PyTree of the same structure as the initial condition, but with an additional leading axis of length n.

Source code in exponax/_utils.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def rollout(
    stepper_fn: Union[Callable[[PyTree], PyTree], Callable[[PyTree, PyTree], PyTree]],
    n: int,
    *,
    include_init: bool = False,
    takes_aux: bool = False,
    constant_aux: bool = True,
):
    """
    Transform a stepper function into a function that autoregressively (i.e.,
    recursively applied to its own output) produces a trajectory of length `n`.

    Based on `takes_aux`, the stepper function is either fully automomous, just
    mapping state to state, or takes an additional auxiliary input. This can be
    a force/control or additional metadata (like physical parameters, or time
    for non-autonomous systems).

    Args:
        - `stepper_fn`: The time stepper to transform. If `takes_aux = False`
            (default), expected signature is `u_next = stepper_fn(u)`, else
            `u_next = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees
            of identical structure, in the easiest case just arrays of same
            shape.
        - `n`: The number of time steps to rollout the trajectory into the
            future. If `include_init = False` (default) produces the `n` steps
            into the future.
        - `include_init`: Whether to include the initial condition in the
            trajectory. If `True`, the arrays in the returning PyTree have shape
            `(n + 1, ...)`, else `(n, ...)`. Default: `False`.
        - `takes_aux`: Whether the stepper function takes an additional PyTree
            as second argument.
        - `constant_aux`: Whether the auxiliary input is constant over the
            trajectory. If `True`, the auxiliary input is repeated `n` times,
            otherwise the leading axis in the PyTree arrays has to be of length
            `n`.

    Returns:
        - `rollout_stepper_fn`: A function that takes an initial condition `u_0`
            and an auxiliary input `aux` (if `takes_aux = True`) and produces
            the trajectory by autoregressively applying the stepper `n` times.
            If `include_init = True`, the trajectory has shape `(n + 1, ...)`,
            else `(n, ...)`. Returns a PyTree of the same structure as the
            initial condition, but with an additional leading axis of length
            `n`.
    """

    if takes_aux:

        def scan_fn(u, aux):
            u_next = stepper_fn(u, aux)
            return u_next, u_next

        def rollout_stepper_fn(u_0, aux):
            if constant_aux:
                aux = jtu.tree_map(
                    lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), n, axis=0), aux
                )

            _, trj = jax.lax.scan(scan_fn, u_0, aux, length=n)

            if include_init:
                trj_with_init = jtu.tree_map(
                    lambda init, history: jnp.concatenate(
                        [jnp.expand_dims(init, axis=0), history], axis=0
                    ),
                    u_0,
                    trj,
                )
                return trj_with_init
            else:
                return trj

        return rollout_stepper_fn

    else:

        def scan_fn(u, _):
            u_next = stepper_fn(u)
            return u_next, u_next

        def rollout_stepper_fn(u_0):
            _, trj = jax.lax.scan(scan_fn, u_0, None, length=n)

            if include_init:
                trj_with_init = jtu.tree_map(
                    lambda init, history: jnp.concatenate(
                        [jnp.expand_dims(init, axis=0), history], axis=0
                    ),
                    u_0,
                    trj,
                )
                return trj_with_init
            else:
                return trj

        return rollout_stepper_fn

exponax.stack_sub_trajectories ¤

stack_sub_trajectories(
    trj: PyTree[Float[Array, "n_timesteps ..."]],
    sub_len: int,
) -> PyTree[Float[Array, "n_stacks sub_len ..."]]

Slice a trajectory into subtrajectories of length n and stack them together. Useful for rollout training neural operators with temporal mixing.

!!! Note that this function can produce very large arrays.

Arguments: - trj: The trajectory to slice. Expected shape: (n_timesteps, ...). - sub_len: The length of the subtrajectories. If you want to perform rollout training with k steps, note that n=k+1 to also have an initial condition in the subtrajectories.

Returns: - sub_trjs: The stacked subtrajectories. Expected shape: (n_stacks, n, ...). n_stacks is the number of subtrajectories stacked together, i.e., n_timesteps - n + 1.

Source code in exponax/_utils.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def stack_sub_trajectories(
    trj: PyTree[Float[Array, "n_timesteps ..."]],
    sub_len: int,
) -> PyTree[Float[Array, "n_stacks sub_len ..."]]:
    """
    Slice a trajectory into subtrajectories of length `n` and stack them
    together. Useful for rollout training neural operators with temporal mixing.

    !!! Note that this function can produce very large arrays.

    **Arguments:**
        - `trj`: The trajectory to slice. Expected shape: `(n_timesteps, ...)`.
        - `sub_len`: The length of the subtrajectories. If you want to perform rollout
            training with k steps, note that `n=k+1` to also have an initial
            condition in the subtrajectories.

    **Returns:**
        - `sub_trjs`: The stacked subtrajectories. Expected shape: `(n_stacks, n, ...)`.
           `n_stacks` is the number of subtrajectories stacked together, i.e.,
           `n_timesteps - n + 1`.
    """
    n_time_steps = [leaf.shape[0] for leaf in jtu.tree_leaves(trj)]

    if len(set(n_time_steps)) != 1:
        raise ValueError(
            "All arrays in trj must have the same number of time steps in the leading axis"
        )
    else:
        n_time_steps = n_time_steps[0]

    if sub_len > n_time_steps:
        raise ValueError(
            "n must be smaller than or equal to the number of time steps in trj"
        )

    n_sub_trjs = n_time_steps - sub_len + 1

    def scan_fn(_, i):
        sliced = jtu.tree_map(
            lambda leaf: jax.lax.dynamic_slice_in_dim(
                leaf,
                start_index=i,
                slice_size=sub_len,
                axis=0,
            ),
            trj,
        )
        return _, sliced

    _, sub_trjs = jax.lax.scan(scan_fn, None, jnp.arange(n_sub_trjs))

    return sub_trjs