Additional features of Exponax
¤
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import exponax as ex
Repeated stepper for more accurate (or more stable) integration¤
The (conservative) second order Kuramoto-Sivshinsky stepper is unstable for a \(\Delta t = 1.0\) under default configurations.
DOMAIN_EXTENT = 60.0
NUM_POINTS = 100
DT = 1.0
ks_stepper = ex.stepper.KuramotoSivashinskyConservative(
1, DOMAIN_EXTENT, NUM_POINTS, DT
)
grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jax.random.normal(
jax.random.PRNGKey(0),
(
1,
NUM_POINTS,
),
) # Initial condition does not matter for KS equation
unstable_trj = ex.rollout(ks_stepper, 500, include_init=True)(u_0)
plt.imshow(
unstable_trj[:, 0, :].T,
aspect="auto",
cmap="RdBu_r",
vmin=-2,
vmax=2,
origin="lower",
)
print(unstable_trj[-1, 0, 1:5]) # All NaNs
We can instantiate a stepper that performs only half of that \(\Delta t\) and repeat it twice to get the same result. However, now it is stable.
ks_stepper_half_step = ex.stepper.KuramotoSivashinskyConservative(
1, DOMAIN_EXTENT, NUM_POINTS, DT / 2
)
ks_stepper_substepper = ex.RepeatedStepper(ks_stepper_half_step, 2)
stable_trj = ex.rollout(ks_stepper_substepper, 500, include_init=True)(u_0)
plt.imshow(
stable_trj[:, 0, :].T, aspect="auto", cmap="RdBu_r", vmin=-2, vmax=2, origin="lower"
)
Adding a forcing term¤
So far the steppers only consider "transient PDEs" that are not externally
forded. Exponax
has a quick hack to turn a stepper of the signature $
\mathcal{P}: \R^{1 \times N} \mapsto \R^{1 \times N} $ into a stepper of the
signature $ \mathcal{P}: \R^{1 \times N} \times \R^{1 \times N} \mapsto \R^{1
\times N} $ which takes an additional forcing term. It is implemented by a
simple Euler integration before the actual ETDRK step. As such, it reduces the
overall temporal integration order to one.
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
NU = 0.01
diffusion_stepper = ex.stepper.Diffusion(
1, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU
)
forced_diffusion_stepper = ex.ForcedStepper(diffusion_stepper)
grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(6 * jnp.pi * grid / DOMAIN_EXTENT)
# Have a constant forcing term but we could also supply a time-dependent forcing
# trajectory
forcing = 0.5 * jnp.sin(2 * jnp.pi * grid / DOMAIN_EXTENT)
original_trj = ex.rollout(diffusion_stepper, 200, include_init=True)(u_0)
# We need to inform the rollout transformation that the signature of the stepper
# is different
forced_trj = ex.rollout(
forced_diffusion_stepper, 200, include_init=True, takes_aux=True, constant_aux=True
)(u_0, forcing)
fig, ax = plt.subplots(2, 1, sharex=True, sharey=True)
ax[0].imshow(original_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1)
ax[0].set_title("Original Heat Equation")
ax[1].imshow(forced_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1)
ax[1].set_title("Forced Heat Equation")
Spectral derivatives¤
Any field discretized can be derived spectrally.
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u = jnp.sin(6 * jnp.pi * grid / DOMAIN_EXTENT)
u_prime_exact = jnp.cos(6 * jnp.pi * grid / DOMAIN_EXTENT) * 6 * jnp.pi / DOMAIN_EXTENT
# You can select the order of differencing (first derivative, second derivative,
# etc.) with the `order` keyword argument
u_prime_spectrally = ex.derivative(u, DOMAIN_EXTENT)
# Results a very small value (~1e-5 when computed in single precision)
jnp.sqrt(jnp.mean(jnp.square(u_prime_exact - u_prime_spectrally)))
Moving between resolutions (Upsampling and Downsampling)¤
Using its Fourier representation, Exponax
can upsample (increase the
resolution) or downsample (decrease the resolution) of a state. With the
following nice benefits:
- For upsampling: If the state on the initial (lower) resolution is bandlimited the interpolation is exact.
- For downsampling: If the state on the initial (higher) resolution and on the new (lower) resolution is bandlimited the decimation is exact.
DOMAIN_EXTENT = 2 * jnp.pi
NUM_POINTS_LOW = 10
NUM_POINTS_HIGH = 100
grid_1d = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS_LOW)
u = jnp.sin(grid_1d)
ex.viz.plot_state_1d(u, domain_extent=DOMAIN_EXTENT)
u_upsampled = ex.map_between_resolutions(u, NUM_POINTS_HIGH)
ex.viz.plot_state_1d(u_upsampled, domain_extent=DOMAIN_EXTENT)
A usefull application for this is to compress states or entire trajectories if they are highly bandlimited.
grid_1d_high = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS_HIGH)
u_complicated_high = jnp.exp(-((grid_1d_high - jnp.pi) ** 2))
ex.viz.plot_state_1d(u_complicated_high, domain_extent=DOMAIN_EXTENT)
u_complicated_low = ex.map_between_resolutions(u_complicated_high, NUM_POINTS_LOW)
ex.viz.plot_state_1d(u_complicated_low, domain_extent=DOMAIN_EXTENT)
Interpolation¤
Exponax
can use the Fourier representation to interpolate a state at any point
in space which can also be different from the original grid. For example, this
allows to represent a state on an unstructured grid.
If the state is bandlimited, the interpolation is exact (but costly; check the
documentation of exponax.FourierInterpolator
for more details).
u = ex.ic.RandomTruncatedFourierSeries(2, cutoff=5, max_one=True)(
50, key=jax.random.PRNGKey(0)
)
ex.viz.plot_state_2d(u)
interpolator = ex.FourierInterpolator(u)
unstructured_coordinates = jax.random.uniform(
jax.random.PRNGKey(0), (300, 2), minval=0.3, maxval=0.8
)
values = jax.vmap(interpolator)(unstructured_coordinates)
import matplotlib.tri as tri
triang = tri.Triangulation(
unstructured_coordinates[:, 0], unstructured_coordinates[:, 1]
)
plt.tricontourf(triang, values[:, 0], cmap="RdBu_r")
plt.scatter(unstructured_coordinates[:, 0], unstructured_coordinates[:, 1], c="black")
plt.xlim(0, 1)
plt.ylim(0, 1)