Skip to content

Showcase of Initial Conditions in 1d¤

Initial conditions affect the dynamics produced by the timesteppers.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import exponax as ex

Introduction¤

In the simplest case, we can define an initial condition by creating a grid and discretizing a function on it. For example, the first sine mode.

DOMAIN_EXTENT = 10.0
NUM_POINTS = 50

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)

u_0 = jnp.sin(2 * jnp.pi / DOMAIN_EXTENT * grid)

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
2024-04-09 15:29:54.285141: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.52). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

No description has been provided for this image

This can be randomized by changing the amplitude and phase of the sine mode. Here, we will draw these values from a uniform distribution:

\[ a \propto \mathcal{U}(-1, 1) \]
\[ \phi \propto \mathcal{U}(0, 2\pi) \]

then the initial condition is

\[ u(x) = a \sin(2\pi/L x + \phi) \]
main_key = jax.random.PRNGKey(0)

amplitude_key, phase_key = jax.random.split(main_key)

amplitude = jax.random.uniform(amplitude_key, (), minval=-1.0, maxval=1.0)
phase = jax.random.uniform(phase_key, (), minval=0.0, maxval=2 * jnp.pi)

print(f"Amplitude: {amplitude:.3f}")
print(f"Phase: {phase:.3f}")

ic_fun = lambda x: amplitude * jnp.sin(2 * jnp.pi / DOMAIN_EXTENT * x + phase)

u_0 = ic_fun(grid)

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
Amplitude: 0.114
Phase: 0.662

No description has been provided for this image

Let's also create an initial condition that consists of two sine modes, the first and the second. We will again draw the amplitudes and phases from the same uniform distribution.

amplitude_s = jax.random.uniform(amplitude_key, (2,), minval=-1.0, maxval=1.0)
phase_s = jax.random.uniform(phase_key, (2,), minval=0.0, maxval=2 * jnp.pi)

print(f"Amplitudes: {amplitude_s}")
print(f"Phases: {phase_s}")

ic_fun = lambda x: amplitude_s[0] * jnp.sin(
    2 * jnp.pi / DOMAIN_EXTENT * x + phase_s[0]
) + amplitude_s[1] * jnp.sin(4 * jnp.pi / DOMAIN_EXTENT * x + phase_s[1])

u_0 = ic_fun(grid)

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
Amplitudes: [-0.03947949 -0.52326775]
Phases: [3.622575  1.8797542]

No description has been provided for this image

Notice that in both cases we first created a lambda function that represented the "continuous" version of the initial condition function which we subsequently discretized on the grid.

In Exponax there is a representation for such a sine function that is (like the time steppers or any other equinox.Module simply a callable PyTree).

sine_waves = ex.ic.SineWaves1d(
    DOMAIN_EXTENT, amplitude_s, jnp.arange(1, 2 + 1), phase_s
)

sine_waves
SineWaves1d(
  domain_extent=10.0,
  amplitudes=f32[2],
  wavenumbers=i32[2],
  phases=f32[2],
  offset=0.0,
  std_one=False,
  max_one=False
)

We can also evaluate it on the grid

u_0 = sine_waves(grid)

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image

Exponax offers a way to randomly generate initial conditions of this sine type. For this, we first instantiate a RandomSineWaves1d generator. You can think of it as a "distribution over initial conditions" of the aforementioned type.

sine_ic_gen = ex.ic.RandomSineWaves1d(
    1,  # num_spatial_dims fixed to 1
    domain_extent=DOMAIN_EXTENT,
    cutoff=2,  # how many modes are included
)

This generator can now be queried to produce another sine wave initial condition (as a "continuous function")

new_sine_waves_ic = sine_ic_gen.gen_ic_fun(key=jax.random.PRNGKey(1))

new_sine_waves_ic
SineWaves1d(
  domain_extent=10.0,
  amplitudes=f32[2],
  wavenumbers=i32[2],
  phases=f32[2],
  offset=f32[],
  std_one=False,
  max_one=False
)

Let us evaluate it on the grid

u_0 = new_sine_waves_ic(grid)

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image

Oftentimes, when creating initial conditions for the dynamics in Exponax, we are not interested in its continuous representation, but rather in its discretized version. For this, we can directly call the generator with the number of points as an argument. Internally, it will generate a function and a grid and return the discretized function.

Both ways should result in the same initial state (given the same random key).

u_0 = sine_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image

For many applications, we are interested in rolling out dynamics for multiple initial conditions. For this, we can batch execute the generator on a range of keys.

main_key = jax.random.PRNGKey(42)

five_keys = jax.random.split(main_key, 5)

five_u_0 = jnp.stack([sine_ic_gen(NUM_POINTS, key=key) for key in five_keys])

print(five_u_0.shape)

ex.viz.plot_state_1d_facet(five_u_0, domain_extent=DOMAIN_EXTENT);
(5, 1, 50)

No description has been provided for this image

Since the task of evaluating a generator for multiple initial conditions is an common task, there is special function in Exponax that does this for us.

more_u_0 = ex.build_ic_set(
    sine_ic_gen,
    num_points=NUM_POINTS,
    num_samples=5,
    key=jax.random.PRNGKey(73),
)

print(more_u_0.shape)

ex.viz.plot_state_1d_facet(more_u_0, domain_extent=DOMAIN_EXTENT);
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[13], line 1
----> 1 more_u_0 = ex.build_ic_set(
      2     sine_ic_gen,
      3     num_points=NUM_POINTS,
      4     num_samples=5,
      5     key=jax.random.PRNGKey(73),
      6 )
      8 print(more_u_0.shape)
     10 ex.viz.plot_state_1d_facet(more_u_0, domain_extent=DOMAIN_EXTENT);

File ~/Repos/exponax/exponax/_utils.py:334, in build_ic_set(ic_generator, num_points, num_samples, key)
    331     ic = ic_generator(num_points, key=sub_k)
    332     return k, ic
--> 334 _, ic_set = jax.lax.scan(scan_fn, key, None, length=num_samples)
    336 return ic_set

    [... skipping hidden 9 frame]

File ~/Repos/exponax/exponax/_utils.py:331, in build_ic_set.<locals>.scan_fn(k, _)
    329 def scan_fn(k, _):
    330     k, sub_k = jr.split(k)
--> 331     ic = ic_generator(num_points, key=sub_k)
    332     return k, ic

File ~/Repos/exponax/exponax/ic/_base_ic.py:60, in BaseRandomICGenerator.__call__(self, num_points, key)
     43 def __call__(
     44     self,
     45     num_points: int,
     46     *,
     47     key: PRNGKeyArray,
     48 ) -> Float[Array, "1 ... N"]:
     49     """
     50     Generate a random initial condition.
     51 
   (...)
     58         - `u`: The initial condition evaluated at the grid points.
     59     """
---> 60     ic_fun = self.gen_ic_fun(key=key)
     61     grid = make_grid(
     62         self.num_spatial_dims,
     63         self.domain_extent,
     64         num_points,
     65         indexing=self.indexing,
     66     )
     67     return ic_fun(grid)

    [... skipping hidden 1 frame]

File ~/Repos/exponax/exponax/ic/_sine_waves_1d.py:162, in RandomSineWaves1d.gen_ic_fun(self, key)
    149 phases = jr.uniform(
    150     phase_key,
    151     shape=(self.cutoff,),
    152     minval=self.phase_range[0],
    153     maxval=self.phase_range[1],
    154 )
    155 offset = jr.uniform(
    156     offset_key,
    157     shape=(),
    158     minval=self.offset_range[0],
    159     maxval=self.offset_range[1],
    160 )
--> 162 return SineWaves1d(
    163     domain_extent=self.domain_extent,
    164     amplitudes=amplitudes,
    165     wavenumbers=jnp.arange(1, self.cutoff + 1),
    166     phases=phases,
    167     offset=offset,
    168     std_one=self.std_one,
    169     max_one=self.max_one,
    170 )

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/jax_fresh/lib/python3.10/site-packages/equinox/_better_abstract.py:225, in ABCMeta.__call__(cls, *args, **kwargs)
    220     abstract_class_vars = set(cls.__abstractclassvars__)  # pyright: ignore
    221     raise TypeError(
    222         f"Can't instantiate abstract class {cls.__name__} with abstract class "
    223         f"attributes {abstract_class_vars}"
    224     )
--> 225 self = super().__call__(*args, **kwargs)
    226 if len(cls.__abstractvars__) > 0:  # pyright: ignore
    227     abstract_class_vars = set(cls.__abstractvars__)  # pyright: ignore

    [... skipping hidden 1 frame]

File ~/Repos/exponax/exponax/ic/_sine_waves_1d.py:44, in SineWaves1d.__init__(self, domain_extent, amplitudes, wavenumbers, phases, offset, std_one, max_one)
     18 def __init__(
     19     self,
     20     domain_extent: float,
   (...)
     26     max_one: bool = False,
     27 ):
     28     """
     29     A state described by a collection of sine waves. Only works in 1d.
     30 
   (...)
     42             `std_one` and `max_one` can be `True`.
     43     """
---> 44     if offset != 0.0 and std_one:
     45         raise ValueError("Cannot have non-zero offset and `std_one=True`.")
     46     if std_one and max_one:

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/jax_fresh/lib/python3.10/site-packages/jax/_src/core.py:1443, in concretization_function_error.<locals>.error(self, arg)
   1442 def error(self, arg):
-> 1443   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function scan_fn at /home/koehler/Repos/exponax/exponax/_utils.py:329 for scan. This concrete value was not available in Python because it depends on the value of the argument k.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Gaussian Blobs¤

gaussian_blob_ic_gen = ex.ic.RandomGaussianBlobs(
    1,
    domain_extent=DOMAIN_EXTENT,
    num_blobs=1,
)

u_0 = gaussian_blob_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image

Discontinuities¤

discontinuity_ic_gen = ex.ic.RandomDiscontinuities(
    1,
    domain_extent=DOMAIN_EXTENT,
    num_discontinuities=1,
)

u_0 = discontinuity_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image

Gaussian Random Field¤

grf_ic_gen = ex.ic.GaussianRandomField(
    1,
)

u_0 = grf_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image

Diffused Noise¤

diffused_noise_ic_gen = ex.ic.DiffusedNoise(
    1,
)

u_0 = diffused_noise_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image

Truncated Fourier Series¤

Conceptually very similar to the sine waves, but also correctly operates in higher dimensions. However, we cannot explicitly draw a continuous function. The produced state is also different for different resolutions (num_points).

truncated_fourier_ic_gen = ex.ic.RandomTruncatedFourierSeries(
    1,
    cutoff=2,
)

u_0 = truncated_fourier_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image

Meta Generators¤

Multi-Channel IC Generator¤

For example relevant for dynamics with multiple species. (In higher dimensions also relevant for convection problems like Burgers equation)

truncated_fourier_ic_gen = ex.ic.RandomTruncatedFourierSeries(
    1,
    cutoff=4,
)

multi_channel_ic_gen = ex.ic.RandomMultiChannelICGenerator(
    [
        truncated_fourier_ic_gen,
        truncated_fourier_ic_gen,  # Can also use two different generators
    ]
)

u_0 = multi_channel_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image

Clamping IC Generator¤

Clamps the initial condition to a certain range. For example, this is useful for some dynamics that expect initial conditions to be within \([0, 1]\) (like the Fisher-KPP equation).

truncated_fourier_ic_gen = ex.ic.RandomTruncatedFourierSeries(
    1,
    cutoff=4,
)

clamped_truncated_fourier_ic_gen = ex.ic.ClampingICGenerator(
    truncated_fourier_ic_gen,
    limits=(0.0, 0.75),
)

u_0 = clamped_truncated_fourier_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image

Scaled IC Generator¤

Scales the initial condition by a fixed value.

truncated_fourier_ic_gen = ex.ic.RandomTruncatedFourierSeries(
    1,
    cutoff=4,
)

scaled_truncated_fourier_ic_gen = ex.ic.ScaledICGenerator(
    truncated_fourier_ic_gen,
    scale=0.1,
)

u_0 = scaled_truncated_fourier_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image

Modifications¤

Almost all IC generators support the following modifications:

  • zero_mean: Shifts the mean of the initial condition to zero.
  • std_one: Scales the standard deviation of the initial condition to one. (this only works with zero_mean = True enabled)
  • max_one: Scales the maximum absolute of the initial condition to one.

Especially the last is helpful to limit the order of magnitude of the initial condition. This is important for many dynamics, as they are sensitive to it (e.g., Burgers equation).

grf_ic_gen = ex.ic.RandomDiscontinuities(
    1,
    num_discontinuities=1,
    zero_mean=True,
)

u_0 = grf_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image
grf_ic_gen_nonzero = ex.ic.RandomDiscontinuities(
    1,
    num_discontinuities=1,
    zero_mean=False,
)

u_0 = grf_ic_gen_nonzero(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image
grf_std_one = ex.ic.RandomDiscontinuities(
    1,
    num_discontinuities=1,
    zero_mean=True,
    std_one=True,
)

u_0 = grf_std_one(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image
grf_max_one = ex.ic.RandomDiscontinuities(
    1,
    num_discontinuities=1,
    max_one=True,
)

u_0 = grf_max_one(NUM_POINTS, key=jax.random.PRNGKey(1))

ex.viz.plot_state_1d(u_0, domain_extent=DOMAIN_EXTENT);
No description has been provided for this image