Skip to content

Additional features of Exponax¤

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import exponax as ex

Repeated stepper for more accurate (or more stable) integration¤

The (conservative) second order Kuramoto-Sivshinsky stepper is unstable for a \(\Delta t = 1.0\) under default configurations.

DOMAIN_EXTENT = 60.0
NUM_POINTS = 100
DT = 1.0

ks_stepper = ex.stepper.KuramotoSivashinskyConservative(
    1, DOMAIN_EXTENT, NUM_POINTS, DT
)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jax.random.normal(
    jax.random.PRNGKey(0),
    (
        1,
        NUM_POINTS,
    ),
)  # Initial condition does not matter for KS equation

unstable_trj = ex.rollout(ks_stepper, 500, include_init=True)(u_0)

plt.imshow(
    unstable_trj[:, 0, :].T,
    aspect="auto",
    cmap="RdBu_r",
    vmin=-2,
    vmax=2,
    origin="lower",
)

print(unstable_trj[-1, 0, 1:5])  # All NaNs
2024-09-04 12:42:44.276565: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.6.68). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

[nan nan nan nan]

No description has been provided for this image

We can instantiate a stepper that performs only half of that \(\Delta t\) and repeat it twice to get the same result. However, now it is stable.

ks_stepper_half_step = ex.stepper.KuramotoSivashinskyConservative(
    1, DOMAIN_EXTENT, NUM_POINTS, DT / 2
)
ks_stepper_substepper = ex.RepeatedStepper(ks_stepper_half_step, 2)

stable_trj = ex.rollout(ks_stepper_substepper, 500, include_init=True)(u_0)

plt.imshow(
    stable_trj[:, 0, :].T, aspect="auto", cmap="RdBu_r", vmin=-2, vmax=2, origin="lower"
)
<matplotlib.image.AxesImage at 0x753363bb0730>
No description has been provided for this image

Adding a forcing term¤

So far the steppers only consider "transient PDEs" that are not externally forded. Exponax has a quick hack to turn a stepper of the signature $ \mathcal{P}: \R^{1 \times N} \mapsto \R^{1 \times N} $ into a stepper of the signature $ \mathcal{P}: \R^{1 \times N} \times \R^{1 \times N} \mapsto \R^{1 \times N} $ which takes an additional forcing term. It is implemented by a simple Euler integration before the actual ETDRK step. As such, it reduces the overall temporal integration order to one.

DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
NU = 0.01

diffusion_stepper = ex.stepper.Diffusion(
    1, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU
)
forced_diffusion_stepper = ex.ForcedStepper(diffusion_stepper)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(6 * jnp.pi * grid / DOMAIN_EXTENT)

# Have a constant forcing term but we could also supply a time-dependent forcing
# trajectory
forcing = 0.5 * jnp.sin(2 * jnp.pi * grid / DOMAIN_EXTENT)

original_trj = ex.rollout(diffusion_stepper, 200, include_init=True)(u_0)
# We need to inform the rollout transformation that the signature of the stepper
# is different
forced_trj = ex.rollout(
    forced_diffusion_stepper, 200, include_init=True, takes_aux=True, constant_aux=True
)(u_0, forcing)

fig, ax = plt.subplots(2, 1, sharex=True, sharey=True)
ax[0].imshow(original_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1)
ax[0].set_title("Original Heat Equation")
ax[1].imshow(forced_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1)
ax[1].set_title("Forced Heat Equation")
Text(0.5, 1.0, 'Forced Heat Equation')
No description has been provided for this image

Spectral derivatives¤

Any field discretized can be derived spectrally.

DOMAIN_EXTENT = 1.0
NUM_POINTS = 100

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u = jnp.sin(6 * jnp.pi * grid / DOMAIN_EXTENT)
u_prime_exact = jnp.cos(6 * jnp.pi * grid / DOMAIN_EXTENT) * 6 * jnp.pi / DOMAIN_EXTENT

# You can select the order of differencing (first derivative, second derivative,
# etc.) with the `order` keyword argument
u_prime_spectrally = ex.derivative(u, DOMAIN_EXTENT)

# Results a very small value (~1e-5 when computed in single precision)
jnp.sqrt(jnp.mean(jnp.square(u_prime_exact - u_prime_spectrally)))
Array(4.3673816e-05, dtype=float32)

Moving between resolutions (Upsampling and Downsampling)¤

Using its Fourier representation, Exponax can upsample (increase the resolution) or downsample (decrease the resolution) of a state. With the following nice benefits:

  • For upsampling: If the state on the initial (lower) resolution is bandlimited the interpolation is exact.
  • For downsampling: If the state on the initial (higher) resolution and on the new (lower) resolution is bandlimited the decimation is exact.
DOMAIN_EXTENT = 2 * jnp.pi
NUM_POINTS_LOW = 10
NUM_POINTS_HIGH = 100
grid_1d = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS_LOW)

u = jnp.sin(grid_1d)

ex.viz.plot_state_1d(u, domain_extent=DOMAIN_EXTENT)
No description has been provided for this image
u_upsampled = ex.map_between_resolutions(u, NUM_POINTS_HIGH)
ex.viz.plot_state_1d(u_upsampled, domain_extent=DOMAIN_EXTENT)
No description has been provided for this image

A usefull application for this is to compress states or entire trajectories if they are highly bandlimited.

grid_1d_high = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS_HIGH)

u_complicated_high = jnp.exp(-((grid_1d_high - jnp.pi) ** 2))

ex.viz.plot_state_1d(u_complicated_high, domain_extent=DOMAIN_EXTENT)
No description has been provided for this image
u_complicated_low = ex.map_between_resolutions(u_complicated_high, NUM_POINTS_LOW)

ex.viz.plot_state_1d(u_complicated_low, domain_extent=DOMAIN_EXTENT)
No description has been provided for this image

Interpolation¤

Exponax can use the Fourier representation to interpolate a state at any point in space which can also be different from the original grid. For example, this allows to represent a state on an unstructured grid.

If the state is bandlimited, the interpolation is exact (but costly; check the documentation of exponax.FourierInterpolator for more details).

u = ex.ic.RandomTruncatedFourierSeries(2, cutoff=5, max_one=True)(
    50, key=jax.random.PRNGKey(0)
)
ex.viz.plot_state_2d(u)
No description has been provided for this image
interpolator = ex.FourierInterpolator(u)
unstructured_coordinates = jax.random.uniform(
    jax.random.PRNGKey(0), (300, 2), minval=0.3, maxval=0.8
)
values = jax.vmap(interpolator)(unstructured_coordinates)
import matplotlib.tri as tri

triang = tri.Triangulation(
    unstructured_coordinates[:, 0], unstructured_coordinates[:, 1]
)
plt.tricontourf(triang, values[:, 0], cmap="RdBu_r")
plt.scatter(unstructured_coordinates[:, 0], unstructured_coordinates[:, 1], c="black")
plt.xlim(0, 1)
plt.ylim(0, 1)
(0.0, 1.0)
No description has been provided for this image