Performance Hints¤
How to run Exponax
even faster than it already is 😉.
This is beyond some general insights:
- Whenever the
exponax.rollout
orexponax.repeat
function transformations are used, they internally perform ajax.jit
over the timestepper. Hence, there is no need to wrap the resulting function in ajax.jit
again. However, when using the timesteppers directly, it can be advantageous to Just-In-Time compile them. - The number of total degrees of freedom scale exponentially with the number of
dimensions; so does the cost of the spatial FFT and hence the cost of
simulation. As a good guideline on a modern GPU:
- 1d: Highest still nice
num_points
is between 10'000 to 100'000. For most problems, 50-500 points are likely sufficient. - 2d: Highest still nice
num_points
is around 500 (-> 25k total DoF per channel). For most problems, 50-256 points are likely sufficient. - 3d: Highest still nice
num_points
is around 48 (-> 110k total DoF per channel). In general, 3d sims will be tough.
- 1d: Highest still nice
- The produced trajectory array is as large as the number of time steps
performed. Hence, if the underlying discretization already has a lot of
total DoF, the trajectory array can become quite large. If you are only
interested in every n-th step, consider wrapping the time stepper in a
RepeatedStepper
- Some usages of
jax.vmap
only work efficiently on GPUs & TPUs, on CPUs JAX resorts to sequential looping.
import jax
import jax.numpy as jnp
import exponax as ex
import equinox as eqx
Temporal Rollout in Fourier space¤
The methods of Exponax
advance a state to the next time step in Fourier space.
If the stepper is called with a state in physical space, it is first transformed
to Fourier space, then advanced, and finally transformed back to physical space.
This is done for each time step. We can also integrate it directly in Fourier
space and then backtransform the entire trajectory.
This especially saves compute for lower orders of EDTRK integrators (in the greatest sense for linear PDEs) that perform fewer FFTs per time step.
NUM_SPATIAL_DIMS = 1
DOMAIN_EXTENT = 3.0
NUM_POINTS = 100
DT = 0.1
burgers_stepper = ex.stepper.Burgers(NUM_SPATIAL_DIMS, DOMAIN_EXTENT, NUM_POINTS, DT)
u_0 = ex.ic.RandomTruncatedFourierSeries(
NUM_SPATIAL_DIMS,
cutoff=5,
max_one=True,
)(NUM_POINTS, key=jax.random.PRNGKey(0))
ex.viz.plot_state_1d(u_0)
The state of the initial condition is a 1x100 tensor with real floating point values
u_0.shape, u_0.dtype
For an integration in Fourier space, we have to transform it to Fourier space.
Important, whenever we are dealing with FFTs in Exponax
, we need to use
rfft
.
u_0_hat = jnp.fft.rfft(u_0)
Its shape is (1, 51)
, and has complex values. (JAX' complex64
type is composed of two float32
values)
u_0_hat.shape, u_0_hat.dtype
We can use the familiar ex.rollout
function transformation but need to
transform the step function in Fourier space.
trj_hat = ex.rollout(burgers_stepper.step_fourier, 100, include_init=True)(u_0_hat)
Using a jnp.fft.irfft
will batch over all time steps. (We need to inform the
number of points because we used the real-valued FFT)
trj = jnp.fft.irfft(trj_hat, n=NUM_POINTS)
ex.viz.plot_spatio_temporal(trj)
The trajectory is identical to the one obtained by simulation in physical space
jnp.allclose(
trj,
ex.rollout(burgers_stepper, 100, include_init=True)(u_0),
atol=1e-5,
)
Ensemble simulation¤
One particular feature of Exponax
that is highly relevant for the integration
with deep learning is the batched execution.
Rather straightforward, we can jax.vmap
a timestepper to operate in muliple
states at once.
NUM_SPATIAL_DIMS = 1
DOMAIN_EXTENT = 3.0
NUM_POINTS = 100
DT = 0.1
burgers_stepper = ex.stepper.Burgers(NUM_SPATIAL_DIMS, DOMAIN_EXTENT, NUM_POINTS, DT)
ic_gen = ex.ic.RandomTruncatedFourierSeries(
NUM_SPATIAL_DIMS,
cutoff=5,
max_one=True,
)
one_u_0 = ic_gen(NUM_POINTS, key=jax.random.PRNGKey(0))
multiple_u_0 = ex.build_ic_set(
ic_gen, num_points=NUM_POINTS, num_samples=10, key=jax.random.PRNGKey(0)
)
one_u_0.shape, multiple_u_0.shape
one_u_1 = burgers_stepper(one_u_0)
# burgers_stepper(mutliple_u_0) # This will fail because the vanilla timestepper is single-batch only
multiple_u_1 = jax.vmap(burgers_stepper)(multiple_u_0)
one_u_1.shape, multiple_u_1.shape
When using jax.vmap
, we essentially share the same dynamics across all initial
states. What if we wanted to simulate the same state but with three different
dynamics? We could create a list of timesteppers and then loop over time (for
example, with a list comprehension). However, sequential looping is slow. There
is an easy way to also use JAX' automatic vectorization for that. For this we
create an ensemble of three different Burgers steppers (this will only work if
the parameter we vmap over does not change the shape the timesteppers attribute
arrays).
DIFFUSIVITIES = jnp.array([0.1, 0.3, 0.7])
burgers_stepper_ensemble = eqx.filter_vmap(
lambda nu: ex.stepper.Burgers(
NUM_SPATIAL_DIMS, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=nu
)
)(DIFFUSIVITIES)
If we inspect the single timestepper PyTree structure next to the ensemble timestepper PyTree structure, we see an additional batch axis in the internal arrays.
burgers_stepper, burgers_stepper_ensemble
First task is to make three different predictions from the single initial condition.
ensembled_u_1 = eqx.filter_vmap(lambda stepper: stepper(one_u_0))(
burgers_stepper_ensemble
)
This adds a three-dimensional batch axis to the state
ensembled_u_1.shape
We can also use both vmapping over the ensemble and the multiple initial states
ensembled_multiple_u_1 = eqx.filter_vmap(
lambda stepper: jax.vmap(stepper)(multiple_u_0)
)(burgers_stepper_ensemble)
ensembled_multiple_u_1.shape
Be mindful about the order of nested vmappings as they affect the order of axes in the returned arrays.