Skip to content

Truncated Fourier Series¤

exponax.ic.RandomTruncatedFourierSeries ¤

Bases: BaseRandomICGenerator

Source code in exponax/ic/_truncated_fourier_series.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class RandomTruncatedFourierSeries(BaseRandomICGenerator):
    num_spatial_dims: int
    cutoff: int
    offset_range: tuple[int, int]
    std_one: bool
    max_one: bool
    white_noise: WhiteNoise

    def __init__(
        self,
        num_spatial_dims: int,
        *,
        cutoff: int = 5,
        offset_range: tuple[int, int] = (0.0, 0.0),  # no offset by default
        std_one: bool = False,
        max_one: bool = False,
    ):
        """
        Random generator for initial states consisting of a truncated Fourier
        series. White noise is drawn in physical space, transformed to Fourier
        space, low-pass filtered up to ``cutoff``, and transformed back.

        **Arguments**:

        - `num_spatial_dims`: The number of spatial dimensions `d`.
        - `cutoff`: The cutoff of the wavenumbers. This limits the
            "complexity" of the initial state. Note that some dynamics are very
            sensitive to high-frequency information.
        - `offset_range`: The range of the offsets. Defaults to `(0.0,
            0.0)`, meaning **zero-mean** by default.
        - `std_one`: Whether to normalize the state to have a standard
            deviation of one. Defaults to `False`. Only works if the offset is
            zero.
        - `max_one`: Whether to normalize the state to have the maximum
            absolute value of one. Defaults to `False`. Only one of `std_one`
            and `max_one` can be `True`.
        """
        zero_mean = offset_range == (0.0, 0.0)
        validate_normalization_options(
            zero_mean=zero_mean, std_one=std_one, max_one=max_one
        )
        self.num_spatial_dims = num_spatial_dims
        self.cutoff = cutoff
        self.offset_range = offset_range
        self.std_one = std_one
        self.max_one = max_one
        self.white_noise = WhiteNoise(num_spatial_dims)

    def __call__(
        self, num_points: int, *, key: PRNGKeyArray
    ) -> Float[Array, "1 ... N"]:
        noise_key, offset_key = jr.split(key)

        noise = self.white_noise(num_points, key=noise_key)

        noise_hat = fft(noise, num_spatial_dims=self.num_spatial_dims)

        low_pass_filter = low_pass_filter_mask(
            self.num_spatial_dims, num_points, cutoff=self.cutoff, axis_separate=True
        )

        noise_hat = noise_hat * low_pass_filter

        offset = jr.uniform(
            offset_key,
            shape=(1,),
            minval=self.offset_range[0],
            maxval=self.offset_range[1],
        )[0]
        fourier_noise_shape = noise_hat.shape
        noise_hat = noise_hat.flatten().at[0].set(offset).reshape(fourier_noise_shape)

        ic = ifft(
            noise_hat,
            num_spatial_dims=self.num_spatial_dims,
            num_points=num_points,
        )

        zero_mean = self.offset_range == (0.0, 0.0)
        ic = normalize_ic(
            ic, zero_mean=zero_mean, std_one=self.std_one, max_one=self.max_one
        )

        return ic
__init__ ¤
__init__(
    num_spatial_dims: int,
    *,
    cutoff: int = 5,
    offset_range: tuple[int, int] = (0.0, 0.0),
    std_one: bool = False,
    max_one: bool = False
)

Random generator for initial states consisting of a truncated Fourier series. White noise is drawn in physical space, transformed to Fourier space, low-pass filtered up to cutoff, and transformed back.

Arguments:

  • num_spatial_dims: The number of spatial dimensions d.
  • cutoff: The cutoff of the wavenumbers. This limits the "complexity" of the initial state. Note that some dynamics are very sensitive to high-frequency information.
  • offset_range: The range of the offsets. Defaults to (0.0, 0.0), meaning zero-mean by default.
  • std_one: Whether to normalize the state to have a standard deviation of one. Defaults to False. Only works if the offset is zero.
  • max_one: Whether to normalize the state to have the maximum absolute value of one. Defaults to False. Only one of std_one and max_one can be True.
Source code in exponax/ic/_truncated_fourier_series.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
    self,
    num_spatial_dims: int,
    *,
    cutoff: int = 5,
    offset_range: tuple[int, int] = (0.0, 0.0),  # no offset by default
    std_one: bool = False,
    max_one: bool = False,
):
    """
    Random generator for initial states consisting of a truncated Fourier
    series. White noise is drawn in physical space, transformed to Fourier
    space, low-pass filtered up to ``cutoff``, and transformed back.

    **Arguments**:

    - `num_spatial_dims`: The number of spatial dimensions `d`.
    - `cutoff`: The cutoff of the wavenumbers. This limits the
        "complexity" of the initial state. Note that some dynamics are very
        sensitive to high-frequency information.
    - `offset_range`: The range of the offsets. Defaults to `(0.0,
        0.0)`, meaning **zero-mean** by default.
    - `std_one`: Whether to normalize the state to have a standard
        deviation of one. Defaults to `False`. Only works if the offset is
        zero.
    - `max_one`: Whether to normalize the state to have the maximum
        absolute value of one. Defaults to `False`. Only one of `std_one`
        and `max_one` can be `True`.
    """
    zero_mean = offset_range == (0.0, 0.0)
    validate_normalization_options(
        zero_mean=zero_mean, std_one=std_one, max_one=max_one
    )
    self.num_spatial_dims = num_spatial_dims
    self.cutoff = cutoff
    self.offset_range = offset_range
    self.std_one = std_one
    self.max_one = max_one
    self.white_noise = WhiteNoise(num_spatial_dims)
__call__ ¤
__call__(
    num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]
Source code in exponax/ic/_truncated_fourier_series.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def __call__(
    self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]:
    noise_key, offset_key = jr.split(key)

    noise = self.white_noise(num_points, key=noise_key)

    noise_hat = fft(noise, num_spatial_dims=self.num_spatial_dims)

    low_pass_filter = low_pass_filter_mask(
        self.num_spatial_dims, num_points, cutoff=self.cutoff, axis_separate=True
    )

    noise_hat = noise_hat * low_pass_filter

    offset = jr.uniform(
        offset_key,
        shape=(1,),
        minval=self.offset_range[0],
        maxval=self.offset_range[1],
    )[0]
    fourier_noise_shape = noise_hat.shape
    noise_hat = noise_hat.flatten().at[0].set(offset).reshape(fourier_noise_shape)

    ic = ifft(
        noise_hat,
        num_spatial_dims=self.num_spatial_dims,
        num_points=num_points,
    )

    zero_mean = self.offset_range == (0.0, 0.0)
    ic = normalize_ic(
        ic, zero_mean=zero_mean, std_one=self.std_one, max_one=self.max_one
    )

    return ic