Skip to content

Dispersionยค

In 1D:

\[ \frac{\partial u}{\partial t} = \xi \frac{\partial^3 u}{\partial x^3} \]

In higher dimensions:

\[ \frac{\partial u}{\partial t} = \xi \nabla \cdot (\nabla \odot \nabla) u \]

or with spatial mixing:

\[ \frac{\partial u}{\partial t} = \xi (1 \cdot \nabla) (\nabla \cdot \nabla) u \]

exponax.stepper.Dispersion ยค

Bases: BaseStepper

Source code in exponax/stepper/_dispersion.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class Dispersion(BaseStepper):
    dispersivity: Float[Array, "D"]
    advect_on_diffusion: bool

    def __init__(
        self,
        num_spatial_dims: int,
        domain_extent: float,
        num_points: int,
        dt: float,
        *,
        dispersivity: Union[Float[Array, "D"], float] = 1.0,
        advect_on_diffusion: bool = False,
    ):
        """
        Timestepper for the d-dimensional (`d โˆˆ {1, 2, 3}`) dispersion equation
        on periodic boundary conditions. Essentially, a dispersion equation is
        an advection equation with different velocities (=advection speeds) for
        different wavenumbers/modes. Higher wavenumbers/modes are advected
        faster.

        In 1d, the dispersion equation is given by

        ```
            uโ‚œ = ๐’ธ uโ‚“โ‚“โ‚“
        ```

        with `๐’ธ โˆˆ โ„` being the dispersivity.

        In higher dimensions, the dispersion equation can be written as

        ```
            uโ‚œ = ๐’ธ โ‹… (โˆ‡โŠ™โˆ‡โŠ™(โˆ‡u))
        ```

        or

        ```
            uโ‚œ = ๐’ธ โ‹… โˆ‡(ฮ”u)
        ```

        with `๐’ธ โˆˆ โ„แตˆ` being the dispersivity vector

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions `d`.
        - `domain_extent`: The size of the domain `L`; in higher dimensions
            the domain is assumed to be a scaled hypercube `ฮฉ = (0, L)แตˆ`.
        - `num_points`: The number of points `N` used to discretize the
            domain. This **includes** the left boundary point and **excludes**
            the right boundary point. In higher dimensions; the number of points
            in each dimension is the same. Hence, the total number of degrees of
            freedom is `Nแตˆ`.
        - `dt`: The timestep size `ฮ”t` between two consecutive states.
        - `dispersivity` (keyword-only): The dispersivity `๐’ธ`. In higher
            dimensions, this can be a scalar (=float) or a vector of length `d`.
            If a scalar is given, the dispersivity is assumed to be the same in
            all spatial dimensions. Default: `1.0`.
        - `advect_on_diffusion` (keyword-only): If `True`, the second form
            of the dispersion equation in higher dimensions is used. As a
            consequence, there will be mixing in the spatial derivatives.
            Default: `False`.

        **Notes:**

        - The stepper is unconditionally stable, no matter the choice of
            any argument because the equation is solved analytically in Fourier
            space. **However**, note that initial conditions with modes higher
            than the Nyquist freuency (`(N//2)+1` with `N` being the
            `num_points`) lead to spurious oscillations.
        - Ultimately, only the factor `๐’ธ ฮ”t / Lยณ` affects the
            characteristic of the dynamics. See also
            [`exponax.stepper.generic.NormalizedLinearStepper`][] with
            `normalized_coefficients = [0, 0, 0, alpha_3]` with `alpha_3 =
            dispersivity * dt / domain_extent**3`.
        """
        if isinstance(dispersivity, float):
            dispersivity = jnp.ones(num_spatial_dims) * dispersivity
        self.dispersivity = dispersivity
        self.advect_on_diffusion = advect_on_diffusion
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            domain_extent=domain_extent,
            num_points=num_points,
            dt=dt,
            num_channels=1,
            order=0,
        )

    def _build_linear_operator(
        self,
        derivative_operator: Complex[Array, "D ... (N//2)+1"],
    ) -> Complex[Array, "1 ... (N//2)+1"]:
        if self.advect_on_diffusion:
            laplace_operator = build_laplace_operator(derivative_operator)
            advection_operator = build_gradient_inner_product_operator(
                derivative_operator, self.dispersivity, order=1
            )
            linear_operator = advection_operator * laplace_operator
        else:
            linear_operator = build_gradient_inner_product_operator(
                derivative_operator, self.dispersivity, order=3
            )

        return linear_operator

    def _build_nonlinear_fun(
        self,
        derivative_operator: Complex[Array, "D ... (N//2)+1"],
    ) -> ZeroNonlinearFun:
        return ZeroNonlinearFun(
            self.num_spatial_dims,
            self.num_points,
        )
__init__ ยค
__init__(
    num_spatial_dims: int,
    domain_extent: float,
    num_points: int,
    dt: float,
    *,
    dispersivity: Union[Float[Array, D], float] = 1.0,
    advect_on_diffusion: bool = False
)

Timestepper for the d-dimensional (d โˆˆ {1, 2, 3}) dispersion equation on periodic boundary conditions. Essentially, a dispersion equation is an advection equation with different velocities (=advection speeds) for different wavenumbers/modes. Higher wavenumbers/modes are advected faster.

In 1d, the dispersion equation is given by

    uโ‚œ = ๐’ธ uโ‚“โ‚“โ‚“

with ๐’ธ โˆˆ โ„ being the dispersivity.

In higher dimensions, the dispersion equation can be written as

    uโ‚œ = ๐’ธ โ‹… (โˆ‡โŠ™โˆ‡โŠ™(โˆ‡u))

or

    uโ‚œ = ๐’ธ โ‹… โˆ‡(ฮ”u)

with ๐’ธ โˆˆ โ„แตˆ being the dispersivity vector

Arguments:

  • num_spatial_dims: The number of spatial dimensions d.
  • domain_extent: The size of the domain L; in higher dimensions the domain is assumed to be a scaled hypercube ฮฉ = (0, L)แตˆ.
  • num_points: The number of points N used to discretize the domain. This includes the left boundary point and excludes the right boundary point. In higher dimensions; the number of points in each dimension is the same. Hence, the total number of degrees of freedom is Nแตˆ.
  • dt: The timestep size ฮ”t between two consecutive states.
  • dispersivity (keyword-only): The dispersivity ๐’ธ. In higher dimensions, this can be a scalar (=float) or a vector of length d. If a scalar is given, the dispersivity is assumed to be the same in all spatial dimensions. Default: 1.0.
  • advect_on_diffusion (keyword-only): If True, the second form of the dispersion equation in higher dimensions is used. As a consequence, there will be mixing in the spatial derivatives. Default: False.

Notes:

  • The stepper is unconditionally stable, no matter the choice of any argument because the equation is solved analytically in Fourier space. However, note that initial conditions with modes higher than the Nyquist freuency ((N//2)+1 with N being the num_points) lead to spurious oscillations.
  • Ultimately, only the factor ๐’ธ ฮ”t / Lยณ affects the characteristic of the dynamics. See also exponax.stepper.generic.NormalizedLinearStepper with normalized_coefficients = [0, 0, 0, alpha_3] with alpha_3 = dispersivity * dt / domain_extent**3.
Source code in exponax/stepper/_dispersion.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def __init__(
    self,
    num_spatial_dims: int,
    domain_extent: float,
    num_points: int,
    dt: float,
    *,
    dispersivity: Union[Float[Array, "D"], float] = 1.0,
    advect_on_diffusion: bool = False,
):
    """
    Timestepper for the d-dimensional (`d โˆˆ {1, 2, 3}`) dispersion equation
    on periodic boundary conditions. Essentially, a dispersion equation is
    an advection equation with different velocities (=advection speeds) for
    different wavenumbers/modes. Higher wavenumbers/modes are advected
    faster.

    In 1d, the dispersion equation is given by

    ```
        uโ‚œ = ๐’ธ uโ‚“โ‚“โ‚“
    ```

    with `๐’ธ โˆˆ โ„` being the dispersivity.

    In higher dimensions, the dispersion equation can be written as

    ```
        uโ‚œ = ๐’ธ โ‹… (โˆ‡โŠ™โˆ‡โŠ™(โˆ‡u))
    ```

    or

    ```
        uโ‚œ = ๐’ธ โ‹… โˆ‡(ฮ”u)
    ```

    with `๐’ธ โˆˆ โ„แตˆ` being the dispersivity vector

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions `d`.
    - `domain_extent`: The size of the domain `L`; in higher dimensions
        the domain is assumed to be a scaled hypercube `ฮฉ = (0, L)แตˆ`.
    - `num_points`: The number of points `N` used to discretize the
        domain. This **includes** the left boundary point and **excludes**
        the right boundary point. In higher dimensions; the number of points
        in each dimension is the same. Hence, the total number of degrees of
        freedom is `Nแตˆ`.
    - `dt`: The timestep size `ฮ”t` between two consecutive states.
    - `dispersivity` (keyword-only): The dispersivity `๐’ธ`. In higher
        dimensions, this can be a scalar (=float) or a vector of length `d`.
        If a scalar is given, the dispersivity is assumed to be the same in
        all spatial dimensions. Default: `1.0`.
    - `advect_on_diffusion` (keyword-only): If `True`, the second form
        of the dispersion equation in higher dimensions is used. As a
        consequence, there will be mixing in the spatial derivatives.
        Default: `False`.

    **Notes:**

    - The stepper is unconditionally stable, no matter the choice of
        any argument because the equation is solved analytically in Fourier
        space. **However**, note that initial conditions with modes higher
        than the Nyquist freuency (`(N//2)+1` with `N` being the
        `num_points`) lead to spurious oscillations.
    - Ultimately, only the factor `๐’ธ ฮ”t / Lยณ` affects the
        characteristic of the dynamics. See also
        [`exponax.stepper.generic.NormalizedLinearStepper`][] with
        `normalized_coefficients = [0, 0, 0, alpha_3]` with `alpha_3 =
        dispersivity * dt / domain_extent**3`.
    """
    if isinstance(dispersivity, float):
        dispersivity = jnp.ones(num_spatial_dims) * dispersivity
    self.dispersivity = dispersivity
    self.advect_on_diffusion = advect_on_diffusion
    super().__init__(
        num_spatial_dims=num_spatial_dims,
        domain_extent=domain_extent,
        num_points=num_points,
        dt=dt,
        num_channels=1,
        order=0,
    )
__call__ ยค
__call__(
    u: Float[Array, "C ... N"]
) -> Float[Array, "C ... N"]

Perform one step of the time integration for a single state.

Arguments:

  • u: The state vector, shape (C, ..., N,).

Returns:

  • u_next: The state vector after one step, shape (C, ..., N,).

Tip

Use this call method together with exponax.rollout to efficiently produce temporal trajectories.

Info

For batched operation, use jax.vmap on this function.

Source code in exponax/_base_stepper.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def __call__(
    self,
    u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
    """
    Perform one step of the time integration for a single state.

    **Arguments:**

    - `u`: The state vector, shape `(C, ..., N,)`.

    **Returns:**

    - `u_next`: The state vector after one step, shape `(C, ..., N,)`.

    !!! tip
        Use this call method together with `exponax.rollout` to efficiently
        produce temporal trajectories.

    !!! info
        For batched operation, use `jax.vmap` on this function.
    """
    expected_shape = (self.num_channels,) + spatial_shape(
        self.num_spatial_dims, self.num_points
    )
    if u.shape != expected_shape:
        raise ValueError(
            f"Expected shape {expected_shape}, got {u.shape}. For batched operation use `jax.vmap` on this function."
        )
    return self.step(u)