Skip to content

Nonlinear¤

exponax.normalized.NormalizedGeneralNonlinearStepper ¤

Bases: BaseStepper

Source code in exponax/normalized/_general_nonlinear.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class NormalizedGeneralNonlinearStepper(BaseStepper):
    normalized_coefficients_linear: tuple[float, ...]
    normalized_coefficients_nonlinear: tuple[float, ...]
    dealiasing_fraction: float

    def __init__(
        self,
        num_spatial_dims: int,
        num_points: int,
        *,
        normalized_coefficients_linear: tuple[float, ...] = (0.0, 0.0, 0.1 * 0.1),
        normalized_coefficients_nonlinear: tuple[float, ...] = (0.0, -1.0 * 0.1, 0.0),
        order=2,
        dealiasing_fraction: float = 2 / 3,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        """
        By default Burgers.
        """

        if len(normalized_coefficients_nonlinear) != 3:
            raise ValueError(
                "The nonlinear coefficients list must have exactly 3 elements"
            )
        self.normalized_coefficients_linear = normalized_coefficients_linear
        self.normalized_coefficients_nonlinear = normalized_coefficients_nonlinear
        self.dealiasing_fraction = dealiasing_fraction

        super().__init__(
            num_spatial_dims=num_spatial_dims,
            domain_extent=1.0,  # Derivative operator is just scaled with 2 * jnp.pi
            num_points=num_points,
            dt=1.0,
            num_channels=1,
            order=order,
            num_circle_points=num_circle_points,
            circle_radius=circle_radius,
        )

    def _build_linear_operator(self, derivative_operator: Array) -> Array:
        linear_operator = sum(
            jnp.sum(
                c * (derivative_operator) ** i,
                axis=0,
                keepdims=True,
            )
            for i, c in enumerate(self.normalized_coefficients_linear)
        )
        return linear_operator

    def _build_nonlinear_fun(
        self,
        derivative_operator: Complex[Array, "D ... (N//2)+1"],
    ) -> GeneralNonlinearFun:
        return GeneralNonlinearFun(
            self.num_spatial_dims,
            self.num_points,
            derivative_operator=derivative_operator,
            dealiasing_fraction=self.dealiasing_fraction,
            scale_list=self.normalized_coefficients_nonlinear,
            zero_mode_fix=True,  # ToDo: check this
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    num_points: int,
    *,
    normalized_coefficients_linear: tuple[float, ...] = (
        0.0,
        0.0,
        0.1 * 0.1,
    ),
    normalized_coefficients_nonlinear: tuple[float, ...] = (
        0.0,
        -1.0 * 0.1,
        0.0,
    ),
    order=2,
    dealiasing_fraction: float = 2 / 3,
    num_circle_points: int = 16,
    circle_radius: float = 1.0
)

By default Burgers.

Source code in exponax/normalized/_general_nonlinear.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(
    self,
    num_spatial_dims: int,
    num_points: int,
    *,
    normalized_coefficients_linear: tuple[float, ...] = (0.0, 0.0, 0.1 * 0.1),
    normalized_coefficients_nonlinear: tuple[float, ...] = (0.0, -1.0 * 0.1, 0.0),
    order=2,
    dealiasing_fraction: float = 2 / 3,
    num_circle_points: int = 16,
    circle_radius: float = 1.0,
):
    """
    By default Burgers.
    """

    if len(normalized_coefficients_nonlinear) != 3:
        raise ValueError(
            "The nonlinear coefficients list must have exactly 3 elements"
        )
    self.normalized_coefficients_linear = normalized_coefficients_linear
    self.normalized_coefficients_nonlinear = normalized_coefficients_nonlinear
    self.dealiasing_fraction = dealiasing_fraction

    super().__init__(
        num_spatial_dims=num_spatial_dims,
        domain_extent=1.0,  # Derivative operator is just scaled with 2 * jnp.pi
        num_points=num_points,
        dt=1.0,
        num_channels=1,
        order=order,
        num_circle_points=num_circle_points,
        circle_radius=circle_radius,
    )
__call__ ¤
__call__(
    u: Float[Array, "C ... N"]
) -> Float[Array, "C ... N"]

Performs a check

Source code in exponax/_base_stepper.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def __call__(
    self,
    u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
    """
    Performs a check
    """
    expected_shape = (self.num_channels,) + spatial_shape(
        self.num_spatial_dims, self.num_points
    )
    if u.shape != expected_shape:
        raise ValueError(
            f"Expected shape {expected_shape}, got {u.shape}. For batched operation use `jax.vmap` on this function."
        )
    return self.step(u)