Skip to content

Advection¤

In 1D:

\[ \frac{\partial u}{\partial t} + c \frac{\partial u}{\partial x} = 0 \]

In higher dimensions:

\[ \frac{\partial u}{\partial t} + \vec{c} \cdot \nabla u = 0 \]

(often just \(\vec{c} = c \vec{1}\))

exponax.stepper.Advection ¤

Bases: BaseStepper

Source code in exponax/stepper/_advection.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class Advection(BaseStepper):
    velocity: Float[Array, "D"]

    def __init__(
        self,
        num_spatial_dims: int,
        domain_extent: float,
        num_points: int,
        dt: float,
        *,
        velocity: Union[Float[Array, "D"], float] = 1.0,
    ):
        """
        Timestepper for the d-dimensional (`d ∈ {1, 2, 3}`) advection equation
        on periodic boundary conditions.

        In 1d, the advection equation is given by

        ```
            uₜ + c uₓ = 0
        ```

        with `c ∈ ℝ` being the velocity/advection speed.

        In higher dimensions, the advection equation can written as the inner
        product between velocity vector and gradient

        ```
            uₜ + c ⋅ ∇u = 0
        ```

        with `c ∈ ℝᵈ` being the velocity/advection vector.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions `d`.
        - `domain_extent`: The size of the domain `L`; in higher dimensions
            the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
        - `num_points`: The number of points `N` used to discretize the
            domain. This **includes** the left boundary point and **excludes**
            the right boundary point. In higher dimensions; the number of points
            in each dimension is the same. Hence, the total number of degrees of
            freedom is `Nᵈ`.
        - `dt`: The timestep size `Δt` between two consecutive states.
        - `velocity` (keyword-only): The advection speed `c`. In higher
            dimensions, this can be a scalar (=float) or a vector of length `d`.
            If a scalar is given, the advection speed is assumed to be the same
            in all spatial dimensions. Default: `1.0`.

        **Notes:**

        - The stepper is unconditionally stable, no matter the choice of
            any argument because the equation is solved analytically in Fourier
            space. **However**, note that initial conditions with modes higher
            than the Nyquist freuency (`(N//2)+1` with `N` being the
            `num_points`) lead to spurious oscillations.
        - Ultimately, only the factor `c Δt / L` affects the characteristic
            of the dynamics. See also
            [`exponax.stepper.generic.NormalizedLinearStepper`][] with
            `normalized_coefficients = [0, alpha_1]` with `alpha_1 = - velocity
            * dt / domain_extent`.
        """
        # TODO: better checks on the desired type of velocity
        if isinstance(velocity, float):
            velocity = jnp.ones(num_spatial_dims) * velocity
        self.velocity = velocity
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            domain_extent=domain_extent,
            num_points=num_points,
            dt=dt,
            num_channels=1,
            order=0,
        )

    def _build_linear_operator(
        self,
        derivative_operator: Complex[Array, "D ... (N//2)+1"],
    ) -> Complex[Array, "1 ... (N//2)+1"]:
        # Requires minus to move term to the rhs
        return -build_gradient_inner_product_operator(
            derivative_operator, self.velocity, order=1
        )

    def _build_nonlinear_fun(
        self,
        derivative_operator: Complex[Array, "D ... (N//2)+1"],
    ) -> ZeroNonlinearFun:
        return ZeroNonlinearFun(
            self.num_spatial_dims,
            self.num_points,
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    domain_extent: float,
    num_points: int,
    dt: float,
    *,
    velocity: Union[Float[Array, D], float] = 1.0
)

Timestepper for the d-dimensional (d ∈ {1, 2, 3}) advection equation on periodic boundary conditions.

In 1d, the advection equation is given by

    uₜ + c uₓ = 0

with c ∈ ℝ being the velocity/advection speed.

In higher dimensions, the advection equation can written as the inner product between velocity vector and gradient

    uₜ + c ⋅ ∇u = 0

with c ∈ ℝᵈ being the velocity/advection vector.

Arguments:

  • num_spatial_dims: The number of spatial dimensions d.
  • domain_extent: The size of the domain L; in higher dimensions the domain is assumed to be a scaled hypercube Ω = (0, L)ᵈ.
  • num_points: The number of points N used to discretize the domain. This includes the left boundary point and excludes the right boundary point. In higher dimensions; the number of points in each dimension is the same. Hence, the total number of degrees of freedom is Nᵈ.
  • dt: The timestep size Δt between two consecutive states.
  • velocity (keyword-only): The advection speed c. In higher dimensions, this can be a scalar (=float) or a vector of length d. If a scalar is given, the advection speed is assumed to be the same in all spatial dimensions. Default: 1.0.

Notes:

  • The stepper is unconditionally stable, no matter the choice of any argument because the equation is solved analytically in Fourier space. However, note that initial conditions with modes higher than the Nyquist freuency ((N//2)+1 with N being the num_points) lead to spurious oscillations.
  • Ultimately, only the factor c Δt / L affects the characteristic of the dynamics. See also exponax.stepper.generic.NormalizedLinearStepper with normalized_coefficients = [0, alpha_1] with `alpha_1 = - velocity
    • dt / domain_extent`.
Source code in exponax/stepper/_advection.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def __init__(
    self,
    num_spatial_dims: int,
    domain_extent: float,
    num_points: int,
    dt: float,
    *,
    velocity: Union[Float[Array, "D"], float] = 1.0,
):
    """
    Timestepper for the d-dimensional (`d ∈ {1, 2, 3}`) advection equation
    on periodic boundary conditions.

    In 1d, the advection equation is given by

    ```
        uₜ + c uₓ = 0
    ```

    with `c ∈ ℝ` being the velocity/advection speed.

    In higher dimensions, the advection equation can written as the inner
    product between velocity vector and gradient

    ```
        uₜ + c ⋅ ∇u = 0
    ```

    with `c ∈ ℝᵈ` being the velocity/advection vector.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions `d`.
    - `domain_extent`: The size of the domain `L`; in higher dimensions
        the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
    - `num_points`: The number of points `N` used to discretize the
        domain. This **includes** the left boundary point and **excludes**
        the right boundary point. In higher dimensions; the number of points
        in each dimension is the same. Hence, the total number of degrees of
        freedom is `Nᵈ`.
    - `dt`: The timestep size `Δt` between two consecutive states.
    - `velocity` (keyword-only): The advection speed `c`. In higher
        dimensions, this can be a scalar (=float) or a vector of length `d`.
        If a scalar is given, the advection speed is assumed to be the same
        in all spatial dimensions. Default: `1.0`.

    **Notes:**

    - The stepper is unconditionally stable, no matter the choice of
        any argument because the equation is solved analytically in Fourier
        space. **However**, note that initial conditions with modes higher
        than the Nyquist freuency (`(N//2)+1` with `N` being the
        `num_points`) lead to spurious oscillations.
    - Ultimately, only the factor `c Δt / L` affects the characteristic
        of the dynamics. See also
        [`exponax.stepper.generic.NormalizedLinearStepper`][] with
        `normalized_coefficients = [0, alpha_1]` with `alpha_1 = - velocity
        * dt / domain_extent`.
    """
    # TODO: better checks on the desired type of velocity
    if isinstance(velocity, float):
        velocity = jnp.ones(num_spatial_dims) * velocity
    self.velocity = velocity
    super().__init__(
        num_spatial_dims=num_spatial_dims,
        domain_extent=domain_extent,
        num_points=num_points,
        dt=dt,
        num_channels=1,
        order=0,
    )
__call__ ¤
__call__(
    u: Float[Array, "C ... N"]
) -> Float[Array, "C ... N"]

Perform one step of the time integration for a single state.

Arguments:

  • u: The state vector, shape (C, ..., N,).

Returns:

  • u_next: The state vector after one step, shape (C, ..., N,).

Tip

Use this call method together with exponax.rollout to efficiently produce temporal trajectories.

Info

For batched operation, use jax.vmap on this function.

Source code in exponax/_base_stepper.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def __call__(
    self,
    u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
    """
    Perform one step of the time integration for a single state.

    **Arguments:**

    - `u`: The state vector, shape `(C, ..., N,)`.

    **Returns:**

    - `u_next`: The state vector after one step, shape `(C, ..., N,)`.

    !!! tip
        Use this call method together with `exponax.rollout` to efficiently
        produce temporal trajectories.

    !!! info
        For batched operation, use `jax.vmap` on this function.
    """
    expected_shape = (self.num_channels,) + spatial_shape(
        self.num_spatial_dims, self.num_points
    )
    if u.shape != expected_shape:
        raise ValueError(
            f"Expected shape {expected_shape}, got {u.shape}. For batched operation use `jax.vmap` on this function."
        )
    return self.step(u)