Skip to content

Presenting all the built-in solvers working in 3d¤

This notebook is a WIP and it requires the VAPE volume render (pip install vape4d) which only runs on GPUs.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import exponax as ex
from IPython.display import HTML

Linear PDEs¤

Advection¤

\[ \frac{\partial u}{\partial t} + \vec{c} \cdot \nabla u = 0 \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.02
# Can also supply a scalar to use the same value for both dimensions
VELOCITY = jnp.array([-0.4, 1.0, 0.1])

advection_stepper = ex.stepper.Advection(
    3, DOMAIN_EXTENT, NUM_POINTS, DT, velocity=VELOCITY
)

grid = ex.make_grid(3, DOMAIN_EXTENT, NUM_POINTS)
u_0 = (
    jnp.sin(2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
    * jnp.cos(2 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
    * jnp.sin(3 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT)
)

advection_trj = ex.rollout(advection_stepper, 40, include_init=True)(u_0)

advection_ani = ex.viz.animate_state_3d(advection_trj)

HTML(advection_ani.to_html5_video())
2024-10-22 18:00:45.919128: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.6.68). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

Diffusion¤

\[ \frac{\partial u}{\partial t} = \nu \Delta u \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
# See next section for anisotropic diffusion
NU = 0.01

anisotropic_diffusion_stepper = ex.stepper.Diffusion(
    3, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU
)

grid = ex.make_grid(3, DOMAIN_EXTENT, NUM_POINTS)
u_0 = (
    jnp.sin(2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
    * jnp.cos(2 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
    * jnp.sin(3 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT)
)

diffusion_trj = ex.rollout(anisotropic_diffusion_stepper, 40, include_init=True)(u_0)

diffusion_ani = ex.viz.animate_state_3d(diffusion_trj)

HTML(diffusion_ani.to_html5_video())

Anisotropic Diffusion¤

\[ \frac{\partial u}{\partial t} = \nabla \cdot \left( A \nabla u \right) \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
# Can also supply a 2d vector for diagonal diffusivity. For full anisotropy, the
# matrix must be positive definite.
NU = jnp.array([[0.005, 0.003, 0.0001], [0.003, 0.04, 0.0002], [0.0001, 0.0002, 0.01]])

anisotropic_diffusion_stepper = ex.stepper.Diffusion(
    3, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU
)

grid = ex.make_grid(3, DOMAIN_EXTENT, NUM_POINTS)
u_0 = (
    jnp.sin(2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
    * jnp.cos(2 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
    * jnp.sin(3 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT)
)

anisotropic_diffusion_trj = ex.rollout(
    anisotropic_diffusion_stepper, 40, include_init=True
)(u_0)

anisotropic_diffusion_ani = ex.viz.animate_state_3d(anisotropic_diffusion_trj)

HTML(anisotropic_diffusion_ani.to_html5_video())

Advection-Diffusion¤

\[ \frac{\partial u}{\partial t} + \vec{c} \cdot \nabla u = \nu \Delta u \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
# Can supply up to a vector for the advection speed, and up to a matrix for
# anisotropic diffusion
velocity = 1.0
diffusivity = 0.01

advection_diffusion_stepper = ex.stepper.AdvectionDiffusion(
    3, DOMAIN_EXTENT, NUM_POINTS, DT, velocity=velocity, diffusivity=diffusivity
)

grid = ex.make_grid(3, DOMAIN_EXTENT, NUM_POINTS)
u_0 = (
    jnp.sin(2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
    * jnp.cos(2 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
    * jnp.sin(3 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT)
)

advection_diffusion_trj = ex.rollout(
    advection_diffusion_stepper, 40, include_init=True
)(u_0)

advection_diffusion_ani = ex.viz.animate_state_3d(advection_diffusion_trj)

HTML(advection_diffusion_ani.to_html5_video())

Dispersion¤

\[ \frac{\partial u}{\partial t} + \vec{\xi} \cdot (\nabla \odot \nabla \odot \nabla) u = 0 \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
# Can also supply a vector for different dispersivity per dimension
DISPERSIVITY = 0.001

dispersion_stepper = ex.stepper.Dispersion(
    3, DOMAIN_EXTENT, NUM_POINTS, DT, dispersivity=DISPERSIVITY
)

grid = ex.make_grid(3, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT) * jnp.cos(
    2 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT
) * jnp.sin(3 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT) + jnp.sin(
    3 * 2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT
) * jnp.sin(
    4 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT
) * jnp.cos(
    4 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT
)

dispersion_trj = ex.rollout(dispersion_stepper, 40, include_init=True)(u_0)

dispersion_ani = ex.viz.animate_state_3d(dispersion_trj)

HTML(dispersion_ani.to_html5_video())

Hyper-Diffusion¤

\[ \frac{\partial u}{\partial t} = \nu \Delta^2 u \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
HYPER_DIFFUSIVITY = 0.0001

hyper_diffusion_stepper = ex.stepper.HyperDiffusion(
    3, DOMAIN_EXTENT, NUM_POINTS, DT, hyper_diffusivity=HYPER_DIFFUSIVITY
)

grid = ex.make_grid(3, DOMAIN_EXTENT, NUM_POINTS)
u_0 = (
    jnp.sin(2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
    * jnp.cos(2 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
    * jnp.sin(3 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT)
)

hyper_diffusion_trj = ex.rollout(hyper_diffusion_stepper, 40, include_init=True)(u_0)

hyper_diffusion_ani = ex.viz.animate_state_3d(hyper_diffusion_trj)

HTML(hyper_diffusion_ani.to_html5_video())

Nonlinear PDEs¤

Burgers (non-conservative)¤

\[ \frac{\partial u}{\partial t} + (u \cdot \nabla) u = \nu \Delta u \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
NU = 0.005

burgers_stepper = ex.stepper.Burgers(3, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU)

grid = ex.make_grid(3, DOMAIN_EXTENT, NUM_POINTS)

# Burgers has two channels!
u_0 = jnp.concatenate(
    [
        jnp.sin(2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
        * jnp.cos(2 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
        * jnp.sin(3 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT),
        jnp.cos(2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
        * jnp.sin(2 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
        * jnp.cos(3 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT),
        jnp.sin(3 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
        * jnp.cos(3 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
        * jnp.sin(4 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT),
    ]
)

burgers_trj = ex.rollout(burgers_stepper, 40, include_init=True)(u_0)

burgers_ani = ex.viz.animate_state_3d_facet(
    burgers_trj,
    grid=(1, 3),
    figsize=(8, 3),
)

HTML(burgers_ani.to_html5_video())

Burgers (conservative)¤

\[ \frac{\partial u}{\partial t} + \frac{1}{2} \nabla \cdot \left( u \otimes u \right) = \nu \Delta u \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
NU = 0.01

burgers_stepper_conservative = ex.stepper.Burgers(
    3, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU, conservative=True
)

grid = ex.make_grid(3, DOMAIN_EXTENT, NUM_POINTS)

# Burgers has two channels!
u_0 = jnp.concatenate(
    [
        jnp.sin(2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
        * jnp.cos(2 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
        * jnp.sin(3 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT),
        jnp.cos(2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
        * jnp.sin(2 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
        * jnp.cos(3 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT),
        jnp.sin(3 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
        * jnp.cos(3 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
        * jnp.sin(4 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT),
    ]
)

burgers_trj_conservative = ex.rollout(
    burgers_stepper_conservative, 40, include_init=True
)(u_0)

burgers_ani_conservative = ex.viz.animate_state_3d_facet(
    burgers_trj_conservative,
    grid=(1, 3),
    figsize=(8, 3),
)

HTML(burgers_ani_conservative.to_html5_video())

Single-Channel Burgers¤

This is a hack to not have the channel dimension grow together with the spatial dimension.

\[ \frac{\partial u}{\partial t} + \frac{1}{2} (\vec{1} \cdot \nabla) (u^2) = \nu \Delta u \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
NU = 0.01

single_channel_burgers_stepper = ex.stepper.Burgers(
    3, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU, single_channel=True
)

grid = ex.make_grid(3, DOMAIN_EXTENT, NUM_POINTS)

u_0 = (
    jnp.sin(2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
    * jnp.cos(2 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
    * jnp.sin(3 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT)
)

single_channel_burgers_trj = ex.rollout(
    single_channel_burgers_stepper, 40, include_init=True
)(u_0)

single_channel_burgers_ani = ex.viz.animate_state_3d(single_channel_burgers_trj)

HTML(single_channel_burgers_ani.to_html5_video())

Kuramoto-Sivashinsky (KS)¤

The combustion format (using the gradient norm) generalizes nicely to higher dimensions

\[ \frac{\partial u}{\partial t} + \frac{1}{2} \| \nabla u \|_2^2 + \Delta u + \Delta^2 u = 0 \]
DOMAIN_EXTENT = 30.0
NUM_POINTS = 100
DT = 0.1

ks_stepper = ex.stepper.KuramotoSivashinsky(3, DOMAIN_EXTENT, NUM_POINTS, DT)

grid = ex.make_grid(3, DOMAIN_EXTENT, NUM_POINTS)


# IC is irrelevant
u_0 = jax.random.normal(jax.random.PRNGKey(0), (1, NUM_POINTS, NUM_POINTS, NUM_POINTS))
u_0 -= jnp.mean(u_0)

warmed_up_u_0 = ex.repeat(ks_stepper, 500)(u_0)

ks_trj = ex.rollout(ks_stepper, 40, include_init=True)(warmed_up_u_0)

ks_ani = ex.viz.animate_state_3d(ks_trj, vlim=(-6, 6))

HTML(ks_ani.to_html5_video())

Korteweg-de Vries (KdV)¤

Works best with single channel hack

\[ \frac{\partial u}{\partial t} + \frac{1}{2} (\vec{1} \cdot \nabla) u^2 + \vec{1} \cdot (\nabla \odot \nabla \odot \nabla) u = 0 \]
DOMAIN_EXTENT = 20.0
NUM_POINTS = 100
DT = 0.05
HYPER_NU = 0.03

kdv_stepper = ex.stepper.KortewegDeVries(
    3,
    DOMAIN_EXTENT,
    NUM_POINTS,
    DT,
    single_channel=True,
    hyper_diffusivity=HYPER_NU,
)

grid = ex.make_grid(3, DOMAIN_EXTENT, NUM_POINTS)
u_0 = (
    jnp.sin(2 * jnp.pi * grid[0:1] / DOMAIN_EXTENT)
    * jnp.cos(2 * 2 * jnp.pi * grid[1:2] / DOMAIN_EXTENT)
    * jnp.sin(3 * 2 * jnp.pi * grid[2:3] / DOMAIN_EXTENT)
)

kdv_trj = ex.rollout(kdv_stepper, 40, include_init=True)(u_0)

kdv_ani = ex.viz.animate_state_3d(kdv_trj)

HTML(kdv_ani.to_html5_video())

Reaction-Diffusion PDEs¤

Fisher-KPP¤

\[ \frac{\partial u}{\partial t} = \nu \Delta u + r u (1 - u) \]
DOMAIN_EXTENT = 10.0
NUM_POINTS = 100
DT = 0.01
DIFFUSIVITY = 0.01
REACTIVITY = 10.0

fisher_kpp_stepper = ex.stepper.reaction.FisherKPP(
    3, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=DIFFUSIVITY, reactivity=REACTIVITY
)

ic_gen = ex.ic.ClampingICGenerator(ex.ic.RandomTruncatedFourierSeries(3), limits=(0, 1))
u_0 = ic_gen(100, key=jax.random.PRNGKey(0))

fisher_kpp_trj = ex.rollout(fisher_kpp_stepper, 40, include_init=True)(u_0)

fisher_kpp_ani = ex.viz.animate_state_3d(fisher_kpp_trj)

HTML(fisher_kpp_ani.to_html5_video())

Gray-Scott¤

\[ \begin{aligned} \frac{\partial u_0}{\partial t} &= \nu_0 \Delta u_0 - u_0 u_1^2 + f (1 - u_0) \\ \frac{\partial u_1}{\partial t} &= \nu_1 \Delta u_1 + u_0 u_1^2 - (f + k) u_1 \end{aligned} \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 30.0
DIFFUSIVITY_0 = 2e-5
DIFFUSIVITY_1 = 1e-5
FEED_RATE = 0.04
KILL_RATE = 0.06

gray_scott_stepper = ex.RepeatedStepper(
    ex.stepper.reaction.GrayScott(
        3,
        DOMAIN_EXTENT,
        NUM_POINTS,
        DT / 30,
        diffusivity_1=DIFFUSIVITY_0,
        diffusivity_2=DIFFUSIVITY_1,
        feed_rate=FEED_RATE,
        kill_rate=KILL_RATE,
    ),
    30,
)

u_0 = ex.ic.RandomMultiChannelICGenerator(
    [
        ex.ic.RandomGaussianBlobs(3, one_complement=True),
        ex.ic.RandomGaussianBlobs(3),
    ]
)(NUM_POINTS, key=jax.random.PRNGKey(0))

gray_scott_trj = ex.rollout(gray_scott_stepper, 40, include_init=True)(u_0)

gray_scott_ani = ex.viz.animate_state_3d_facet(
    gray_scott_trj,
    grid=(1, 2),
    figsize=(7, 3),
)

HTML(gray_scott_ani.to_html5_video())

Swift-Hohenberg¤

\[ \frac{\partial u}{\partial t} = r u - (1 + \Delta)^2 u + u^2 - u^3 \]
DOMAIN_EXTENT = 20.0 * jnp.pi
NUM_POINTS = 100
DT = 1.0

swift_hohenberg_stepper = ex.RepeatedStepper(
    ex.stepper.reaction.SwiftHohenberg(3, DOMAIN_EXTENT, NUM_POINTS, DT / 10),
    10,
)

u_0 = ex.ic.RandomTruncatedFourierSeries(3, max_one=True)(
    NUM_POINTS, key=jax.random.PRNGKey(0)
)

swift_hohenberg_trj = ex.rollout(swift_hohenberg_stepper, 40, include_init=True)(u_0)

swift_hohenberg_ani = ex.viz.animate_state_3d(swift_hohenberg_trj)

HTML(swift_hohenberg_ani.to_html5_video())

All together¤

joint_trj = jnp.concatenate(
    [
        advection_trj,
        diffusion_trj,
        anisotropic_diffusion_trj,
        advection_diffusion_trj,
        dispersion_trj,
        hyper_diffusion_trj,
        burgers_trj,
        burgers_trj_conservative,
        single_channel_burgers_trj,
        kdv_trj,
        ks_trj,
        fisher_kpp_trj,
        gray_scott_trj,
        swift_hohenberg_trj,
    ],
    axis=1,
)

joint_ani = ex.viz.animate_state_3d_facet(
    joint_trj,
    titles=[
        "Advection",
        "Diffusion",
        "Anisotropic Diffusion",
        "Advection-Diffusion",
        "Dispersion",
        "Hyper-Diffusion",
        "Burgers channel 1",
        "Burgers channel 2",
        "Burgers channel 3",
        "Burgers (Conservative)\nchannel 1",
        "Burgers (Conservative)\nchannel 2",
        "Burgers (Conservative)\nchannel 3",
        "Burgers single channel",
        "KdV",
        "KS",
        "Decaying Turbulence",
        "Kolmogorov Flow",
        "Fisher-KPP",
        "Gray-Scott 1",
        "Gray-Scott 2",
        "Swift-Hohenberg",
        "",
        "",
    ],
    grid=(3, 7),
    figsize=(17, 10),
)

HTML(joint_ani.to_html5_video())