Skip to content

Gradient Norm¤

exponax.normalized.DifficultyGradientNormStepper ¤

Bases: NormalizedGradientNormStepper

Source code in exponax/normalized/_gradient_norm.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class DifficultyGradientNormStepper(NormalizedGradientNormStepper):
    linear_difficulties: tuple[float, ...]
    gradient_norm_difficulty: float

    def __init__(
        self,
        num_spatial_dims: int = 1,
        num_points: int = 48,
        *,
        linear_difficulties: tuple[float, ...] = (0.0, 0.0, -0.128, 0.0, -0.32768),
        gradient_norm_difficulty: float = 0.064,
        maximum_absolute: float = 1.0,
        order: int = 2,
        dealiasing_fraction: float = 2 / 3,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        """
        By default: KS equation
        """
        self.linear_difficulties = linear_difficulties
        self.gradient_norm_difficulty = gradient_norm_difficulty

        normalized_coefficients = extract_normalized_coefficients_from_difficulty(
            linear_difficulties,
            num_spatial_dims=num_spatial_dims,
            num_points=num_points,
        )
        normalized_gradient_norm_scale = (
            extract_normalized_gradient_norm_scale_from_difficulty(
                gradient_norm_difficulty,
                num_spatial_dims=num_spatial_dims,
                num_points=num_points,
                maximum_absolute=maximum_absolute,
            )
        )

        super().__init__(
            num_spatial_dims=num_spatial_dims,
            num_points=num_points,
            normalized_coefficients=normalized_coefficients,
            normalized_gradient_norm_scale=normalized_gradient_norm_scale,
            order=order,
            dealiasing_fraction=dealiasing_fraction,
            num_circle_points=num_circle_points,
            circle_radius=circle_radius,
        )
__init__ ¤
__init__(
    num_spatial_dims: int = 1,
    num_points: int = 48,
    *,
    linear_difficulties: tuple[float, ...] = (
        0.0,
        0.0,
        -0.128,
        0.0,
        -0.32768,
    ),
    gradient_norm_difficulty: float = 0.064,
    maximum_absolute: float = 1.0,
    order: int = 2,
    dealiasing_fraction: float = 2 / 3,
    num_circle_points: int = 16,
    circle_radius: float = 1.0
)

By default: KS equation

Source code in exponax/normalized/_gradient_norm.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def __init__(
    self,
    num_spatial_dims: int = 1,
    num_points: int = 48,
    *,
    linear_difficulties: tuple[float, ...] = (0.0, 0.0, -0.128, 0.0, -0.32768),
    gradient_norm_difficulty: float = 0.064,
    maximum_absolute: float = 1.0,
    order: int = 2,
    dealiasing_fraction: float = 2 / 3,
    num_circle_points: int = 16,
    circle_radius: float = 1.0,
):
    """
    By default: KS equation
    """
    self.linear_difficulties = linear_difficulties
    self.gradient_norm_difficulty = gradient_norm_difficulty

    normalized_coefficients = extract_normalized_coefficients_from_difficulty(
        linear_difficulties,
        num_spatial_dims=num_spatial_dims,
        num_points=num_points,
    )
    normalized_gradient_norm_scale = (
        extract_normalized_gradient_norm_scale_from_difficulty(
            gradient_norm_difficulty,
            num_spatial_dims=num_spatial_dims,
            num_points=num_points,
            maximum_absolute=maximum_absolute,
        )
    )

    super().__init__(
        num_spatial_dims=num_spatial_dims,
        num_points=num_points,
        normalized_coefficients=normalized_coefficients,
        normalized_gradient_norm_scale=normalized_gradient_norm_scale,
        order=order,
        dealiasing_fraction=dealiasing_fraction,
        num_circle_points=num_circle_points,
        circle_radius=circle_radius,
    )
__call__ ¤
__call__(
    u: Float[Array, "C ... N"]
) -> Float[Array, "C ... N"]

Performs a check

Source code in exponax/_base_stepper.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def __call__(
    self,
    u: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
    """
    Performs a check
    """
    expected_shape = (self.num_channels,) + spatial_shape(
        self.num_spatial_dims, self.num_points
    )
    if u.shape != expected_shape:
        raise ValueError(
            f"Expected shape {expected_shape}, got {u.shape}. For batched operation use `jax.vmap` on this function."
        )
    return self.step(u)