Additional features of Exponaxยค
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import exponax as ex
Repeated stepper for more accurate (or more stable) integrationยค
The (conservative) second order Kuramoto-Sivshinsky stepper is unstable for a \(\Delta t = 1.0\) under default configurations.
DOMAIN_EXTENT = 60.0
NUM_POINTS = 100
DT = 1.0
ks_stepper = ex.stepper.KuramotoSivashinskyConservative(
1, DOMAIN_EXTENT, NUM_POINTS, DT
)
grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jax.random.normal(
jax.random.PRNGKey(0),
(
1,
NUM_POINTS,
),
) # Initial condition does not matter for KS equation
unstable_trj = ex.rollout(ks_stepper, 500, include_init=True)(u_0)
plt.imshow(
unstable_trj[:, 0, :].T,
aspect="auto",
cmap="RdBu_r",
vmin=-2,
vmax=2,
origin="lower",
)
print(unstable_trj[-1, 0, 1:5]) # All NaNs
We can instantiate a stepper that performs only half of that \(\Delta t\) and repeat it twice to get the same result. However, now it is stable.
ks_stepper_half_step = ex.stepper.KuramotoSivashinskyConservative(
1, DOMAIN_EXTENT, NUM_POINTS, DT / 2
)
ks_stepper_substepper = ex.RepeatedStepper(ks_stepper_half_step, 2)
stable_trj = ex.rollout(ks_stepper_substepper, 500, include_init=True)(u_0)
plt.imshow(
stable_trj[:, 0, :].T, aspect="auto", cmap="RdBu_r", vmin=-2, vmax=2, origin="lower"
)
Adding a forcing termยค
So far the steppers only consider "transient PDEs" that are not externally
forded. Exponax has a quick hack to turn a stepper of the signature $
\mathcal{P}: \R^{1 \times N} \mapsto \R^{1 \times N} $ into a stepper of the
signature $ \mathcal{P}: \R^{1 \times N} \times \R^{1 \times N} \mapsto \R^{1
\times N} $ which takes an additional forcing term. It is implemented by a
simple Euler integration before the actual ETDRK step. As such, it reduces the
overall temporal integration order to one.
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
NU = 0.01
diffusion_stepper = ex.stepper.Diffusion(
1, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU
)
forced_diffusion_stepper = ex.ForcedStepper(diffusion_stepper)
grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(6 * jnp.pi * grid / DOMAIN_EXTENT)
# Have a constant forcing term but we could also supply a time-dependent forcing
# trajectory
forcing = 0.5 * jnp.sin(2 * jnp.pi * grid / DOMAIN_EXTENT)
original_trj = ex.rollout(diffusion_stepper, 200, include_init=True)(u_0)
# We need to inform the rollout transformation that the signature of the stepper
# is different
forced_trj = ex.rollout(
forced_diffusion_stepper, 200, include_init=True, takes_aux=True, constant_aux=True
)(u_0, forcing)
fig, ax = plt.subplots(2, 1, sharex=True, sharey=True)
ax[0].imshow(original_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1)
ax[0].set_title("Original Heat Equation")
ax[1].imshow(forced_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1)
ax[1].set_title("Forced Heat Equation")
Spectral derivativesยค
Any field discretized can be derived spectrally.
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u = jnp.sin(6 * jnp.pi * grid / DOMAIN_EXTENT)
u_prime_exact = jnp.cos(6 * jnp.pi * grid / DOMAIN_EXTENT) * 6 * jnp.pi / DOMAIN_EXTENT
# You can select the order of differencing (first derivative, second derivative,
# etc.) with the `order` keyword argument
u_prime_spectrally = ex.derivative(u, DOMAIN_EXTENT)
# Results a very small value (~1e-5 when computed in single precision)
jnp.sqrt(jnp.mean(jnp.square(u_prime_exact - u_prime_spectrally)))
Moving between resolutions (Upsampling and Downsampling)ยค
Using its Fourier representation, Exponax can upsample (increase the
resolution) or downsample (decrease the resolution) of a state. With the
following nice benefits:
- For upsampling: If the state on the initial (lower) resolution is bandlimited the interpolation is exact.
- For downsampling: If the state on the initial (higher) resolution and on the new (lower) resolution is bandlimited the decimation is exact.
DOMAIN_EXTENT = 2 * jnp.pi
NUM_POINTS_LOW = 10
NUM_POINTS_HIGH = 100
grid_1d = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS_LOW)
u = jnp.sin(grid_1d)
ex.viz.plot_state_1d(u, domain_extent=DOMAIN_EXTENT)
u_upsampled = ex.map_between_resolutions(u, NUM_POINTS_HIGH)
ex.viz.plot_state_1d(u_upsampled, domain_extent=DOMAIN_EXTENT)
A usefull application for this is to compress states or entire trajectories if they are highly bandlimited.
grid_1d_high = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS_HIGH)
u_complicated_high = jnp.exp(-((grid_1d_high - jnp.pi) ** 2))
ex.viz.plot_state_1d(u_complicated_high, domain_extent=DOMAIN_EXTENT)
u_complicated_low = ex.map_between_resolutions(u_complicated_high, NUM_POINTS_LOW)
ex.viz.plot_state_1d(u_complicated_low, domain_extent=DOMAIN_EXTENT)
Interpolationยค
Exponax can use the Fourier representation to interpolate a state at any point
in space which can also be different from the original grid. For example, this
allows to represent a state on an unstructured grid.
If the state is bandlimited, the interpolation is exact (but costly; check the
documentation of exponax.FourierInterpolator for more details).
u = ex.ic.RandomTruncatedFourierSeries(2, cutoff=5, max_one=True)(
50, key=jax.random.PRNGKey(0)
)
ex.viz.plot_state_2d(u)
interpolator = ex.FourierInterpolator(u)
unstructured_coordinates = jax.random.uniform(
jax.random.PRNGKey(0), (300, 2), minval=0.3, maxval=0.8
)
values = jax.vmap(interpolator)(unstructured_coordinates)
import matplotlib.tri as tri
triang = tri.Triangulation(
unstructured_coordinates[:, 0], unstructured_coordinates[:, 1]
)
plt.tricontourf(triang, values[:, 0], cmap="RdBu_r")
plt.scatter(unstructured_coordinates[:, 0], unstructured_coordinates[:, 1], c="black")
plt.xlim(0, 1)
plt.ylim(0, 1)