Skip to content

Gaussian Random Field¤

exponax.ic.GaussianRandomField ¤

Bases: BaseRandomICGenerator

Source code in exponax/ic/_gaussian_random_field.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class GaussianRandomField(BaseRandomICGenerator):
    num_spatial_dims: int
    domain_extent: float
    powerlaw_exponent: float
    zero_mean: bool
    std_one: bool
    max_one: bool

    def __init__(
        self,
        num_spatial_dims: int,
        *,
        domain_extent: float = 1.0,
        powerlaw_exponent: float = 3.0,
        zero_mean: bool = True,
        std_one: bool = False,
        max_one: bool = False,
    ):
        """
        Random generator for initial states following a power-law spectrum in
        Fourier space.

        **Arguments:**
            - `num_spatial_dims`: The number of spatial dimensions.
            - `domain_extent`: The extent of the domain in each spatial direction.
            - `powerlaw_exponent`: The exponent of the power-law spectrum.
            - `zero_mean`: Whether the field should have zero mean.
            - `std_one`: Whether to normalize the state to have a standard
                deviation of one. Defaults to `False`. Only works if the offset
                is zero.
            - `max_one`: Whether to normalize the state to have the maximum
                absolute value of one. Defaults to `False`. Only one of
                `std_one` and `max_one` can be `True`.
        """
        if not zero_mean and std_one:
            raise ValueError("Cannot have `zero_mean=False` and `std_one=True`.")
        if std_one and max_one:
            raise ValueError("Cannot have `std_one=True` and `max_one=True`.")
        self.num_spatial_dims = num_spatial_dims
        self.domain_extent = domain_extent
        self.powerlaw_exponent = powerlaw_exponent
        self.zero_mean = zero_mean
        self.std_one = std_one
        self.max_one = max_one

    def __call__(
        self, num_points: int, *, key: PRNGKeyArray
    ) -> Float[Array, "1 ... N"]:
        wavenumber_grid = build_scaled_wavenumbers(
            self.num_spatial_dims, self.domain_extent, num_points
        )
        wavenumer_norm_grid = jnp.linalg.norm(wavenumber_grid, axis=0, keepdims=True)
        amplitude = jnp.power(wavenumer_norm_grid, -self.powerlaw_exponent / 2.0)
        amplitude = (
            amplitude.flatten().at[0].set(0.0).reshape(wavenumer_norm_grid.shape)
        )

        real_key, imag_key = jr.split(key, 2)
        noise = jr.normal(
            real_key,
            shape=(1,) + wavenumber_shape(self.num_spatial_dims, num_points),
        ) + 1j * jr.normal(
            imag_key,
            shape=(1,) + wavenumber_shape(self.num_spatial_dims, num_points),
        )

        noise = noise * amplitude

        noise = noise * build_scaling_array(self.num_spatial_dims, num_points)

        ic = jnp.fft.irfftn(
            noise,
            s=spatial_shape(self.num_spatial_dims, num_points),
            axes=space_indices(self.num_spatial_dims),
        )

        if self.zero_mean:
            ic = ic - jnp.mean(ic)

        if self.std_one:
            ic = ic / jnp.std(ic)

        if self.max_one:
            ic = ic / jnp.max(jnp.abs(ic))

        return ic
__init__ ¤
__init__(
    num_spatial_dims: int,
    *,
    domain_extent: float = 1.0,
    powerlaw_exponent: float = 3.0,
    zero_mean: bool = True,
    std_one: bool = False,
    max_one: bool = False
)

Random generator for initial states following a power-law spectrum in Fourier space.

Arguments: - num_spatial_dims: The number of spatial dimensions. - domain_extent: The extent of the domain in each spatial direction. - powerlaw_exponent: The exponent of the power-law spectrum. - zero_mean: Whether the field should have zero mean. - std_one: Whether to normalize the state to have a standard deviation of one. Defaults to False. Only works if the offset is zero. - max_one: Whether to normalize the state to have the maximum absolute value of one. Defaults to False. Only one of std_one and max_one can be True.

Source code in exponax/ic/_gaussian_random_field.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self,
    num_spatial_dims: int,
    *,
    domain_extent: float = 1.0,
    powerlaw_exponent: float = 3.0,
    zero_mean: bool = True,
    std_one: bool = False,
    max_one: bool = False,
):
    """
    Random generator for initial states following a power-law spectrum in
    Fourier space.

    **Arguments:**
        - `num_spatial_dims`: The number of spatial dimensions.
        - `domain_extent`: The extent of the domain in each spatial direction.
        - `powerlaw_exponent`: The exponent of the power-law spectrum.
        - `zero_mean`: Whether the field should have zero mean.
        - `std_one`: Whether to normalize the state to have a standard
            deviation of one. Defaults to `False`. Only works if the offset
            is zero.
        - `max_one`: Whether to normalize the state to have the maximum
            absolute value of one. Defaults to `False`. Only one of
            `std_one` and `max_one` can be `True`.
    """
    if not zero_mean and std_one:
        raise ValueError("Cannot have `zero_mean=False` and `std_one=True`.")
    if std_one and max_one:
        raise ValueError("Cannot have `std_one=True` and `max_one=True`.")
    self.num_spatial_dims = num_spatial_dims
    self.domain_extent = domain_extent
    self.powerlaw_exponent = powerlaw_exponent
    self.zero_mean = zero_mean
    self.std_one = std_one
    self.max_one = max_one
__call__ ¤
__call__(
    num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]
Source code in exponax/ic/_gaussian_random_field.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def __call__(
    self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]:
    wavenumber_grid = build_scaled_wavenumbers(
        self.num_spatial_dims, self.domain_extent, num_points
    )
    wavenumer_norm_grid = jnp.linalg.norm(wavenumber_grid, axis=0, keepdims=True)
    amplitude = jnp.power(wavenumer_norm_grid, -self.powerlaw_exponent / 2.0)
    amplitude = (
        amplitude.flatten().at[0].set(0.0).reshape(wavenumer_norm_grid.shape)
    )

    real_key, imag_key = jr.split(key, 2)
    noise = jr.normal(
        real_key,
        shape=(1,) + wavenumber_shape(self.num_spatial_dims, num_points),
    ) + 1j * jr.normal(
        imag_key,
        shape=(1,) + wavenumber_shape(self.num_spatial_dims, num_points),
    )

    noise = noise * amplitude

    noise = noise * build_scaling_array(self.num_spatial_dims, num_points)

    ic = jnp.fft.irfftn(
        noise,
        s=spatial_shape(self.num_spatial_dims, num_points),
        axes=space_indices(self.num_spatial_dims),
    )

    if self.zero_mean:
        ic = ic - jnp.mean(ic)

    if self.std_one:
        ic = ic / jnp.std(ic)

    if self.max_one:
        ic = ic / jnp.max(jnp.abs(ic))

    return ic