Skip to content

Sample Data¤

trainax.sample_data.advection_1d_periodic ¤

advection_1d_periodic(
    num_points: int = 30,
    num_samples: int = 20,
    *,
    cfl: float = 0.75,
    highest_init_mode: int = 5,
    temporal_horizon: int = 100,
    key: PRNGKeyArray
) -> Float[
    Array, "num_samples temporal_horizon 1 num_points"
]

Produces a reference trajectory of the simulation of 1D advection with periodic boundary conditions. The solution is exact due to a Fourier spectral solver (requires highest_init_mode < num_points//2).

Arguments:

  • num_points: The number of grid points.
  • num_samples: The number of samples to generate, i.e., how many different trajectories.
  • cfl: The Courant-Friedrichs-Lewy number.
  • highest_init_mode: The highest mode of the initial condition.
  • temporal_horizon: The number of timesteps to simulate.
  • key: The random key.

Returns:

  • A tensor of shape (num_samples, temporal_horizon, 1, num_points). The singleton axis is to represent one channel to have format suitable for convolutional networks.
Source code in trainax/_sample_data.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def advection_1d_periodic(
    num_points: int = 30,
    num_samples: int = 20,
    *,
    cfl: float = 0.75,
    highest_init_mode: int = 5,
    temporal_horizon: int = 100,
    key: PRNGKeyArray,
) -> Float[Array, "num_samples temporal_horizon 1 num_points"]:
    """
    Produces a reference trajectory of the simulation of 1D advection with
    periodic boundary conditions. The solution is exact due to a Fourier
    spectral solver (requires `highest_init_mode` < `num_points//2`).

    **Arguments**:

    - `num_points`: The number of grid points.
    - `num_samples`: The number of samples to generate, i.e., how many different
        trajectories.
    - `cfl`: The Courant-Friedrichs-Lewy number.
    - `highest_init_mode`: The highest mode of the initial condition.
    - `temporal_horizon`: The number of timesteps to simulate.
    - `key`: The random key.

    **Returns**:

    - A tensor of shape `(num_samples, temporal_horizon, 1, num_points)`. The
        singleton axis is to represent one channel to have format suitable for
        convolutional networks.
    """
    init_keys = jax.random.split(key, num_samples)

    u_0 = jax.vmap(
        lambda k: _random_truncated_fourier_series_1d(
            num_points, highest_init_mode, key=k
        )
    )(init_keys)

    def scan_fn(u, _):
        u_next = _advect_analytical(u, cfl=cfl)
        return u_next, u

    def rollout(init):
        _, u_trj = jax.lax.scan(scan_fn, init, jnp.arange(temporal_horizon))
        return u_trj

    u_trj = jax.vmap(rollout)(u_0)

    u_trj_with_singleton_channel = u_trj[..., None, :]

    return u_trj_with_singleton_channel

trainax.sample_data.lorenz_rk4 ¤

lorenz_rk4(
    num_samples: int = 20,
    *,
    temporal_horizon: int = 1000,
    dt: float = 0.01,
    num_warmup_steps: int = 500,
    sigma: float = 10.0,
    rho: float = 28.0,
    beta: float = 8.0 / 3.0,
    init_std: float = 1.0,
    key: PRNGKeyArray
) -> Float[Array, "num_samples temporal_horizon 3"]

Produces reference trajectories of the simple three-equation Lorenz system when integrated with a fixed-size Runge-Kutta 4th order scheme.

\[ \begin{aligned} \frac{dx}{dt} &= \sigma (y - x) \\ \frac{dy}{dt} &= x (\rho - z) - y \\ \frac{dz}{dt} &= x y - \beta z \end{aligned} \]

The initial conditions are drawn from a standard normal distribution for each of the three variables with a prescribed standard deviation (mean is zero).

Arguments:

  • num_samples: The number of samples to generate, i.e., how many different trajectories.
  • temporal_horizon: The number of timesteps to simulate.
  • dt: The timestep size. Depending on the values of sigma, rho, and beta, the system might be hard to integrate. Usually, a time step \(\Delta t \in [0.01, 0.1]\) is a good choice.
  • num_warmup_steps: The number of steps to discard from the beginning of the trajectory.
  • sigma: The \(\sigma\) parameter of the Lorenz system.
  • rho: The \(\rho\) parameter of the Lorenz system.
  • beta: The \(\beta\) parameter of the Lorenz system.
  • init_std: The standard deviation of the initial conditions.
  • key: The random key.

Returns:

  • A tensor of shape (num_samples, temporal_horizon, 3).
Source code in trainax/_sample_data.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def lorenz_rk4(
    num_samples: int = 20,
    *,
    temporal_horizon: int = 1000,
    dt: float = 0.01,
    num_warmup_steps: int = 500,
    sigma: float = 10.0,
    rho: float = 28.0,
    beta: float = 8.0 / 3.0,
    init_std: float = 1.0,
    key: PRNGKeyArray,
) -> Float[Array, "num_samples temporal_horizon 3"]:
    r"""
    Produces reference trajectories of the simple three-equation Lorenz system
    when integrated with a fixed-size Runge-Kutta 4th order scheme.

    $$
    \begin{aligned}
    \frac{dx}{dt} &= \sigma (y - x) \\
    \frac{dy}{dt} &= x (\rho - z) - y \\
    \frac{dz}{dt} &= x y - \beta z
    \end{aligned}
    $$

    The initial conditions are drawn from a standard normal distribution for
    each of the three variables with a prescribed standard deviation (mean is
    zero).

    **Arguments**:

    - `num_samples`: The number of samples to generate, i.e., how many different
        trajectories.
    - `temporal_horizon`: The number of timesteps to simulate.
    - `dt`: The timestep size. Depending on the values of `sigma`, `rho`, and
        `beta`, the system might be hard to integrate. Usually, a time step
        $\Delta t \in [0.01, 0.1]$ is a good choice.
    - `num_warmup_steps`: The number of steps to discard from the beginning of
        the trajectory.
    - `sigma`: The $\sigma$ parameter of the Lorenz system.
    - `rho`: The $\rho$ parameter of the Lorenz system.
    - `beta`: The $\beta$ parameter of the Lorenz system.
    - `init_std`: The standard deviation of the initial conditions.
    - `key`: The random key.

    **Returns**:

    - A tensor of shape `(num_samples, temporal_horizon, 3)`.
    """

    u_0_set = jax.random.normal(key, shape=(num_samples, 3)) * init_std

    # lorenz_rhs_params_fixed = lambda u: _lorenz_rhs(u, sigma=sigma, rho=rho, beta=beta)
    # lorenz_stepper = lambda u: _step_rk4(lorenz_rhs_params_fixed, u, dt=dt)

    lorenz_stepper = make_lorenz_stepper_rk4(dt=dt, sigma=sigma, rho=rho, beta=beta)

    def scan_fn(u, _):
        u_next = lorenz_stepper(u)
        return u_next, u

    def rollout(init):
        _, u_trj = jax.lax.scan(
            scan_fn, init, None, length=temporal_horizon + num_warmup_steps
        )
        return u_trj

    trj_set = jax.vmap(rollout)(u_0_set)

    # Slice away the warmup steps
    trj_set = trj_set[:, num_warmup_steps:]

    return trj_set