Skip to content

Creating custom solvers¤

This is a showcase how the sample steppers in exponax.sample_stepper are implemented.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import exponax as ex
from exponax import BaseStepper
from exponax.nonlin_fun import BaseNonlinearFun, ConvectionNonlinearFun

An ultra-viscous? Korteveg-de Vries equation¤

By default the KdV stepper "only" supports regular diffusion (second-order) and hyper-diffusion (fourth-order). Let's create a custom stepper that has an additional sixth-order term. To be fancy, let's call it the "ultra-viscous KdV" stepper.

from jaxtyping import Array, Complex


class KdV_ultra_diffusive(BaseStepper):
    # Declare your constitutive parameters here
    convection_scale: float
    diffusivity: float
    hyper_diffusivity: float
    ultra_diffusivity: float  # This is the term not present in the pre-made KdV stepper
    dispersivity: float

    # This is necessary to correctly instantiate the nonlinear function
    dealiasing_fraction: float

    def __init__(
        self,
        # The first four positional arguments are always the same, i.e.,
        # dimensions, domain size, number of degrees of freedom and time step
        # size
        num_spatial_dims: int,
        domain_extent: float,
        num_points: int,
        dt: float,
        *,
        # The first range of keyword arguments should be constitutive
        # parameters, make sure to set reasonable defaults
        convection_scale: float = -6,
        diffusivity: float = 0.0,
        hyper_diffusivity: float = 0.0,
        ultra_diffusivity: float = 0.0003,
        dispersivity: float = 1.0,
        # The second range of keyword arguments should be numerical parameters;
        # oftentimes it is fine to use the same as in other methods
        order: int = 2,
        dealiasing_fraction: float = 2 / 3,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        """
        Add a docstring here
        """

        # First, save the constitutive parameters
        self.convection_scale = convection_scale
        self.diffusivity = diffusivity
        self.hyper_diffusivity = hyper_diffusivity
        self.ultra_diffusivity = ultra_diffusivity
        self.dispersivity = dispersivity

        # Then save the dealiasing fraction
        self.dealiasing_fraction = dealiasing_fraction

        # Then, call the parent constructor
        super().__init__(
            num_spatial_dims,
            domain_extent,
            num_points,
            dt,
            # The KdV equation uses a convection nonlinearity for which the
            # number of channels grows with the spatial dimension
            num_channels=num_spatial_dims,
            # Pass the numerical parameters to the parent constructor
            order=order,
            num_circle_points=num_circle_points,
            circle_radius=circle_radius,
        )

    # We have to implement two methods `_build_linear_operator` and `_build_nonlinear_fun`

    def _build_linear_operator(self, derivative_operator: Array) -> Array:
        # IMPORTANT: The implementations below only work in 1D
        diffusion_part = self.diffusivity * derivative_operator**2
        hyper_diffusion_part = -self.hyper_diffusivity * derivative_operator**4
        ultra_diffusion_part = self.ultra_diffusivity * derivative_operator**6
        # We need a minus here because this term of the equation is moved to the
        # right hand side
        dispersion_part = -self.dispersivity * derivative_operator**3

        linear_operator = (
            diffusion_part
            + dispersion_part
            + hyper_diffusion_part
            + ultra_diffusion_part
        )
        return linear_operator

    def _build_nonlinear_fun(self, derivative_operator: Array) -> BaseNonlinearFun:
        # The nonlinear given by convection, i.e., $b * 1/2 * (u^2)_x$ is so common
        # among (semi-linear/) nonlinear PDEs that there is shared nonlinear
        # function.
        nonlinear_fun = ConvectionNonlinearFun(
            self.num_spatial_dims,
            self.num_points,
            derivative_operator=derivative_operator,
            dealiasing_fraction=self.dealiasing_fraction,
            scale=self.convection_scale,
        )
        return nonlinear_fun
DOMAIN_EXTENT = 20.0
NUM_POINTS = 100
DT = 0.05

original_kdv_stepper = ex.stepper.KortewegDeVries(
    1,
    DOMAIN_EXTENT,
    NUM_POINTS,
    DT,
    dispersivity=1.0,
    convection_scale=-6,
)
ultra_diffusive_kdv_stepper = KdV_ultra_diffusive(
    1,
    DOMAIN_EXTENT,
    NUM_POINTS,
    DT,
    diffusivity=0.1,
    dispersivity=1.0,
    convection_scale=-6,
)

mesh = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.cos(4 * jnp.pi * mesh / DOMAIN_EXTENT)

original_kdv_trj = ex.rollout(original_kdv_stepper, 200, include_init=True)(u_0)
ultra_diffusive_kdv_trj = ex.rollout(
    ultra_diffusive_kdv_stepper, 200, include_init=True
)(u_0)

fig, ax = plt.subplots(2, 1, sharex=True, sharey=True)
ax[0].imshow(original_kdv_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1)
ax[0].set_title("Original KdV")
ax[1].imshow(
    ultra_diffusive_kdv_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1
)
ax[1].set_title("Ultra-Diffusive KdV")
2024-10-22 17:12:25.544002: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.6.68). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

Text(0.5, 1.0, 'Ultra-Diffusive KdV')
No description has been provided for this image

Heat Equation with nonlinear source term¤

\[ \partial_t u = \nu \partial_x^2 u + \frac{1}{1 + u^2} \]

For these highly nonlinear equations, it is hard to say how strong the dealiasing should be. Having a fraction of \(1/2\) is a good first guess.

class NonlinearSourceFun(BaseNonlinearFun):
    # The constructor of a nonlinear function should take as positional
    # arguments the number of spatial dimensions and the number of points. The
    # derivative operator should be a keyword argument. Here, we do not need; so
    # let's not include it.
    def __init__(
        self,
        num_spatial_dims: int,
        num_points: int,
        *,
        dealiasing_fraction: float = 1 / 2,
    ):
        super().__init__(
            num_spatial_dims,
            num_points,
            dealiasing_fraction=dealiasing_fraction,
        )

    def __call__(self, u_hat: Array) -> Array:
        # When inheriting from BaseNonlinearFun, we can use the the class
        # methods `fft` and `ifft` to perform the correct (real-valued) Fourier
        # transforms. Make sure to dealias before evaluating nonlinear terms.
        u = self.ifft(self.dealias(u_hat))
        u_nonlin = 1 / (1 + u**2)
        u_nonlin_hat = self.fft(u_nonlin)
        return u_nonlin_hat


class Diffusion_nonlinear_source(BaseStepper):
    diffusivity: float

    dealiasing_fraction: float

    def __init__(
        self,
        num_spatial_dims: int,
        domain_extent: float,
        num_points: int,
        dt: float,
        *,
        diffusivity: float = 0.01,
        # The below numerics parameters normally would not be necessary for the
        # linear heat PDE, because any linear PDE on periodic boundaries can be
        # exactly integrated in time with the Fourier spectral method (given it
        # is band-limited). However, here we added a nonlinear source term.
        order: int = 2,
        dealiasing_fraction: float = 1 / 2,  # Using a stronger dealiasing
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        self.diffusivity = diffusivity
        self.dealiasing_fraction = dealiasing_fraction
        super().__init__(
            num_spatial_dims,
            domain_extent,
            num_points,
            dt,
            # The heat equation's channel will always be 1, no matter the
            # spatial dimension
            num_channels=1,
            order=order,
            num_circle_points=num_circle_points,
            circle_radius=circle_radius,
        )

    def _build_linear_operator(self, derivative_operator: Array) -> Array:
        return self.diffusivity * derivative_operator**2

    def _build_nonlinear_fun(self, derivative_operator: Array) -> BaseNonlinearFun:
        return NonlinearSourceFun(
            self.num_spatial_dims,
            self.num_points,
            dealiasing_fraction=self.dealiasing_fraction,
        )
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.05
NU = 0.01

original_diffusion_stepper = ex.stepper.Diffusion(
    1, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU
)
nonlinear_diffusion_stepper = Diffusion_nonlinear_source(
    1,
    DOMAIN_EXTENT,
    NUM_POINTS,
    DT,
    diffusivity=NU,
)

mesh = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.cos(4 * jnp.pi * mesh / DOMAIN_EXTENT)

original_diffusion_trj = ex.rollout(original_diffusion_stepper, 200, include_init=True)(
    u_0
)
nonlinear_diffusion_trj = ex.rollout(
    nonlinear_diffusion_stepper, 200, include_init=True
)(u_0)

fig, ax = plt.subplots(2, 1, sharex=True, sharey=True)
ax[0].imshow(
    original_diffusion_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1
)
ax[0].set_title("Original Diffusion")
ax[1].imshow(
    nonlinear_diffusion_trj[:, 0, :].T, aspect="auto", cmap="RdBu", vmin=-1, vmax=1
)
ax[1].set_title("Nonlinear Diffusion")
Text(0.5, 1.0, 'Nonlinear Diffusion')
No description has been provided for this image