Skip to content

Sine Waves in 1D¤

exponax.ic.RandomSineWaves1d ¤

Bases: BaseRandomICGenerator

Source code in exponax/ic/_sine_waves_1d.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class RandomSineWaves1d(BaseRandomICGenerator):
    num_spatial_dims: int
    domain_extent: float
    cutoff: int
    amplitude_range: tuple[float, float]
    phase_range: tuple[float, float]
    offset_range: tuple[float, float]

    std_one: bool
    max_one: bool

    def __init__(
        self,
        num_spatial_dims: int,
        *,
        domain_extent: float = 1.0,
        cutoff: int = 5,
        amplitude_range: tuple[float, float] = (-1.0, 1.0),
        phase_range: tuple[float, float] = (0.0, 2 * jnp.pi),
        offset_range: tuple[float, float] = (0.0, 0.0),
        std_one: bool = False,
        max_one: bool = False,
    ):
        """
        Random generator for initial states described by a collection of sine
        waves. Only works in 1d.

        This is a simplified version of the `RandomTruncatedFourierSeries`
        generator that works in arbitrary dimensions. However, only this
        generator can produce a functional representation of the initial
        condition.

        **Arguments**:

        - `num_spatial_dims`: The number of spatial dimensions.
        - `domain_extent`: The extent of the domain.
        - `cutoff`: The cutoff of the wavenumbers. This limits the
            "complexity" of the initial state. Note that some dynamics are very
            sensitive to high-frequency information.
        - `amplitude_range`: The range of the amplitudes. Defaults to
            `(-1.0, 1.0)`.
        - `phase_range`: The range of the phases. Defaults to `(0.0, 2π)`.
        - `offset_range`: The range of the offsets. Defaults to `(0.0,
            0.0)`, meaning **zero-mean** by default.
        - `std_one`: Whether to normalize the state to have a standard
            deviation of one. Defaults to `False`. Only works if the offset is
            zero.
        - `max_one`: Whether to normalize the state to have the maximum
            absolute value of one. Defaults to `False`. Only one of `std_one`
            and `max_one` can be `True`.
        """
        if num_spatial_dims != 1:
            raise ValueError("RandomSineWaves1d only works in 1d.")
        if offset_range != (0.0, 0.0) and std_one:
            raise ValueError("Cannot have non-zero offset and `std_one=True`.")
        if std_one and max_one:
            raise ValueError("Cannot have `std_one=True` and `max_one=True`.")

        self.num_spatial_dims = num_spatial_dims
        self.domain_extent = domain_extent
        self.cutoff = cutoff
        self.amplitude_range = amplitude_range
        self.phase_range = phase_range
        self.offset_range = offset_range
        self.std_one = std_one
        self.max_one = max_one

    def gen_ic_fun(self, *, key: PRNGKeyArray) -> SineWaves1d:
        amplitude_key, phase_key, offset_key = jr.split(key, 3)

        amplitudes = jr.uniform(
            amplitude_key,
            shape=(self.cutoff,),
            minval=self.amplitude_range[0],
            maxval=self.amplitude_range[1],
        )
        phases = jr.uniform(
            phase_key,
            shape=(self.cutoff,),
            minval=self.phase_range[0],
            maxval=self.phase_range[1],
        )
        offset = jr.uniform(
            offset_key,
            shape=(),
            minval=self.offset_range[0],
            maxval=self.offset_range[1],
        )

        return SineWaves1d(
            domain_extent=self.domain_extent,
            amplitudes=amplitudes,
            wavenumbers=jnp.arange(1, self.cutoff + 1),
            phases=phases,
            offset=offset,
            std_one=self.std_one,
            max_one=self.max_one,
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    *,
    domain_extent: float = 1.0,
    cutoff: int = 5,
    amplitude_range: tuple[float, float] = (-1.0, 1.0),
    phase_range: tuple[float, float] = (0.0, 2 * jnp.pi),
    offset_range: tuple[float, float] = (0.0, 0.0),
    std_one: bool = False,
    max_one: bool = False
)

Random generator for initial states described by a collection of sine waves. Only works in 1d.

This is a simplified version of the RandomTruncatedFourierSeries generator that works in arbitrary dimensions. However, only this generator can produce a functional representation of the initial condition.

Arguments:

  • num_spatial_dims: The number of spatial dimensions.
  • domain_extent: The extent of the domain.
  • cutoff: The cutoff of the wavenumbers. This limits the "complexity" of the initial state. Note that some dynamics are very sensitive to high-frequency information.
  • amplitude_range: The range of the amplitudes. Defaults to (-1.0, 1.0).
  • phase_range: The range of the phases. Defaults to (0.0, 2π).
  • offset_range: The range of the offsets. Defaults to (0.0, 0.0), meaning zero-mean by default.
  • std_one: Whether to normalize the state to have a standard deviation of one. Defaults to False. Only works if the offset is zero.
  • max_one: Whether to normalize the state to have the maximum absolute value of one. Defaults to False. Only one of std_one and max_one can be True.
Source code in exponax/ic/_sine_waves_1d.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def __init__(
    self,
    num_spatial_dims: int,
    *,
    domain_extent: float = 1.0,
    cutoff: int = 5,
    amplitude_range: tuple[float, float] = (-1.0, 1.0),
    phase_range: tuple[float, float] = (0.0, 2 * jnp.pi),
    offset_range: tuple[float, float] = (0.0, 0.0),
    std_one: bool = False,
    max_one: bool = False,
):
    """
    Random generator for initial states described by a collection of sine
    waves. Only works in 1d.

    This is a simplified version of the `RandomTruncatedFourierSeries`
    generator that works in arbitrary dimensions. However, only this
    generator can produce a functional representation of the initial
    condition.

    **Arguments**:

    - `num_spatial_dims`: The number of spatial dimensions.
    - `domain_extent`: The extent of the domain.
    - `cutoff`: The cutoff of the wavenumbers. This limits the
        "complexity" of the initial state. Note that some dynamics are very
        sensitive to high-frequency information.
    - `amplitude_range`: The range of the amplitudes. Defaults to
        `(-1.0, 1.0)`.
    - `phase_range`: The range of the phases. Defaults to `(0.0, 2π)`.
    - `offset_range`: The range of the offsets. Defaults to `(0.0,
        0.0)`, meaning **zero-mean** by default.
    - `std_one`: Whether to normalize the state to have a standard
        deviation of one. Defaults to `False`. Only works if the offset is
        zero.
    - `max_one`: Whether to normalize the state to have the maximum
        absolute value of one. Defaults to `False`. Only one of `std_one`
        and `max_one` can be `True`.
    """
    if num_spatial_dims != 1:
        raise ValueError("RandomSineWaves1d only works in 1d.")
    if offset_range != (0.0, 0.0) and std_one:
        raise ValueError("Cannot have non-zero offset and `std_one=True`.")
    if std_one and max_one:
        raise ValueError("Cannot have `std_one=True` and `max_one=True`.")

    self.num_spatial_dims = num_spatial_dims
    self.domain_extent = domain_extent
    self.cutoff = cutoff
    self.amplitude_range = amplitude_range
    self.phase_range = phase_range
    self.offset_range = offset_range
    self.std_one = std_one
    self.max_one = max_one
__call__ ¤
__call__(
    num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]

Generate a random initial condition on a grid with num_points points.

Arguments:

  • num_points: The number of grid points in each dimension.
  • key: A jax random key.

Returns:

  • u: The initial condition evaluated at the grid points.
Source code in exponax/ic/_base_ic.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __call__(
    self,
    num_points: int,
    *,
    key: PRNGKeyArray,
) -> Float[Array, "1 ... N"]:
    """
    Generate a random initial condition on a grid with `num_points` points.

    **Arguments**:

    - `num_points`: The number of grid points in each dimension.
    - `key`: A jax random key.

    **Returns**:

    - `u`: The initial condition evaluated at the grid points.
    """
    ic_fun = self.gen_ic_fun(key=key)
    grid = make_grid(
        self.num_spatial_dims,
        self.domain_extent,
        num_points,
        indexing=self.indexing,
    )
    return ic_fun(grid)

exponax.ic.SineWaves1d ¤

Bases: BaseIC

Source code in exponax/ic/_sine_waves_1d.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class SineWaves1d(BaseIC):
    domain_extent: float
    amplitudes: tuple[float, ...]
    wavenumbers: tuple[float, ...]
    phases: tuple[float, ...]
    offset: float

    std_one: bool
    max_one: bool

    def __init__(
        self,
        domain_extent: float,
        amplitudes: tuple[float, ...],
        wavenumbers: tuple[float, ...],
        phases: tuple[float, ...],
        offset: float = 0.0,
        std_one: bool = False,
        max_one: bool = False,
    ):
        """
        A state described by a collection of sine waves. Only works in 1d.

        **Arguments**:

        - `domain_extent`: The extent of the domain.
        - `amplitudes`: A tuple of amplitudes.
        - `wavenumbers`: A tuple of wavenumbers.
        - `phases`: A tuple of phases.
        - `offset`: A constant offset.
        - `std_one`: Whether to normalize the state to have a standard
            deviation of one. Defaults to `False`. Only works if the offset
            is zero.
        - `max_one`: Whether to normalize the state to have the maximum
            absolute value of one. Defaults to `False`. Only one of
            `std_one` and `max_one` can be `True`.
        """
        if offset != 0.0 and std_one:
            raise ValueError("Cannot have non-zero offset and `std_one=True`.")
        if std_one and max_one:
            raise ValueError("Cannot have `std_one=True` and `max_one=True`.")

        if len(amplitudes) != len(wavenumbers) or len(wavenumbers) != len(phases):
            raise ValueError(
                "The number of amplitudes, wavenumbers, and phases must be the same."
            )

        self.domain_extent = domain_extent
        self.amplitudes = amplitudes
        self.wavenumbers = wavenumbers
        self.phases = phases
        self.offset = offset
        self.std_one = std_one
        self.max_one = max_one

    def __call__(self, x: Float[Array, "1 N"]) -> Float[Array, "1 N"]:
        if x.shape[0] != 1:
            raise ValueError("SineWaves1d only works in 1d.")
        result = jnp.zeros_like(x)
        for a, k, p in zip(self.amplitudes, self.wavenumbers, self.phases):
            result += a * jnp.sin(k * (2 * jnp.pi / self.domain_extent) * x + p)
        result += self.offset

        if self.std_one:
            result = result / jnp.std(result)

        if self.max_one:
            result = result / jnp.max(jnp.abs(result))

        return result
__init__ ¤
__init__(
    domain_extent: float,
    amplitudes: tuple[float, ...],
    wavenumbers: tuple[float, ...],
    phases: tuple[float, ...],
    offset: float = 0.0,
    std_one: bool = False,
    max_one: bool = False,
)

A state described by a collection of sine waves. Only works in 1d.

Arguments:

  • domain_extent: The extent of the domain.
  • amplitudes: A tuple of amplitudes.
  • wavenumbers: A tuple of wavenumbers.
  • phases: A tuple of phases.
  • offset: A constant offset.
  • std_one: Whether to normalize the state to have a standard deviation of one. Defaults to False. Only works if the offset is zero.
  • max_one: Whether to normalize the state to have the maximum absolute value of one. Defaults to False. Only one of std_one and max_one can be True.
Source code in exponax/ic/_sine_waves_1d.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    domain_extent: float,
    amplitudes: tuple[float, ...],
    wavenumbers: tuple[float, ...],
    phases: tuple[float, ...],
    offset: float = 0.0,
    std_one: bool = False,
    max_one: bool = False,
):
    """
    A state described by a collection of sine waves. Only works in 1d.

    **Arguments**:

    - `domain_extent`: The extent of the domain.
    - `amplitudes`: A tuple of amplitudes.
    - `wavenumbers`: A tuple of wavenumbers.
    - `phases`: A tuple of phases.
    - `offset`: A constant offset.
    - `std_one`: Whether to normalize the state to have a standard
        deviation of one. Defaults to `False`. Only works if the offset
        is zero.
    - `max_one`: Whether to normalize the state to have the maximum
        absolute value of one. Defaults to `False`. Only one of
        `std_one` and `max_one` can be `True`.
    """
    if offset != 0.0 and std_one:
        raise ValueError("Cannot have non-zero offset and `std_one=True`.")
    if std_one and max_one:
        raise ValueError("Cannot have `std_one=True` and `max_one=True`.")

    if len(amplitudes) != len(wavenumbers) or len(wavenumbers) != len(phases):
        raise ValueError(
            "The number of amplitudes, wavenumbers, and phases must be the same."
        )

    self.domain_extent = domain_extent
    self.amplitudes = amplitudes
    self.wavenumbers = wavenumbers
    self.phases = phases
    self.offset = offset
    self.std_one = std_one
    self.max_one = max_one
__call__ ¤
__call__(x: Float[Array, '1 N']) -> Float[Array, '1 N']
Source code in exponax/ic/_sine_waves_1d.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __call__(self, x: Float[Array, "1 N"]) -> Float[Array, "1 N"]:
    if x.shape[0] != 1:
        raise ValueError("SineWaves1d only works in 1d.")
    result = jnp.zeros_like(x)
    for a, k, p in zip(self.amplitudes, self.wavenumbers, self.phases):
        result += a * jnp.sin(k * (2 * jnp.pi / self.domain_extent) * x + p)
    result += self.offset

    if self.std_one:
        result = result / jnp.std(result)

    if self.max_one:
        result = result / jnp.max(jnp.abs(result))

    return result