Skip to content

Wave¤

In 1D:

\[ \frac{\partial^2 u}{\partial t^2} = c^2 \frac{\partial^2 u}{\partial x^2} \]

In higher dimensions:

\[ \frac{\partial^2 u}{\partial t^2} = c^2 \Delta u \]

with \(c \in \R\) the speed of sound (wave speed).

Internally, this second-order equation is rewritten as a first-order system of two coupled fields — height \(h\) and velocity \(v = h_t\):

\[ h_t = v, \quad v_t = c^2 \Delta h \]

The state therefore has two channels: u[0] is the height field \(h\) and u[1] is the velocity field \(v\).

exponax.stepper.Wave ¤

Bases: BaseStepper

Source code in exponax/stepper/_wave.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
class Wave(BaseStepper):
    speed_of_sound: float
    wavenumber_norm: Float[Array, " 1 ... (N//2)+1"]

    def __init__(
        self,
        num_spatial_dims: int,
        domain_extent: float,
        num_points: int,
        dt: float,
        *,
        speed_of_sound: float = 1.0,
    ):
        """
        Timestepper for the d-dimensional (`d ∈ {1, 2, 3}`) wave equation on
        periodic boundary conditions.

        In 1d, the wave equation is given by

        ```
            uₜₜ = c² uₓₓ
        ```

        with `c ∈ ℝ` being the speed of sound (or wave speed).

        In higher dimensions, the wave equation is written using the Laplacian

        ```
            uₜₜ = c² Δu
        ```

        Internally, the second-order equation is rewritten as a first-order
        system of two coupled fields — height `h` and velocity `v = hₜ`:

        ```
            hₜ = v
            vₜ = c² Δh
        ```

        As a result, the state has **two channels**: `u[0]` is the height field
        `h` and `u[1]` is the velocity field `v`.

        **Diagonalization:**

        The general solution of the wave equation is a superposition of
        right-traveling and left-traveling waves (d'Alembert's decomposition).
        This stepper exploits that structure: rather than time-stepping the
        coupled `(h, v)` system directly, it transforms into independent
        traveling-wave modes that each evolve as a simple phase rotation.

        In Fourier space, each wavenumber `k` gives a 2×2 ODE for `(ĥ, v̂)`
        that oscillates at frequency `ω = c|k|` — analogous to a harmonic
        oscillator trading potential and kinetic energy. Three steps
        diagonalize it:

        1. **Rescale** — `h` and `v` live on different scales (displacement vs.
           rate). Defining `w = iωĥ` puts them on equal footing. The coupled
           system becomes symmetric: `wₜ = iω v̂`, `v̂ₜ = iω w`.

        2. **Rotate** — Taking the sum and difference
           `pos = (w + v̂)/√2`, `neg = (w − v̂)/√2` decouples the system
           into two independent modes: `posₜ = +iω · pos` and
           `negₜ = −iω · neg`. Physically, `pos` is the right-traveling
           wave and `neg` the left-traveling wave.

        3. **Exponentiate** — Each decoupled mode evolves as a pure phase
           rotation: `pos(t+Δt) = exp(+iωΔt) · pos(t)`. This is what the
           ETDRK0 integrator computes exactly.

        After the exponential step, the inverse rotation and unscaling recover
        the updated `(h, v)`.

        At `k = 0` (the spatial mean), the frequency is zero and the two modes
        collapse — the system is no longer diagonalizable. There, the exact
        update is simply `h_mean += Δt · v_mean`, which is applied as a
        separate correction.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions `d`.
        - `domain_extent`: The size of the domain `L`; in higher dimensions
            the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
        - `num_points`: The number of points `N` used to discretize the
            domain. This **includes** the left boundary point and **excludes**
            the right boundary point. In higher dimensions; the number of points
            in each dimension is the same. Hence, the total number of degrees of
            freedom is `Nᵈ`.
        - `dt`: The timestep size `Δt` between two consecutive states.
        - `speed_of_sound` (keyword-only): The wave speed `c`. Default: `1.0`.

        **Notes:**

        - The stepper is unconditionally stable, no matter the choice of
            any argument because the equation is solved analytically in Fourier
            space.
        - Ultimately, only the factor `c Δt / L` affects the characteristic
            of the dynamics.
        - The implementation relies on a handcrafted diagonalization of the
          system in Fourier space, which is specific to the wave equation.
          Hence, wave dynamics is not part of the generic steppers like
          [`exponax.stepper.generic.GeneralLinearStepper`][]
        """
        self.speed_of_sound = speed_of_sound
        self.wavenumber_norm = jnp.linalg.norm(
            build_scaled_wavenumbers(
                num_spatial_dims=num_spatial_dims,
                domain_extent=domain_extent,
                num_points=num_points,
            ),
            axis=0,
            keepdims=True,
        )
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            domain_extent=domain_extent,
            num_points=num_points,
            dt=dt,
            num_channels=2,
            order=0,
        )

    def _forward_transform(
        self, u_hat: Complex[Array, " 2 ... (N//2)+1"]
    ) -> Complex[Array, " 2 ... (N//2)+1"]:
        """Transform (h, v) into diagonalized forward/backward wave modes."""
        h_hat, v_hat = u_hat[0:1], u_hat[1:2]
        # Scale height to match velocity units: w = i c |k| h
        k_guard = jnp.where(self.wavenumber_norm == 0, 1.0, self.wavenumber_norm)
        w_hat = 1j * self.speed_of_sound * k_guard * h_hat

        # Orthonormal rotation into wave modes
        pos = (1 / jnp.sqrt(2)) * (w_hat + v_hat)
        neg = (1 / jnp.sqrt(2)) * (w_hat - v_hat)
        return jnp.concatenate([pos, neg], axis=0)

    def _inverse_transform(
        self, waves_hat: Complex[Array, " 2 ... (N//2)+1"]
    ) -> Complex[Array, " 2 ... (N//2)+1"]:
        """Transform diagonalized wave modes back into (h, v)."""
        pos, neg = waves_hat[0:1], waves_hat[1:2]
        # Inverse rotation (the rotation matrix is its own inverse)
        w_hat = (1 / jnp.sqrt(2)) * (pos + neg)
        v_hat = (1 / jnp.sqrt(2)) * (pos - neg)

        # Undo scaling to recover height
        k_guard = jnp.where(self.wavenumber_norm == 0, 1.0, self.wavenumber_norm)
        h_hat = w_hat / (1j * self.speed_of_sound * k_guard)
        return jnp.concatenate([h_hat, v_hat], axis=0)

    def _build_linear_operator(
        self, derivative_operator: Complex[Array, " D ... (N//2)+1"]
    ) -> Complex[Array, " 2 ... (N//2)+1"]:
        val = 1j * self.speed_of_sound * self.wavenumber_norm
        return jnp.concatenate(
            (
                val,
                -val,
            ),
            axis=0,
        )

    def _build_nonlinear_fun(
        self, derivative_operator: Complex[Array, " D ... (N//2)+1"]
    ) -> ZeroNonlinearFun:
        return ZeroNonlinearFun(self.num_spatial_dims, self.num_points)

    def step_fourier(
        self, u_hat: Complex[Array, " 2 ... (N//2)+1"]
    ) -> Complex[Array, " 2 ... (N//2)+1"]:
        """
        Advance the state by one timestep in Fourier space.

        Overrides the base method to wrap the ETDRK step with the
        forward/inverse diagonalization transforms.
        """
        waves_hat = self._forward_transform(u_hat)
        waves_hat_next = super().step_fourier(waves_hat)
        u_hat_next = self._inverse_transform(waves_hat_next)

        # The k=0 (mean/DC) mode cannot be diagonalized because the two
        # eigenvalues collapse to zero and the system matrix becomes
        # [[0, 1], [0, 0]]. The diagonalization leaves this mode unchanged,
        # but the exact solution is h_mean(t+dt) = h_mean(t) + dt * v_mean(t).
        # Apply this linear drift explicitly.
        h_dc_idx = (0,) + (0,) * self.num_spatial_dims
        v_dc_idx = (1,) + (0,) * self.num_spatial_dims
        u_hat_next = u_hat_next.at[h_dc_idx].add(self.dt * u_hat[v_dc_idx])

        return u_hat_next
__init__ ¤
__init__(
    num_spatial_dims: int,
    domain_extent: float,
    num_points: int,
    dt: float,
    *,
    speed_of_sound: float = 1.0
)

Timestepper for the d-dimensional (d ∈ {1, 2, 3}) wave equation on periodic boundary conditions.

In 1d, the wave equation is given by

    uₜₜ = c² uₓₓ

with c ∈ ℝ being the speed of sound (or wave speed).

In higher dimensions, the wave equation is written using the Laplacian

    uₜₜ = c² Δu

Internally, the second-order equation is rewritten as a first-order system of two coupled fields — height h and velocity v = hₜ:

    hₜ = v
    vₜ = c² Δh

As a result, the state has two channels: u[0] is the height field h and u[1] is the velocity field v.

Diagonalization:

The general solution of the wave equation is a superposition of right-traveling and left-traveling waves (d'Alembert's decomposition). This stepper exploits that structure: rather than time-stepping the coupled (h, v) system directly, it transforms into independent traveling-wave modes that each evolve as a simple phase rotation.

In Fourier space, each wavenumber k gives a 2×2 ODE for (ĥ, v̂) that oscillates at frequency ω = c|k| — analogous to a harmonic oscillator trading potential and kinetic energy. Three steps diagonalize it:

  1. Rescaleh and v live on different scales (displacement vs. rate). Defining w = iωĥ puts them on equal footing. The coupled system becomes symmetric: wₜ = iω v̂, v̂ₜ = iω w.

  2. Rotate — Taking the sum and difference pos = (w + v̂)/√2, neg = (w − v̂)/√2 decouples the system into two independent modes: posₜ = +iω · pos and negₜ = −iω · neg. Physically, pos is the right-traveling wave and neg the left-traveling wave.

  3. Exponentiate — Each decoupled mode evolves as a pure phase rotation: pos(t+Δt) = exp(+iωΔt) · pos(t). This is what the ETDRK0 integrator computes exactly.

After the exponential step, the inverse rotation and unscaling recover the updated (h, v).

At k = 0 (the spatial mean), the frequency is zero and the two modes collapse — the system is no longer diagonalizable. There, the exact update is simply h_mean += Δt · v_mean, which is applied as a separate correction.

Arguments:

  • num_spatial_dims: The number of spatial dimensions d.
  • domain_extent: The size of the domain L; in higher dimensions the domain is assumed to be a scaled hypercube Ω = (0, L)ᵈ.
  • num_points: The number of points N used to discretize the domain. This includes the left boundary point and excludes the right boundary point. In higher dimensions; the number of points in each dimension is the same. Hence, the total number of degrees of freedom is Nᵈ.
  • dt: The timestep size Δt between two consecutive states.
  • speed_of_sound (keyword-only): The wave speed c. Default: 1.0.

Notes:

  • The stepper is unconditionally stable, no matter the choice of any argument because the equation is solved analytically in Fourier space.
  • Ultimately, only the factor c Δt / L affects the characteristic of the dynamics.
  • The implementation relies on a handcrafted diagonalization of the system in Fourier space, which is specific to the wave equation. Hence, wave dynamics is not part of the generic steppers like exponax.stepper.generic.GeneralLinearStepper
Source code in exponax/stepper/_wave.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def __init__(
    self,
    num_spatial_dims: int,
    domain_extent: float,
    num_points: int,
    dt: float,
    *,
    speed_of_sound: float = 1.0,
):
    """
    Timestepper for the d-dimensional (`d ∈ {1, 2, 3}`) wave equation on
    periodic boundary conditions.

    In 1d, the wave equation is given by

    ```
        uₜₜ = c² uₓₓ
    ```

    with `c ∈ ℝ` being the speed of sound (or wave speed).

    In higher dimensions, the wave equation is written using the Laplacian

    ```
        uₜₜ = c² Δu
    ```

    Internally, the second-order equation is rewritten as a first-order
    system of two coupled fields — height `h` and velocity `v = hₜ`:

    ```
        hₜ = v
        vₜ = c² Δh
    ```

    As a result, the state has **two channels**: `u[0]` is the height field
    `h` and `u[1]` is the velocity field `v`.

    **Diagonalization:**

    The general solution of the wave equation is a superposition of
    right-traveling and left-traveling waves (d'Alembert's decomposition).
    This stepper exploits that structure: rather than time-stepping the
    coupled `(h, v)` system directly, it transforms into independent
    traveling-wave modes that each evolve as a simple phase rotation.

    In Fourier space, each wavenumber `k` gives a 2×2 ODE for `(ĥ, v̂)`
    that oscillates at frequency `ω = c|k|` — analogous to a harmonic
    oscillator trading potential and kinetic energy. Three steps
    diagonalize it:

    1. **Rescale** — `h` and `v` live on different scales (displacement vs.
       rate). Defining `w = iωĥ` puts them on equal footing. The coupled
       system becomes symmetric: `wₜ = iω v̂`, `v̂ₜ = iω w`.

    2. **Rotate** — Taking the sum and difference
       `pos = (w + v̂)/√2`, `neg = (w − v̂)/√2` decouples the system
       into two independent modes: `posₜ = +iω · pos` and
       `negₜ = −iω · neg`. Physically, `pos` is the right-traveling
       wave and `neg` the left-traveling wave.

    3. **Exponentiate** — Each decoupled mode evolves as a pure phase
       rotation: `pos(t+Δt) = exp(+iωΔt) · pos(t)`. This is what the
       ETDRK0 integrator computes exactly.

    After the exponential step, the inverse rotation and unscaling recover
    the updated `(h, v)`.

    At `k = 0` (the spatial mean), the frequency is zero and the two modes
    collapse — the system is no longer diagonalizable. There, the exact
    update is simply `h_mean += Δt · v_mean`, which is applied as a
    separate correction.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions `d`.
    - `domain_extent`: The size of the domain `L`; in higher dimensions
        the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
    - `num_points`: The number of points `N` used to discretize the
        domain. This **includes** the left boundary point and **excludes**
        the right boundary point. In higher dimensions; the number of points
        in each dimension is the same. Hence, the total number of degrees of
        freedom is `Nᵈ`.
    - `dt`: The timestep size `Δt` between two consecutive states.
    - `speed_of_sound` (keyword-only): The wave speed `c`. Default: `1.0`.

    **Notes:**

    - The stepper is unconditionally stable, no matter the choice of
        any argument because the equation is solved analytically in Fourier
        space.
    - Ultimately, only the factor `c Δt / L` affects the characteristic
        of the dynamics.
    - The implementation relies on a handcrafted diagonalization of the
      system in Fourier space, which is specific to the wave equation.
      Hence, wave dynamics is not part of the generic steppers like
      [`exponax.stepper.generic.GeneralLinearStepper`][]
    """
    self.speed_of_sound = speed_of_sound
    self.wavenumber_norm = jnp.linalg.norm(
        build_scaled_wavenumbers(
            num_spatial_dims=num_spatial_dims,
            domain_extent=domain_extent,
            num_points=num_points,
        ),
        axis=0,
        keepdims=True,
    )
    super().__init__(
        num_spatial_dims=num_spatial_dims,
        domain_extent=domain_extent,
        num_points=num_points,
        dt=dt,
        num_channels=2,
        order=0,
    )
__call__ ¤
__call__(
    u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]

Perform one step of the time integration for a single state.

Arguments:

  • u: The state vector, shape (C, ..., N,).

Returns:

  • u_next: The state vector after one step, shape (C, ..., N,).

Tip

Use this call method together with exponax.rollout to efficiently produce temporal trajectories.

Info

For batched operation, use jax.vmap on this function.

Source code in exponax/_base_stepper.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def __call__(
    self,
    u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
    """
    Perform one step of the time integration for a single state.

    **Arguments:**

    - `u`: The state vector, shape `(C, ..., N,)`.

    **Returns:**

    - `u_next`: The state vector after one step, shape `(C, ..., N,)`.

    !!! tip
        Use this call method together with `exponax.rollout` to efficiently
        produce temporal trajectories.

    !!! info
        For batched operation, use `jax.vmap` on this function.
    """
    expected_shape = (self.num_channels,) + spatial_shape(
        self.num_spatial_dims, self.num_points
    )
    if u.shape != expected_shape:
        raise ValueError(
            f"""Expected shape {expected_shape}, got {u.shape}. For batched
             operation use `jax.vmap` on this function."""
        )
    return self.step(u)
_forward_transform ¤
_forward_transform(
    u_hat: Complex[Array, " 2 ... (N//2)+1"],
) -> Complex[Array, " 2 ... (N//2)+1"]

Transform (h, v) into diagonalized forward/backward wave modes.

Source code in exponax/stepper/_wave.py
130
131
132
133
134
135
136
137
138
139
140
141
142
def _forward_transform(
    self, u_hat: Complex[Array, " 2 ... (N//2)+1"]
) -> Complex[Array, " 2 ... (N//2)+1"]:
    """Transform (h, v) into diagonalized forward/backward wave modes."""
    h_hat, v_hat = u_hat[0:1], u_hat[1:2]
    # Scale height to match velocity units: w = i c |k| h
    k_guard = jnp.where(self.wavenumber_norm == 0, 1.0, self.wavenumber_norm)
    w_hat = 1j * self.speed_of_sound * k_guard * h_hat

    # Orthonormal rotation into wave modes
    pos = (1 / jnp.sqrt(2)) * (w_hat + v_hat)
    neg = (1 / jnp.sqrt(2)) * (w_hat - v_hat)
    return jnp.concatenate([pos, neg], axis=0)
_inverse_transform ¤
_inverse_transform(
    waves_hat: Complex[Array, " 2 ... (N//2)+1"],
) -> Complex[Array, " 2 ... (N//2)+1"]

Transform diagonalized wave modes back into (h, v).

Source code in exponax/stepper/_wave.py
144
145
146
147
148
149
150
151
152
153
154
155
156
def _inverse_transform(
    self, waves_hat: Complex[Array, " 2 ... (N//2)+1"]
) -> Complex[Array, " 2 ... (N//2)+1"]:
    """Transform diagonalized wave modes back into (h, v)."""
    pos, neg = waves_hat[0:1], waves_hat[1:2]
    # Inverse rotation (the rotation matrix is its own inverse)
    w_hat = (1 / jnp.sqrt(2)) * (pos + neg)
    v_hat = (1 / jnp.sqrt(2)) * (pos - neg)

    # Undo scaling to recover height
    k_guard = jnp.where(self.wavenumber_norm == 0, 1.0, self.wavenumber_norm)
    h_hat = w_hat / (1j * self.speed_of_sound * k_guard)
    return jnp.concatenate([h_hat, v_hat], axis=0)
step_fourier ¤
step_fourier(
    u_hat: Complex[Array, " 2 ... (N//2)+1"],
) -> Complex[Array, " 2 ... (N//2)+1"]

Advance the state by one timestep in Fourier space.

Overrides the base method to wrap the ETDRK step with the forward/inverse diagonalization transforms.

Source code in exponax/stepper/_wave.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def step_fourier(
    self, u_hat: Complex[Array, " 2 ... (N//2)+1"]
) -> Complex[Array, " 2 ... (N//2)+1"]:
    """
    Advance the state by one timestep in Fourier space.

    Overrides the base method to wrap the ETDRK step with the
    forward/inverse diagonalization transforms.
    """
    waves_hat = self._forward_transform(u_hat)
    waves_hat_next = super().step_fourier(waves_hat)
    u_hat_next = self._inverse_transform(waves_hat_next)

    # The k=0 (mean/DC) mode cannot be diagonalized because the two
    # eigenvalues collapse to zero and the system matrix becomes
    # [[0, 1], [0, 0]]. The diagonalization leaves this mode unchanged,
    # but the exact solution is h_mean(t+dt) = h_mean(t) + dt * v_mean(t).
    # Apply this linear drift explicitly.
    h_dc_idx = (0,) + (0,) * self.num_spatial_dims
    v_dc_idx = (1,) + (0,) * self.num_spatial_dims
    u_hat_next = u_hat_next.at[h_dc_idx].add(self.dt * u_hat[v_dc_idx])

    return u_hat_next