Skip to content

General Polynomial Stepper¤

exponax.stepper.generic.GeneralPolynomialStepper ¤

Bases: BaseStepper

Source code in exponax/stepper/generic/_polynomial.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class GeneralPolynomialStepper(BaseStepper):
    linear_coefficients: tuple[float, ...]
    polynomial_coefficients: tuple[float, ...]
    dealiasing_fraction: float

    def __init__(
        self,
        num_spatial_dims: int,
        domain_extent: float,
        num_points: int,
        dt: float,
        *,
        linear_coefficients: tuple[float, ...] = (10.0, 0.0, 1.0),
        polynomial_coefficients: tuple[float, ...] = (0.0, 0.0, -10.0),
        order=2,
        dealiasing_fraction: float = 2 / 3,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        """
        Timestepper for the d-dimensional (`d ∈ {1, 2, 3}`) semi-linear PDEs
        consisting of an arbitrary combination of polynomial nonlinearities and
        (isotropic) linear derivatives. This can be used to represent a wide
        array of reaction-diffusion equations.

        In 1d, the PDE is of the form

        ```
            uₜ = ∑ₖ pₖ uᵏ + ∑ⱼ aⱼ uₓʲ
        ```

        where `pₖ` are the polynomial coefficients and `aⱼ` are the linear
        coefficients. `uᵏ` denotes `u` pointwise raised to the power of `k`
        (hence the polynomial contribution) and `uₓʲ` denotes the `j`-th
        derivative of `u`.

        The higher-dimensional generalization reads

        ```
            uₜ = ∑ₖ pₖ uᵏ + ∑ⱼ a_j (1⋅∇ʲ)u

        ```

        where `∇ʲ` is the `j`-th derivative operator.

        The default configuration corresponds to the Fisher-KPP equation with
        the following settings

        ```python

        exponax.stepper.reaction.FisherKPP(
            num_spatial_dims=num_spatial_dims, domain_extent=domain_extent,
            num_points=num_points, dt=dt, diffusivity=0.01, reactivity=-10.0,
            #TODO: Check this
        )
        ```

        Note that the effect of polynomial_scale[1] is similar to the effect of
        coefficients[0] with the difference that in ETDRK integration the latter
        is treated anlytically and should be preferred.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions `d`.
        - `domain_extent`: The size of the domain `L`; in higher dimensions
            the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
        - `num_points`: The number of points `N` used to discretize the
            domain. This **includes** the left boundary point and **excludes**
            the right boundary point. In higher dimensions; the number of points
            in each dimension is the same. Hence, the total number of degrees of
            freedom is `Nᵈ`.
        - `dt`: The timestep size `Δt` between two consecutive states.
        - `linear_coefficients`: The list of coefficients `a_j` corresponding to the
            derivatives. The length of this tuple represents the highest
            occuring derivative. The default value `(10.0, 0.0, 0.01)` in
            combination with the default `polynomial_coefficients` corresponds to the
            Fisher-KPP equation.
        - `polynomial_coefficients`: The list of scales `pₖ` corresponding to the
            polynomial contributions. The length of this tuple represents the
            highest occuring polynomial. The default value `(0.0, 0.0, 10.0)` in
            combination with the default `linear_coefficients` corresponds to the
            Fisher-KPP equation.
        - `order`: The order of the Exponential Time Differencing Runge
            Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
            solves the linear part of the equation. Use higher values for higher
            accuracy and stability. The default choice of `2` is a good
            compromise for single precision floats.
        - `dealiasing_fraction`: The fraction of the wavenumbers to keep
            before evaluating the nonlinearity. The default 2/3 corresponds to
            Orszag's 2/3 rule which is sufficient if the highest occuring
            polynomial is quadratic (i.e., there are at maximum three entries in
            the `polynomial_scales` tuple).
        - `num_circle_points`: How many points to use in the complex contour
            integral method to compute the coefficients of the exponential time
            differencing Runge Kutta method.
        - `circle_radius`: The radius of the contour used to compute the
            coefficients of the exponential time differencing Runge Kutta
            method.
        """
        self.linear_coefficients = linear_coefficients
        self.polynomial_coefficients = polynomial_coefficients
        self.dealiasing_fraction = dealiasing_fraction

        super().__init__(
            num_spatial_dims=num_spatial_dims,
            domain_extent=domain_extent,
            num_points=num_points,
            dt=dt,
            num_channels=1,
            order=order,
            num_circle_points=num_circle_points,
            circle_radius=circle_radius,
        )

    def _build_linear_operator(
        self,
        derivative_operator: Complex[Array, "D ... (N//2)+1"],
    ) -> Complex[Array, "1 ... (N//2)+1"]:
        linear_operator = sum(
            jnp.sum(
                c * (derivative_operator) ** i,
                axis=0,
                keepdims=True,
            )
            for i, c in enumerate(self.linear_coefficients)
        )
        return linear_operator

    def _build_nonlinear_fun(
        self,
        derivative_operator: Complex[Array, "D ... (N//2)+1"],
    ) -> PolynomialNonlinearFun:
        return PolynomialNonlinearFun(
            self.num_spatial_dims,
            self.num_points,
            dealiasing_fraction=self.dealiasing_fraction,
            coefficients=self.polynomial_coefficients,
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    domain_extent: float,
    num_points: int,
    dt: float,
    *,
    linear_coefficients: tuple[float, ...] = (
        10.0,
        0.0,
        1.0,
    ),
    polynomial_coefficients: tuple[float, ...] = (
        0.0,
        0.0,
        -10.0,
    ),
    order=2,
    dealiasing_fraction: float = 2 / 3,
    num_circle_points: int = 16,
    circle_radius: float = 1.0
)

Timestepper for the d-dimensional (d ∈ {1, 2, 3}) semi-linear PDEs consisting of an arbitrary combination of polynomial nonlinearities and (isotropic) linear derivatives. This can be used to represent a wide array of reaction-diffusion equations.

In 1d, the PDE is of the form

    uₜ = ∑ₖ pₖ uᵏ + ∑ⱼ aⱼ uₓʲ

where pₖ are the polynomial coefficients and aⱼ are the linear coefficients. uᵏ denotes u pointwise raised to the power of k (hence the polynomial contribution) and uₓʲ denotes the j-th derivative of u.

The higher-dimensional generalization reads

    uₜ = ∑ₖ pₖ uᵏ + ∑ⱼ a_j (1⋅∇ʲ)u

where ∇ʲ is the j-th derivative operator.

The default configuration corresponds to the Fisher-KPP equation with the following settings

exponax.stepper.reaction.FisherKPP(
    num_spatial_dims=num_spatial_dims, domain_extent=domain_extent,
    num_points=num_points, dt=dt, diffusivity=0.01, reactivity=-10.0,
    #TODO: Check this
)

Note that the effect of polynomial_scale[1] is similar to the effect of coefficients[0] with the difference that in ETDRK integration the latter is treated anlytically and should be preferred.

Arguments:

  • num_spatial_dims: The number of spatial dimensions d.
  • domain_extent: The size of the domain L; in higher dimensions the domain is assumed to be a scaled hypercube Ω = (0, L)ᵈ.
  • num_points: The number of points N used to discretize the domain. This includes the left boundary point and excludes the right boundary point. In higher dimensions; the number of points in each dimension is the same. Hence, the total number of degrees of freedom is Nᵈ.
  • dt: The timestep size Δt between two consecutive states.
  • linear_coefficients: The list of coefficients a_j corresponding to the derivatives. The length of this tuple represents the highest occuring derivative. The default value (10.0, 0.0, 0.01) in combination with the default polynomial_coefficients corresponds to the Fisher-KPP equation.
  • polynomial_coefficients: The list of scales pₖ corresponding to the polynomial contributions. The length of this tuple represents the highest occuring polynomial. The default value (0.0, 0.0, 10.0) in combination with the default linear_coefficients corresponds to the Fisher-KPP equation.
  • order: The order of the Exponential Time Differencing Runge Kutta method. Must be one of {0, 1, 2, 3, 4}. The option 0 only solves the linear part of the equation. Use higher values for higher accuracy and stability. The default choice of 2 is a good compromise for single precision floats.
  • dealiasing_fraction: The fraction of the wavenumbers to keep before evaluating the nonlinearity. The default 2/3 corresponds to Orszag's 2/3 rule which is sufficient if the highest occuring polynomial is quadratic (i.e., there are at maximum three entries in the polynomial_scales tuple).
  • num_circle_points: How many points to use in the complex contour integral method to compute the coefficients of the exponential time differencing Runge Kutta method.
  • circle_radius: The radius of the contour used to compute the coefficients of the exponential time differencing Runge Kutta method.
Source code in exponax/stepper/generic/_polynomial.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def __init__(
    self,
    num_spatial_dims: int,
    domain_extent: float,
    num_points: int,
    dt: float,
    *,
    linear_coefficients: tuple[float, ...] = (10.0, 0.0, 1.0),
    polynomial_coefficients: tuple[float, ...] = (0.0, 0.0, -10.0),
    order=2,
    dealiasing_fraction: float = 2 / 3,
    num_circle_points: int = 16,
    circle_radius: float = 1.0,
):
    """
    Timestepper for the d-dimensional (`d ∈ {1, 2, 3}`) semi-linear PDEs
    consisting of an arbitrary combination of polynomial nonlinearities and
    (isotropic) linear derivatives. This can be used to represent a wide
    array of reaction-diffusion equations.

    In 1d, the PDE is of the form

    ```
        uₜ = ∑ₖ pₖ uᵏ + ∑ⱼ aⱼ uₓʲ
    ```

    where `pₖ` are the polynomial coefficients and `aⱼ` are the linear
    coefficients. `uᵏ` denotes `u` pointwise raised to the power of `k`
    (hence the polynomial contribution) and `uₓʲ` denotes the `j`-th
    derivative of `u`.

    The higher-dimensional generalization reads

    ```
        uₜ = ∑ₖ pₖ uᵏ + ∑ⱼ a_j (1⋅∇ʲ)u

    ```

    where `∇ʲ` is the `j`-th derivative operator.

    The default configuration corresponds to the Fisher-KPP equation with
    the following settings

    ```python

    exponax.stepper.reaction.FisherKPP(
        num_spatial_dims=num_spatial_dims, domain_extent=domain_extent,
        num_points=num_points, dt=dt, diffusivity=0.01, reactivity=-10.0,
        #TODO: Check this
    )
    ```

    Note that the effect of polynomial_scale[1] is similar to the effect of
    coefficients[0] with the difference that in ETDRK integration the latter
    is treated anlytically and should be preferred.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions `d`.
    - `domain_extent`: The size of the domain `L`; in higher dimensions
        the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
    - `num_points`: The number of points `N` used to discretize the
        domain. This **includes** the left boundary point and **excludes**
        the right boundary point. In higher dimensions; the number of points
        in each dimension is the same. Hence, the total number of degrees of
        freedom is `Nᵈ`.
    - `dt`: The timestep size `Δt` between two consecutive states.
    - `linear_coefficients`: The list of coefficients `a_j` corresponding to the
        derivatives. The length of this tuple represents the highest
        occuring derivative. The default value `(10.0, 0.0, 0.01)` in
        combination with the default `polynomial_coefficients` corresponds to the
        Fisher-KPP equation.
    - `polynomial_coefficients`: The list of scales `pₖ` corresponding to the
        polynomial contributions. The length of this tuple represents the
        highest occuring polynomial. The default value `(0.0, 0.0, 10.0)` in
        combination with the default `linear_coefficients` corresponds to the
        Fisher-KPP equation.
    - `order`: The order of the Exponential Time Differencing Runge
        Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
        solves the linear part of the equation. Use higher values for higher
        accuracy and stability. The default choice of `2` is a good
        compromise for single precision floats.
    - `dealiasing_fraction`: The fraction of the wavenumbers to keep
        before evaluating the nonlinearity. The default 2/3 corresponds to
        Orszag's 2/3 rule which is sufficient if the highest occuring
        polynomial is quadratic (i.e., there are at maximum three entries in
        the `polynomial_scales` tuple).
    - `num_circle_points`: How many points to use in the complex contour
        integral method to compute the coefficients of the exponential time
        differencing Runge Kutta method.
    - `circle_radius`: The radius of the contour used to compute the
        coefficients of the exponential time differencing Runge Kutta
        method.
    """
    self.linear_coefficients = linear_coefficients
    self.polynomial_coefficients = polynomial_coefficients
    self.dealiasing_fraction = dealiasing_fraction

    super().__init__(
        num_spatial_dims=num_spatial_dims,
        domain_extent=domain_extent,
        num_points=num_points,
        dt=dt,
        num_channels=1,
        order=order,
        num_circle_points=num_circle_points,
        circle_radius=circle_radius,
    )
__call__ ¤
__call__(
    u: Float[Array, "C ... N"]
) -> Float[Array, "C ... N"]

Perform one step of the time integration for a single state.

Arguments:

  • u: The state vector, shape (C, ..., N,).

Returns:

  • u_next: The state vector after one step, shape (C, ..., N,).

Tip

Use this call method together with exponax.rollout to efficiently produce temporal trajectories.

Info

For batched operation, use jax.vmap on this function.

Source code in exponax/_base_stepper.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def __call__(
    self,
    u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
    """
    Perform one step of the time integration for a single state.

    **Arguments:**

    - `u`: The state vector, shape `(C, ..., N,)`.

    **Returns:**

    - `u_next`: The state vector after one step, shape `(C, ..., N,)`.

    !!! tip
        Use this call method together with `exponax.rollout` to efficiently
        produce temporal trajectories.

    !!! info
        For batched operation, use `jax.vmap` on this function.
    """
    expected_shape = (self.num_channels,) + spatial_shape(
        self.num_spatial_dims, self.num_points
    )
    if u.shape != expected_shape:
        raise ValueError(
            f"Expected shape {expected_shape}, got {u.shape}. For batched operation use `jax.vmap` on this function."
        )
    return self.step(u)