Showcase of Initial Conditions in 1d¤
Initial conditions affect the dynamics produced by the timesteppers.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import exponax as ex
Introduction¤
In the simplest case, we can define an initial condition by creating a grid and discretizing a function on it. For example, the first sine mode.
DOMAIN_EXTENT = 10.0
NUM_POINTS = 50
grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(2 * jnp.pi / DOMAIN_EXTENT * grid)
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
This can be randomized by changing the amplitude and phase of the sine mode. Here, we will draw these values from a uniform distribution:
then the initial condition is
main_key = jax.random.PRNGKey(0)
amplitude_key, phase_key = jax.random.split(main_key)
amplitude = jax.random.uniform(amplitude_key, (), minval=-1.0, maxval=1.0)
phase = jax.random.uniform(phase_key, (), minval=0.0, maxval=2 * jnp.pi)
print(f"Amplitude: {amplitude:.3f}")
print(f"Phase: {phase:.3f}")
ic_fun = lambda x: amplitude * jnp.sin(2 * jnp.pi / DOMAIN_EXTENT * x + phase)
u_0 = ic_fun(grid)
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
Let's also create an initial condition that consists of two sine modes, the first and the second. We will again draw the amplitudes and phases from the same uniform distribution.
amplitude_s = jax.random.uniform(amplitude_key, (2,), minval=-1.0, maxval=1.0)
phase_s = jax.random.uniform(phase_key, (2,), minval=0.0, maxval=2 * jnp.pi)
print(f"Amplitudes: {amplitude_s}")
print(f"Phases: {phase_s}")
ic_fun = lambda x: amplitude_s[0] * jnp.sin(
2 * jnp.pi / DOMAIN_EXTENT * x + phase_s[0]
) + amplitude_s[1] * jnp.sin(4 * jnp.pi / DOMAIN_EXTENT * x + phase_s[1])
u_0 = ic_fun(grid)
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
Notice that in both cases we first created a lambda function that represented the "continuous" version of the initial condition function which we subsequently discretized on the grid.
In Exponax
there is a representation for such a sine function that is (like
the time steppers or any other equinox.Module
) simply a callable PyTree.
sine_waves = ex.ic.SineWaves1d(
DOMAIN_EXTENT, amplitude_s, jnp.arange(1, 2 + 1), phase_s
)
sine_waves
We can also evaluate it on the grid
u_0 = sine_waves(grid)
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
Exponax
offers a way to randomly generate initial conditions of this sine
type. For this, we first instantiate a RandomSineWaves1d
generator. You can
think of it as a "distribution over initial conditions" of the aforementioned
type.
sine_ic_gen = ex.ic.RandomSineWaves1d(
1, # num_spatial_dims fixed to 1
domain_extent=DOMAIN_EXTENT,
cutoff=2, # how many modes are included
)
This generator can now be queried to produce another sine wave initial condition (as a "continuous function")
new_sine_waves_ic = sine_ic_gen.gen_ic_fun(key=jax.random.PRNGKey(1))
new_sine_waves_ic
Let us evaluate it on the grid
u_0 = new_sine_waves_ic(grid)
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
Oftentimes, when creating initial conditions for the dynamics in Exponax
, we
are not interested in its continuous representation, but rather in its
discretized version. For this, we can directly call the generator with the
number of points as an argument. Internally, it will generate a function and a
grid and return the discretized function.
Both ways should result in the same initial state (given the same random key).
u_0 = sine_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
For many applications, we are interested in rolling out dynamics for multiple initial conditions. For this, we can batch execute the generator on a range of keys.
main_key = jax.random.PRNGKey(42)
five_keys = jax.random.split(main_key, 5)
five_u_0 = jnp.stack([sine_ic_gen(NUM_POINTS, key=key) for key in five_keys])
print(five_u_0.shape)
ex.viz.plot_state_1d_facet(five_u_0, domain_extent=DOMAIN_EXTENT)
Gaussian Blobs¤
gaussian_blob_ic_gen = ex.ic.RandomGaussianBlobs(
1,
domain_extent=DOMAIN_EXTENT,
num_blobs=1,
)
u_0 = gaussian_blob_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
Discontinuities¤
discontinuity_ic_gen = ex.ic.RandomDiscontinuities(
1,
domain_extent=DOMAIN_EXTENT,
num_discontinuities=1,
)
u_0 = discontinuity_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
Gaussian Random Field¤
grf_ic_gen = ex.ic.GaussianRandomField(
1,
)
u_0 = grf_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
Diffused Noise¤
diffused_noise_ic_gen = ex.ic.DiffusedNoise(
1,
)
u_0 = diffused_noise_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
Truncated Fourier Series¤
Conceptually very similar to the sine waves, but also correctly operates in
higher dimensions. However, we cannot explicitly draw a continuous function. The
produced state is also different for different resolutions (num_points
).
truncated_fourier_ic_gen = ex.ic.RandomTruncatedFourierSeries(
1,
cutoff=2,
)
u_0 = truncated_fourier_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
Meta Generators¤
Multi-Channel IC Generator¤
For example relevant for dynamics with multiple species. (In higher dimensions also relevant for convection problems like Burgers equation)
truncated_fourier_ic_gen = ex.ic.RandomTruncatedFourierSeries(
1,
cutoff=4,
)
multi_channel_ic_gen = ex.ic.RandomMultiChannelICGenerator(
[
truncated_fourier_ic_gen,
truncated_fourier_ic_gen, # Can also use two different generators
]
)
u_0 = multi_channel_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
Clamping IC Generator¤
Clamps the initial condition to a certain range. For example, this is useful for some dynamics that expect initial conditions to be within \([0, 1]\) (like the Fisher-KPP equation).
truncated_fourier_ic_gen = ex.ic.RandomTruncatedFourierSeries(
1,
cutoff=4,
)
clamped_truncated_fourier_ic_gen = ex.ic.ClampingICGenerator(
truncated_fourier_ic_gen,
limits=(0.0, 0.75),
)
u_0 = clamped_truncated_fourier_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
Scaled IC Generator¤
Scales the initial condition by a fixed value.
truncated_fourier_ic_gen = ex.ic.RandomTruncatedFourierSeries(
1,
cutoff=4,
)
scaled_truncated_fourier_ic_gen = ex.ic.ScaledICGenerator(
truncated_fourier_ic_gen,
scale=0.1,
)
u_0 = scaled_truncated_fourier_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
Modifications¤
Almost all IC generators support the following modifications:
zero_mean
: Shifts the mean of the initial condition to zero.std_one
: Scales the standard deviation of the initial condition to one. (this only works withzero_mean = True
enabled)max_one
: Scales the maximum absolute of the initial condition to one.
Especially the last is helpful to limit the order of magnitude of the initial condition. This is important for many dynamics, as they are sensitive to it (e.g., Burgers equation).
grf_ic_gen = ex.ic.RandomDiscontinuities(
1,
num_discontinuities=1,
zero_mean=True,
)
u_0 = grf_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
grf_ic_gen_nonzero = ex.ic.RandomDiscontinuities(
1,
num_discontinuities=1,
zero_mean=False,
)
u_0 = grf_ic_gen_nonzero(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
grf_std_one = ex.ic.RandomDiscontinuities(
1,
num_discontinuities=1,
zero_mean=True,
std_one=True,
)
u_0 = grf_std_one(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)
grf_max_one = ex.ic.RandomDiscontinuities(
1,
num_discontinuities=1,
max_one=True,
)
u_0 = grf_max_one(NUM_POINTS, key=jax.random.PRNGKey(1))
ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT)