Skip to content

Repeated Stepper¤

Use this to create steppers that perform substepping. To do this instantiate those with a subset of the desired dt. For example,

substepped_stepper = exponax.stepper.Burgers(1, 1.0, 64, 0.1/5)
stepper = exponax.stepper.RepeatedStepper(substepped_stepper, 5)

This will create a stepper that performs 5 substeps of 0.1/5=0.02 each time it is called.

exponax.RepeatedStepper ¤

Bases: Module

Source code in exponax/_repeated_stepper.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class RepeatedStepper(eqx.Module):
    num_spatial_dims: int
    domain_extent: float
    num_points: int
    num_channels: int
    dt: float
    dx: float

    stepper: BaseStepper
    num_sub_steps: int

    def __init__(
        self,
        stepper: BaseStepper,
        num_sub_steps: int,
    ):
        """
        Sugarcoat the utility function `repeat` in a callable PyTree for easy
        composition with other equinox modules.

        One intended usage is to get "more accurate" or "more stable" time steppers
        that perform substeps.

        The effective time step is `self.stepper.dt * self.num_sub_steps`. In order to
        get a time step of X with Y substeps, first instantiate a stepper with a
        time step of X/Y and then wrap it in a RepeatedStepper with num_sub_steps=Y.

        **Arguments:**
            - `stepper`: The stepper to repeat.
            - `num_sub_steps`: The number of substeps to perform.
        """
        self.stepper = stepper
        self.num_sub_steps = num_sub_steps

        self.dt = stepper.dt * num_sub_steps

        self.num_spatial_dims = stepper.num_spatial_dims
        self.domain_extent = stepper.domain_extent
        self.num_points = stepper.num_points
        self.num_channels = stepper.num_channels
        self.dx = stepper.dx

    def step(
        self,
        u: Float[Array, "C ... N"],
    ) -> Float[Array, "C ... N"]:
        """
        Step the PDE forward in time by self.num_sub_steps time steps given the
        current state `u`.
        """
        return repeat(self.stepper.step, self.num_sub_steps)(u)

    def step_fourier(
        self,
        u_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        """
        Step the PDE forward in time by self.num_sub_steps time steps given the
        current state `u_hat` in real-valued Fourier space.
        """
        return repeat(self.stepper.step_fourier, self.num_sub_steps)(u_hat)

    def __call__(
        self,
        u: Float[Array, "C ... N"],
    ) -> Float[Array, "C ... N"]:
        """
        Step the PDE forward in time by self.num_sub_steps time steps given the
        current state `u`.
        """
        return repeat(self.stepper, self.num_sub_steps)(u)
__init__ ¤
__init__(stepper: BaseStepper, num_sub_steps: int)

Sugarcoat the utility function repeat in a callable PyTree for easy composition with other equinox modules.

One intended usage is to get "more accurate" or "more stable" time steppers that perform substeps.

The effective time step is self.stepper.dt * self.num_sub_steps. In order to get a time step of X with Y substeps, first instantiate a stepper with a time step of X/Y and then wrap it in a RepeatedStepper with num_sub_steps=Y.

Arguments: - stepper: The stepper to repeat. - num_sub_steps: The number of substeps to perform.

Source code in exponax/_repeated_stepper.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
    self,
    stepper: BaseStepper,
    num_sub_steps: int,
):
    """
    Sugarcoat the utility function `repeat` in a callable PyTree for easy
    composition with other equinox modules.

    One intended usage is to get "more accurate" or "more stable" time steppers
    that perform substeps.

    The effective time step is `self.stepper.dt * self.num_sub_steps`. In order to
    get a time step of X with Y substeps, first instantiate a stepper with a
    time step of X/Y and then wrap it in a RepeatedStepper with num_sub_steps=Y.

    **Arguments:**
        - `stepper`: The stepper to repeat.
        - `num_sub_steps`: The number of substeps to perform.
    """
    self.stepper = stepper
    self.num_sub_steps = num_sub_steps

    self.dt = stepper.dt * num_sub_steps

    self.num_spatial_dims = stepper.num_spatial_dims
    self.domain_extent = stepper.domain_extent
    self.num_points = stepper.num_points
    self.num_channels = stepper.num_channels
    self.dx = stepper.dx
__call__ ¤
__call__(
    u: Float[Array, "C ... N"]
) -> Float[Array, "C ... N"]

Step the PDE forward in time by self.num_sub_steps time steps given the current state u.

Source code in exponax/_repeated_stepper.py
70
71
72
73
74
75
76
77
78
def __call__(
    self,
    u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
    """
    Step the PDE forward in time by self.num_sub_steps time steps given the
    current state `u`.
    """
    return repeat(self.stepper, self.num_sub_steps)(u)