Skip to content

Gradient Norm¤

exponax.stepper.generic.DifficultyGradientNormStepper ¤

Bases: NormalizedGradientNormStepper

Source code in exponax/stepper/generic/_gradient_norm.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
class DifficultyGradientNormStepper(NormalizedGradientNormStepper):
    linear_difficulties: tuple[float, ...]
    gradient_norm_difficulty: float

    def __init__(
        self,
        num_spatial_dims: int = 1,
        num_points: int = 48,
        *,
        linear_difficulties: tuple[float, ...] = (0.0, 0.0, -0.128, 0.0, -0.32768),
        gradient_norm_difficulty: float = 0.064,
        maximum_absolute: float = 1.0,
        order: int = 2,
        dealiasing_fraction: float = 2 / 3,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        """
        Timestepper for the **difficulty-based** d-dimensional (`d ∈ {1, 2, 3}`)
        semi-linear PDEs consisting of a gradient norm nonlinearity and an
        arbitrary combination of (isotropic) linear operators. Uses a
        difficulty-based interface where the "intensity" of the dynamics reduces
        with increasing resolution. This is intended such that emulator learning
        problems on two resolutions are comparibly difficult.

        Different to `exponax.stepper.generic.NormalizedGradientNormStepper`,
        the dynamics are defined by difficulties. The difficulties are a
        different combination of normalized dynamics, `num_spatial_dims`, and
        `num_points`.

            γᵢ = αᵢ Nⁱ 2ⁱ⁻¹ d

        with `d` the number of spatial dimensions, `N` the number of points, and
        `αᵢ` the normalized coefficient.

        The difficulty of the nonlinear convection scale is defined by

            δ₂ = β₂ * M * N² * D

        with `M` the maximum absolute value of the input state (typically `1.0`
        if one uses the `exponax.ic` random generators with the `max_one=True`
        argument).

        This interface is more natural than the normalized interface because the
        difficulties for all orders (given by `i`) are around 1.0. Additionally,
        they relate to stability condition of explicit Finite Difference schemes
        for the particular equations. For example, for advection (`i=1`), the
        absolute of the difficulty is the Courant-Friedrichs-Lewy (CFL) number.

        Under the default settings, this timestepper represents the
        Kuramoto-Sivashinsky equation (in combustion format).

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions `d`.
        - `num_points`: The number of points `N` used to discretize the
            domain. This **includes** the left boundary point and **excludes**
            the right boundary point. In higher dimensions; the number of points
            in each dimension is the same. Hence, the total number of degrees of
            freedom is `Nᵈ`.
        - `linear_difficulties`: The list of difficulties `γᵢ` corresponding to
            the derivatives. The length of this tuple represents the highest
            occuring derivative. The default value `(0.0, 0.0, -0.128, 0.0,
            -0.32768)` corresponds to the Kuramoto-Sivashinsky equation in
            combustion format (because it contains both a negative diffusion and
            a negative hyperdiffusion term).
        - `gradient_norm_difficulty`: The difficulty of the gradient norm term
            `δ₂`.
        - `maximum_absolute`: The maximum absolute value of the input state. This
            is used to scale the gradient norm term.
        - `order`: The order of the Exponential Time Differencing Runge
            Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
            solves the linear part of the equation. Use higher values for higher
            accuracy and stability. The default choice of `2` is a good
            compromise for single precision floats.
        - `dealiasing_fraction`: The fraction of the wavenumbers to keep
            before evaluating the nonlinearity. The default 2/3 corresponds to
            Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default:
            2/3.
        - `num_circle_points`: How many points to use in the complex contour
            integral method to compute the coefficients of the exponential time
            differencing Runge Kutta method. Default: 16.
        - `circle_radius`: The radius of the contour used to compute the
            coefficients of the exponential time differencing Runge Kutta
            method. Default: 1.0.
        """
        self.linear_difficulties = linear_difficulties
        self.gradient_norm_difficulty = gradient_norm_difficulty

        normalized_coefficients = extract_normalized_coefficients_from_difficulty(
            linear_difficulties,
            num_spatial_dims=num_spatial_dims,
            num_points=num_points,
        )
        normalized_gradient_norm_scale = (
            extract_normalized_gradient_norm_scale_from_difficulty(
                gradient_norm_difficulty,
                num_spatial_dims=num_spatial_dims,
                num_points=num_points,
                maximum_absolute=maximum_absolute,
            )
        )

        super().__init__(
            num_spatial_dims=num_spatial_dims,
            num_points=num_points,
            normalized_linear_coefficients=normalized_coefficients,
            normalized_gradient_norm_scale=normalized_gradient_norm_scale,
            order=order,
            dealiasing_fraction=dealiasing_fraction,
            num_circle_points=num_circle_points,
            circle_radius=circle_radius,
        )
__init__ ¤
__init__(
    num_spatial_dims: int = 1,
    num_points: int = 48,
    *,
    linear_difficulties: tuple[float, ...] = (
        0.0,
        0.0,
        -0.128,
        0.0,
        -0.32768,
    ),
    gradient_norm_difficulty: float = 0.064,
    maximum_absolute: float = 1.0,
    order: int = 2,
    dealiasing_fraction: float = 2 / 3,
    num_circle_points: int = 16,
    circle_radius: float = 1.0
)

Timestepper for the difficulty-based d-dimensional (d ∈ {1, 2, 3}) semi-linear PDEs consisting of a gradient norm nonlinearity and an arbitrary combination of (isotropic) linear operators. Uses a difficulty-based interface where the "intensity" of the dynamics reduces with increasing resolution. This is intended such that emulator learning problems on two resolutions are comparibly difficult.

Different to exponax.stepper.generic.NormalizedGradientNormStepper, the dynamics are defined by difficulties. The difficulties are a different combination of normalized dynamics, num_spatial_dims, and num_points.

γᵢ = αᵢ Nⁱ 2ⁱ⁻¹ d

with d the number of spatial dimensions, N the number of points, and αᵢ the normalized coefficient.

The difficulty of the nonlinear convection scale is defined by

δ₂ = β₂ * M * N² * D

with M the maximum absolute value of the input state (typically 1.0 if one uses the exponax.ic random generators with the max_one=True argument).

This interface is more natural than the normalized interface because the difficulties for all orders (given by i) are around 1.0. Additionally, they relate to stability condition of explicit Finite Difference schemes for the particular equations. For example, for advection (i=1), the absolute of the difficulty is the Courant-Friedrichs-Lewy (CFL) number.

Under the default settings, this timestepper represents the Kuramoto-Sivashinsky equation (in combustion format).

Arguments:

  • num_spatial_dims: The number of spatial dimensions d.
  • num_points: The number of points N used to discretize the domain. This includes the left boundary point and excludes the right boundary point. In higher dimensions; the number of points in each dimension is the same. Hence, the total number of degrees of freedom is Nᵈ.
  • linear_difficulties: The list of difficulties γᵢ corresponding to the derivatives. The length of this tuple represents the highest occuring derivative. The default value (0.0, 0.0, -0.128, 0.0, -0.32768) corresponds to the Kuramoto-Sivashinsky equation in combustion format (because it contains both a negative diffusion and a negative hyperdiffusion term).
  • gradient_norm_difficulty: The difficulty of the gradient norm term δ₂.
  • maximum_absolute: The maximum absolute value of the input state. This is used to scale the gradient norm term.
  • order: The order of the Exponential Time Differencing Runge Kutta method. Must be one of {0, 1, 2, 3, 4}. The option 0 only solves the linear part of the equation. Use higher values for higher accuracy and stability. The default choice of 2 is a good compromise for single precision floats.
  • dealiasing_fraction: The fraction of the wavenumbers to keep before evaluating the nonlinearity. The default 2/3 corresponds to Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default: 2/3.
  • num_circle_points: How many points to use in the complex contour integral method to compute the coefficients of the exponential time differencing Runge Kutta method. Default: 16.
  • circle_radius: The radius of the contour used to compute the coefficients of the exponential time differencing Runge Kutta method. Default: 1.0.
Source code in exponax/stepper/generic/_gradient_norm.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def __init__(
    self,
    num_spatial_dims: int = 1,
    num_points: int = 48,
    *,
    linear_difficulties: tuple[float, ...] = (0.0, 0.0, -0.128, 0.0, -0.32768),
    gradient_norm_difficulty: float = 0.064,
    maximum_absolute: float = 1.0,
    order: int = 2,
    dealiasing_fraction: float = 2 / 3,
    num_circle_points: int = 16,
    circle_radius: float = 1.0,
):
    """
    Timestepper for the **difficulty-based** d-dimensional (`d ∈ {1, 2, 3}`)
    semi-linear PDEs consisting of a gradient norm nonlinearity and an
    arbitrary combination of (isotropic) linear operators. Uses a
    difficulty-based interface where the "intensity" of the dynamics reduces
    with increasing resolution. This is intended such that emulator learning
    problems on two resolutions are comparibly difficult.

    Different to `exponax.stepper.generic.NormalizedGradientNormStepper`,
    the dynamics are defined by difficulties. The difficulties are a
    different combination of normalized dynamics, `num_spatial_dims`, and
    `num_points`.

        γᵢ = αᵢ Nⁱ 2ⁱ⁻¹ d

    with `d` the number of spatial dimensions, `N` the number of points, and
    `αᵢ` the normalized coefficient.

    The difficulty of the nonlinear convection scale is defined by

        δ₂ = β₂ * M * N² * D

    with `M` the maximum absolute value of the input state (typically `1.0`
    if one uses the `exponax.ic` random generators with the `max_one=True`
    argument).

    This interface is more natural than the normalized interface because the
    difficulties for all orders (given by `i`) are around 1.0. Additionally,
    they relate to stability condition of explicit Finite Difference schemes
    for the particular equations. For example, for advection (`i=1`), the
    absolute of the difficulty is the Courant-Friedrichs-Lewy (CFL) number.

    Under the default settings, this timestepper represents the
    Kuramoto-Sivashinsky equation (in combustion format).

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions `d`.
    - `num_points`: The number of points `N` used to discretize the
        domain. This **includes** the left boundary point and **excludes**
        the right boundary point. In higher dimensions; the number of points
        in each dimension is the same. Hence, the total number of degrees of
        freedom is `Nᵈ`.
    - `linear_difficulties`: The list of difficulties `γᵢ` corresponding to
        the derivatives. The length of this tuple represents the highest
        occuring derivative. The default value `(0.0, 0.0, -0.128, 0.0,
        -0.32768)` corresponds to the Kuramoto-Sivashinsky equation in
        combustion format (because it contains both a negative diffusion and
        a negative hyperdiffusion term).
    - `gradient_norm_difficulty`: The difficulty of the gradient norm term
        `δ₂`.
    - `maximum_absolute`: The maximum absolute value of the input state. This
        is used to scale the gradient norm term.
    - `order`: The order of the Exponential Time Differencing Runge
        Kutta method. Must be one of {0, 1, 2, 3, 4}. The option `0` only
        solves the linear part of the equation. Use higher values for higher
        accuracy and stability. The default choice of `2` is a good
        compromise for single precision floats.
    - `dealiasing_fraction`: The fraction of the wavenumbers to keep
        before evaluating the nonlinearity. The default 2/3 corresponds to
        Orszag's 2/3 rule. To fully eliminate aliasing, use 1/2. Default:
        2/3.
    - `num_circle_points`: How many points to use in the complex contour
        integral method to compute the coefficients of the exponential time
        differencing Runge Kutta method. Default: 16.
    - `circle_radius`: The radius of the contour used to compute the
        coefficients of the exponential time differencing Runge Kutta
        method. Default: 1.0.
    """
    self.linear_difficulties = linear_difficulties
    self.gradient_norm_difficulty = gradient_norm_difficulty

    normalized_coefficients = extract_normalized_coefficients_from_difficulty(
        linear_difficulties,
        num_spatial_dims=num_spatial_dims,
        num_points=num_points,
    )
    normalized_gradient_norm_scale = (
        extract_normalized_gradient_norm_scale_from_difficulty(
            gradient_norm_difficulty,
            num_spatial_dims=num_spatial_dims,
            num_points=num_points,
            maximum_absolute=maximum_absolute,
        )
    )

    super().__init__(
        num_spatial_dims=num_spatial_dims,
        num_points=num_points,
        normalized_linear_coefficients=normalized_coefficients,
        normalized_gradient_norm_scale=normalized_gradient_norm_scale,
        order=order,
        dealiasing_fraction=dealiasing_fraction,
        num_circle_points=num_circle_points,
        circle_radius=circle_radius,
    )
__call__ ¤
__call__(
    u: Float[Array, "C ... N"]
) -> Float[Array, "C ... N"]

Perform one step of the time integration for a single state.

Arguments:

  • u: The state vector, shape (C, ..., N,).

Returns:

  • u_next: The state vector after one step, shape (C, ..., N,).

Tip

Use this call method together with exponax.rollout to efficiently produce temporal trajectories.

Info

For batched operation, use jax.vmap on this function.

Source code in exponax/_base_stepper.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def __call__(
    self,
    u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
    """
    Perform one step of the time integration for a single state.

    **Arguments:**

    - `u`: The state vector, shape `(C, ..., N,)`.

    **Returns:**

    - `u_next`: The state vector after one step, shape `(C, ..., N,)`.

    !!! tip
        Use this call method together with `exponax.rollout` to efficiently
        produce temporal trajectories.

    !!! info
        For batched operation, use `jax.vmap` on this function.
    """
    expected_shape = (self.num_channels,) + spatial_shape(
        self.num_spatial_dims, self.num_points
    )
    if u.shape != expected_shape:
        raise ValueError(
            f"Expected shape {expected_shape}, got {u.shape}. For batched operation use `jax.vmap` on this function."
        )
    return self.step(u)