Creating custom solvers¤
This is a showcase how the sample steppers in exponax.sample_stepper
are implemented.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import exponax as ex
from exponax import BaseStepper
from exponax.nonlin_fun import BaseNonlinearFun, ConvectionNonlinearFun
An ultra-viscous? Korteveg-de Vries equation¤
By default the KdV stepper "only" supports regular diffusion (second-order) and hyper-diffusion (fourth-order). Let's create a custom stepper that has an additional sixth-order term. To be fancy, let's call it the "ultra-viscous KdV" stepper.
from jaxtyping import Array, Complex
class KdV_ultra_diffusive(BaseStepper):
# Declare your constitutive parameters here
convection_scale: float
diffusivity: float
hyper_diffusivity: float
ultra_diffusivity: float # This is the term not present in the pre-made KdV stepper
dispersivity: float
# This is necessary to correctly instantiate the nonlinear function
dealiasing_fraction: float
def __init__(
self,
# The first four positional arguments are always the same, i.e.,
# dimensions, domain size, number of degrees of freedom and time step
# size
num_spatial_dims: int,
domain_extent: float,
num_points: int,
dt: float,
*,
# The first range of keyword arguments should be constitutive
# parameters, make sure to set reasonable defaults
convection_scale: float = -6,
diffusivity: float = 0.0,
hyper_diffusivity: float = 0.0,
ultra_diffusivity: float = 0.0003,
dispersivity: float = 1.0,
# The second range of keyword arguments should be numerical parameters;
# oftentimes it is fine to use the same as in other methods
order: int = 2,
dealiasing_fraction: float = 2 / 3,
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
"""
Add a docstring here
"""
# First, save the constitutive parameters
self.convection_scale = convection_scale
self.diffusivity = diffusivity
self.hyper_diffusivity = hyper_diffusivity
self.ultra_diffusivity = ultra_diffusivity
self.dispersivity = dispersivity
# Then save the dealiasing fraction
self.dealiasing_fraction = dealiasing_fraction
# Then, call the parent constructor
super().__init__(
num_spatial_dims,
domain_extent,
num_points,
dt,
# The KdV equation uses a convection nonlinearity for which the
# number of channels grows with the spatial dimension
num_channels=num_spatial_dims,
# Pass the numerical parameters to the parent constructor
order=order,
num_circle_points=num_circle_points,
circle_radius=circle_radius,
)
# We have to implement two methods `_build_linear_operator` and `_build_nonlinear_fun`
def _build_linear_operator(self, derivative_operator: Array) -> Array:
# IMPORTANT: The implementations below only work in 1D
diffusion_part = self.diffusivity * derivative_operator**2
hyper_diffusion_part = -self.hyper_diffusivity * derivative_operator**4
ultra_diffusion_part = self.ultra_diffusivity * derivative_operator**6
# We need a minus here because this term of the equation is moved to the
# right hand side
dispersion_part = -self.dispersivity * derivative_operator**3
linear_operator = (
diffusion_part
+ dispersion_part
+ hyper_diffusion_part
+ ultra_diffusion_part
)
return linear_operator
def _build_nonlinear_fun(self, derivative_operator: Array) -> BaseNonlinearFun:
# The nonlinear given by convection, i.e., $b * 1/2 * (u^2)_x$ is so common
# among (semi-linear/) nonlinear PDEs that there is shared nonlinear
# function.
nonlinear_fun = ConvectionNonlinearFun(
self.num_spatial_dims,
self.num_points,
derivative_operator=derivative_operator,
dealiasing_fraction=self.dealiasing_fraction,
scale=self.convection_scale,
)
return nonlinear_fun
DOMAIN_EXTENT = 20.0
NUM_POINTS = 100
DT = 0.05
original_kdv_stepper = ex.stepper.KortewegDeVries(
1,
DOMAIN_EXTENT,
NUM_POINTS,
DT,
dispersivity=1.0,
convection_scale=-6,
)
ultra_diffusive_kdv_stepper = KdV_ultra_diffusive(
1,
DOMAIN_EXTENT,
NUM_POINTS,
DT,
diffusivity=0.1,
dispersivity=1.0,
convection_scale=-6,
)
mesh = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.cos(4 * jnp.pi * mesh / DOMAIN_EXTENT)
original_kdv_trj = ex.rollout(original_kdv_stepper, 200, include_init=True)(u_0)
ultra_diffusive_kdv_trj = ex.rollout(
ultra_diffusive_kdv_stepper, 200, include_init=True
)(u_0)
fig, ax = plt.subplots(2, 1, sharex=True, sharey=True)
ax[0].imshow(original_kdv_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1)
ax[0].set_title("Original KdV")
ax[1].imshow(
ultra_diffusive_kdv_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1
)
ax[1].set_title("Ultra-Diffusive KdV")
Heat Equation with nonlinear source term¤
\[ \partial_t u = \nu \partial_x^2 u + \frac{1}{1 + u^2} \]
For these highly nonlinear equations, it is hard to say how strong the dealiasing should be. Having a fraction of \(1/2\) is a good first guess.
class NonlinearSourceFun(BaseNonlinearFun):
# The constructor of a nonlinear function should take as positional
# arguments the number of spatial dimensions and the number of points. The
# derivative operator should be a keyword argument. Here, we do not need; so
# let's not include it.
def __init__(
self,
num_spatial_dims: int,
num_points: int,
*,
dealiasing_fraction: float = 1 / 2,
):
super().__init__(
num_spatial_dims,
num_points,
dealiasing_fraction=dealiasing_fraction,
)
def __call__(self, u_hat: Array) -> Array:
# When inheriting from BaseNonlinearFun, we can use the the class
# methods `fft` and `ifft` to perform the correct (real-valued) Fourier
# transforms. Make sure to dealias before evaluating nonlinear terms.
u = self.ifft(self.dealias(u_hat))
u_nonlin = 1 / (1 + u**2)
u_nonlin_hat = self.fft(u_nonlin)
return u_nonlin_hat
class Diffusion_nonlinear_source(BaseStepper):
diffusivity: float
dealiasing_fraction: float
def __init__(
self,
num_spatial_dims: int,
domain_extent: float,
num_points: int,
dt: float,
*,
diffusivity: float = 0.01,
# The below numerics parameters normally would not be necessary for the
# linear heat PDE, because any linear PDE on periodic boundaries can be
# exactly integrated in time with the Fourier spectral method (given it
# is band-limited). However, here we added a nonlinear source term.
order: int = 2,
dealiasing_fraction: float = 1 / 2, # Using a stronger dealiasing
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
self.diffusivity = diffusivity
self.dealiasing_fraction = dealiasing_fraction
super().__init__(
num_spatial_dims,
domain_extent,
num_points,
dt,
# The heat equation's channel will always be 1, no matter the
# spatial dimension
num_channels=1,
order=order,
num_circle_points=num_circle_points,
circle_radius=circle_radius,
)
def _build_linear_operator(self, derivative_operator: Array) -> Array:
return self.diffusivity * derivative_operator**2
def _build_nonlinear_fun(self, derivative_operator: Array) -> BaseNonlinearFun:
return NonlinearSourceFun(
self.num_spatial_dims,
self.num_points,
dealiasing_fraction=self.dealiasing_fraction,
)
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.05
NU = 0.01
original_diffusion_stepper = ex.stepper.Diffusion(
1, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU
)
nonlinear_diffusion_stepper = Diffusion_nonlinear_source(
1,
DOMAIN_EXTENT,
NUM_POINTS,
DT,
diffusivity=NU,
)
mesh = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.cos(4 * jnp.pi * mesh / DOMAIN_EXTENT)
original_diffusion_trj = ex.rollout(original_diffusion_stepper, 200, include_init=True)(
u_0
)
nonlinear_diffusion_trj = ex.rollout(
nonlinear_diffusion_stepper, 200, include_init=True
)(u_0)
fig, ax = plt.subplots(2, 1, sharex=True, sharey=True)
ax[0].imshow(
original_diffusion_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1
)
ax[0].set_title("Original Diffusion")
ax[1].imshow(
nonlinear_diffusion_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1
)
ax[1].set_title("Nonlinear Diffusion")