Simple Advection exampleยค
The advection equation
is one of the simplest partial differential equation. It describes a hyperbolic transport process. Under periodic boundary conditions, the solution at a point later in time is given by the moved initial condition
using some loose notation of the modulo operator to enfore periodicity.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import exponax as ex
Let's first create a grid that discretizes the domain \([0, L]\) into \(N\) degrees
of freedom. Since we work with periodic domains, one of the boundary points,
either \(x=0\) or \(x=L\), is redundant. Hence, we will linearly space \(N+1\) points
over the domain and discard the last one. (Actually, discarding one boundary
point is a prerequisite for using the FFT algorithm that is baked in all the
solvers that are part of Exponax
.)
DOMAIN_EXTENT = 1.0
NUM_POINTS = 20
grid = jnp.linspace(0, DOMAIN_EXTENT, NUM_POINTS + 1)[:-1]
plt.scatter(grid, jnp.zeros_like(grid))
plt.xlim(-0.1, 1.1)
plt.grid()
We can also print out the jax.numpy
array containing the grid
grid
Handily, there is a function in Exponax
that produces a grid
grid_from_exponax = ex.make_grid(
1,
DOMAIN_EXTENT,
NUM_POINTS,
)
grid_from_exponax
This function takes three positional arguments:
num_spatial_dims
defining the dimensionality of the domain. For this tutorial, we work in 1D, so we pass1
.domain_extent
num_points
Note that the shape of the grid array is different from the shape we got from jax.numpy.linspace
.
grid.shape, grid_from_exponax.shape
It has an additional singleton dimension in the beginning. This is to
represent the dimensionality of the grid. In 1D, the shape of the grid is (1,
N)
. In 2D, it would be (2, N, N)
, and so on.
Let's work with the grid_from_eponax
from now on.
grid = grid_from_exponax
On this domain, we can discretize a function. Let's, for instance, use the first sine mode \(u_0(x) = \sin(2 \pi x / L)\).
Notice that we have to index both the grid array and the function array at [0]
to remove the singleton dimension, and get an array for plotting.
Notice that we do not have a function value at \(x=1\) since we discarded the last mesh point.
ic_fun = lambda x: jnp.sin(2 * jnp.pi * x / DOMAIN_EXTENT)
ic = ic_fun(grid)
plt.plot(grid[0], ic[0])
plt.xlim(-0.1, 1.1)
plt.grid()
If we wanted to plot the function including its periodic extension, we had to
wrap around the domain. This can be done by the wrap_bc
function. Note that in
order to then plot the function, we also need the "full grid" including the
redundant point.
Notice again that we index at [0]
to remove the singleton dimension.
full_grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS, full=True)
full_ic = ex.wrap_bc(ic)
plt.plot(full_grid[0], full_ic[0])
plt.xlim(-0.1, 1.1)
plt.grid()
Next, let's instantiate the advection timestepper. We need to specify the velocity \(c\) and the timestep \(\Delta t\).
Note that (almost) all timesteppers in Exponax take as first four positional arguments:
- The
num_spatial_dims
defining the dimensionality of the domain. For this tutorial, we work in 1D, so we pass1
. - The
domain_extent
- The
num_points
(Important: Exclude the redundant point!) - The timestep
dt
in this order.
Other options, such as coefficients/constitutional parameters or numerical parameters, are passed as keyword arguments.
An advection problem is defined by its external velocity which we assume to be a constant \(c=1.0\).
VELOCITY = 1.0
DT = 0.2
advection_stepper = ex.stepper.Advection(
1,
DOMAIN_EXTENT,
NUM_POINTS,
DT,
velocity=VELOCITY,
)
advection_stepper
Once instantiated, the timesteppers behave like a Python callable, taking a discretized function at time \(t\) (now called a state) and returning the discretized function at time \(t + \Delta t\) (i.e., the next state).
u_1 = advection_stepper(ic)
Let's plot the solution at \(t=0\) and \(t=0.1\) next to each other.
(Again, notice the indexing at [0]
to remove the singleton dimension.)
plt.plot(full_grid[0], ex.wrap_bc(ic)[0], label="ic")
plt.plot(full_grid[0], ex.wrap_bc(u_1)[0], label="1 step")
plt.xlim(-0.1, 1.1)
plt.grid()
plt.legend()
You notice that the initial condition moved by \(0.2\) space units to the right. Additionally, whatever left the domain at the right boundary re-entered at the left boundary. The shape of the function is preserved!
Moving by \(0.2\) space units in \(0.2\) time units corresponds to a velocity of \(c = 1\) which is exactly what we prescribed. Feel free to play around with the velocity \(c\), the timestep \(\Delta t\), and the domain size \(L\) to see how the solution changes. You will notice that ultimately only the value \(\frac{c \Delta t}{L}\) matters.
For any time-dependent phenomenon, we might be interested in not just computing one timestep into the future, but many. This is an inherently sequential process because we need the solution at time \(t\) to compute the solution at time \(t + \Delta t\). Let's compute the three following solutions.
u_2 = advection_stepper(u_1)
u_3 = advection_stepper(u_2)
u_4 = advection_stepper(u_3)
plt.plot(full_grid[0], ex.wrap_bc(ic)[0], label="ic")
plt.plot(full_grid[0], ex.wrap_bc(u_1)[0], label="1 step")
plt.plot(full_grid[0], ex.wrap_bc(u_2)[0], label="2 steps")
plt.plot(full_grid[0], ex.wrap_bc(u_3)[0], label="3 steps")
plt.plot(full_grid[0], ex.wrap_bc(u_4)[0], label="4 steps")
plt.xlim(-0.1, 1.1)
plt.grid()
plt.legend()
If we had performed another timestep, we would have returned to the initial condition because doing \(5\) steps with a \(c \cdot \Delta t = 0.2\) velocity corresponds to moving by \(1\) space unit. This is exactly the length of the domain \(L\)!
Now, we are interested in stacking the time steps into a trajectory. Manually, we could do this the following way.
short_trajectory = jnp.stack(
[
ic,
u_1,
u_2,
u_3,
u_4,
]
)
short_trajectory.shape
There is also a handy function within Exponax
that does exactly this. In the
spirit of JAX
it is a function transformation. So far the advection_stepper
is a mapping of \(\R^{1 \times N} \mapsto \R^{1 \times N}\). This function
transformation turns it into a mapping of \(\R^{1 \times N} \mapsto \R^{T \times
1 \times N}\) where \(T\) is the number of time steps we want to perform into the
future. It has the additional keyword flag to include the initial condition in
the trajectory. Then, we mapping is \(\R^{N+1} \mapsto \R^{(T+1) \times 1 \times N}\).
short_rollout_advection_stepper = ex.rollout(advection_stepper, 4, include_init=True)
jnp.allclose(short_trajectory, short_rollout_advection_stepper(ic))
Let's instantiate another advection timestepper with a smaller timestep and roll it out for \(200\) steps
Let's use the rollout
transformation for a longer trajectory, say \(200\) steps.
SMALLER_DT = 0.01
slower_advection_stepper = ex.stepper.Advection(
1, DOMAIN_EXTENT, NUM_POINTS, SMALLER_DT, velocity=VELOCITY
)
longer_rollout_advection_stepper = ex.rollout(
slower_advection_stepper, 200, include_init=True
)
longer_trajectory = longer_rollout_advection_stepper(ic)
Then we can jax.vmap
the wrap_bc
function over the trajectory.
longer_trajectory_wrapped = jax.vmap(ex.wrap_bc)(longer_trajectory)
Then, we can visualize the trajectory as a spatio-temporal plot.
The indexing [:, 0, :]
takes all temporal shapshots, removes the singleton
domension, and takes all spatial points. We use a transposition .T
to have the
time dimension on the horizontal axis.
plt.imshow(
longer_trajectory_wrapped[:, 0, :].T,
origin="lower",
cmap="RdBu",
vmin=-1,
vmax=1,
extent=[0, 200 * SMALLER_DT, 0, DOMAIN_EXTENT],
)
plt.colorbar()
plt.xlabel("time")
plt.ylabel("space")
plt.title("advection")