Skip to content

Repeated Stepper¤

Use this to create steppers that perform substepping. To do this instantiate those with a subset of the desired dt. For example,

substepped_stepper = exponax.stepper.Burgers(1, 1.0, 64, 0.1/5)
stepper = exponax.stepper.RepeatedStepper(substepped_stepper, 5)

This will create a stepper that performs 5 substeps of 0.1/5=0.02 each time it is called.

exponax.RepeatedStepper ¤

Bases: Module

Source code in exponax/_repeated_stepper.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class RepeatedStepper(eqx.Module):
    num_spatial_dims: int
    domain_extent: float
    num_points: int
    num_channels: int
    dt: float
    dx: float

    stepper: BaseStepper
    num_sub_steps: int

    def __init__(
        self,
        stepper: BaseStepper,
        num_sub_steps: int,
    ):
        """
        Sugarcoat the utility function `repeat` in a callable PyTree for easy
        composition with other equinox modules.

        !!! info
            Performs the substepping in Fourier space to avoid unnecessary
            back-and-forth transformations.

        One intended usage is to get "more accurate" or "more stable" time steppers
        that perform substeps.

        The effective time step is `self.stepper.dt * self.num_sub_steps`. In order to
        get a time step of X with Y substeps, first instantiate a stepper with a
        time step of X/Y and then wrap it in a RepeatedStepper with num_sub_steps=Y.

        **Arguments:**

        - `stepper`: The stepper to repeat.
        - `num_sub_steps`: The number of substeps to perform.
        """
        self.stepper = stepper
        self.num_sub_steps = num_sub_steps

        self.dt = stepper.dt * num_sub_steps

        self.num_spatial_dims = stepper.num_spatial_dims
        self.domain_extent = stepper.domain_extent
        self.num_points = stepper.num_points
        self.num_channels = stepper.num_channels
        self.dx = stepper.dx

    def step(
        self,
        u: Float[Array, "C ... N"],
    ) -> Float[Array, "C ... N"]:
        """
        Step the PDE forward in time by `self.num_sub_steps` time steps given the
        current state `u`.

        !!! info
            Performs the substepping in Fourier space to avoid unnecessary
            back-and-forth transformations.

        **Arguments:**

        - `u`: The current state.

        **Returns:**

        - `u_next`: The state after `self.num_sub_steps` time steps.
        """
        u_hat = fft(u, num_spatial_dims=self.num_spatial_dims)
        u_hat_after_steps = self.step_fourier(u_hat)
        u_after_steps = ifft(
            u_hat_after_steps,
            num_spatial_dims=self.num_spatial_dims,
            num_points=self.num_points,
        )
        return u_after_steps

    def step_fourier(
        self,
        u_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        """
        Step the PDE forward in time by self.num_sub_steps time steps given the
        current state `u_hat` in real-valued Fourier space.

        **Arguments:**

        - `u_hat`: The current state in Fourier space.

        **Returns:**

        - `u_next_hat`: The state after `self.num_sub_steps` time steps in Fourier
            space.
        """
        return repeat(self.stepper.step_fourier, self.num_sub_steps)(u_hat)

    def __call__(
        self,
        u: Float[Array, "C ... N"],
    ) -> Float[Array, "C ... N"]:
        """
        Step the PDE forward in time by self.num_sub_steps time steps given the
        current state `u`.

        !!! info
            Performs the substepping in Fourier space to avoid unnecessary
            back-and-forth transformations.

        **Arguments:**

        - `u`: The current state.

        **Returns:**

        - `u_next`: The state after `self.num_sub_steps` time steps.

        !!! tip
            Use this call method together with `exponax.rollout` to efficiently
            produce temporal trajectories.

        !!! info
            For batched operation, use `jax.vmap` on this function.
        """
        expected_shape = (self.num_channels,) + spatial_shape(
            self.num_spatial_dims, self.num_points
        )
        if u.shape != expected_shape:
            raise ValueError(
                f"Expected shape {expected_shape}, got {u.shape}. For batched operation use `jax.vmap` on this function."
            )
        return self.step(u)
__init__ ¤
__init__(stepper: BaseStepper, num_sub_steps: int)

Sugarcoat the utility function repeat in a callable PyTree for easy composition with other equinox modules.

Info

Performs the substepping in Fourier space to avoid unnecessary back-and-forth transformations.

One intended usage is to get "more accurate" or "more stable" time steppers that perform substeps.

The effective time step is self.stepper.dt * self.num_sub_steps. In order to get a time step of X with Y substeps, first instantiate a stepper with a time step of X/Y and then wrap it in a RepeatedStepper with num_sub_steps=Y.

Arguments:

  • stepper: The stepper to repeat.
  • num_sub_steps: The number of substeps to perform.
Source code in exponax/_repeated_stepper.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(
    self,
    stepper: BaseStepper,
    num_sub_steps: int,
):
    """
    Sugarcoat the utility function `repeat` in a callable PyTree for easy
    composition with other equinox modules.

    !!! info
        Performs the substepping in Fourier space to avoid unnecessary
        back-and-forth transformations.

    One intended usage is to get "more accurate" or "more stable" time steppers
    that perform substeps.

    The effective time step is `self.stepper.dt * self.num_sub_steps`. In order to
    get a time step of X with Y substeps, first instantiate a stepper with a
    time step of X/Y and then wrap it in a RepeatedStepper with num_sub_steps=Y.

    **Arguments:**

    - `stepper`: The stepper to repeat.
    - `num_sub_steps`: The number of substeps to perform.
    """
    self.stepper = stepper
    self.num_sub_steps = num_sub_steps

    self.dt = stepper.dt * num_sub_steps

    self.num_spatial_dims = stepper.num_spatial_dims
    self.domain_extent = stepper.domain_extent
    self.num_points = stepper.num_points
    self.num_channels = stepper.num_channels
    self.dx = stepper.dx
__call__ ¤
__call__(
    u: Float[Array, "C ... N"]
) -> Float[Array, "C ... N"]

Step the PDE forward in time by self.num_sub_steps time steps given the current state u.

Info

Performs the substepping in Fourier space to avoid unnecessary back-and-forth transformations.

Arguments:

  • u: The current state.

Returns:

  • u_next: The state after self.num_sub_steps time steps.

Tip

Use this call method together with exponax.rollout to efficiently produce temporal trajectories.

Info

For batched operation, use jax.vmap on this function.

Source code in exponax/_repeated_stepper.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def __call__(
    self,
    u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
    """
    Step the PDE forward in time by self.num_sub_steps time steps given the
    current state `u`.

    !!! info
        Performs the substepping in Fourier space to avoid unnecessary
        back-and-forth transformations.

    **Arguments:**

    - `u`: The current state.

    **Returns:**

    - `u_next`: The state after `self.num_sub_steps` time steps.

    !!! tip
        Use this call method together with `exponax.rollout` to efficiently
        produce temporal trajectories.

    !!! info
        For batched operation, use `jax.vmap` on this function.
    """
    expected_shape = (self.num_channels,) + spatial_shape(
        self.num_spatial_dims, self.num_points
    )
    if u.shape != expected_shape:
        raise ValueError(
            f"Expected shape {expected_shape}, got {u.shape}. For batched operation use `jax.vmap` on this function."
        )
    return self.step(u)