Skip to content

Presenting all the built-in solvers working in 1d¤

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import exponax as ex

Linear PDEs¤

Advection¤

\[ \frac{\partial u}{\partial t} + c \frac{\partial u}{\partial x} = 0 \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
VELOCITY = 1.0

advection_stepper = ex.stepper.Advection(
    1, DOMAIN_EXTENT, NUM_POINTS, DT, velocity=VELOCITY
)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(2 * jnp.pi * grid / DOMAIN_EXTENT)
advection_trj = ex.rollout(advection_stepper, 200, include_init=True)(u_0)

plt.imshow(advection_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-1, vmax=1)
<matplotlib.image.AxesImage at 0x709dd875a9c0>
No description has been provided for this image

Diffusion¤

\[ \frac{\partial u}{\partial t} = \nu \frac{\partial^2 u}{\partial x^2} \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
NU = 0.01

diffusion_stepper = ex.stepper.Diffusion(
    1, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU
)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.where((grid &gt; 0.3) &amp; (grid &lt; 0.5), 1.0, 0.0)

diffusion_trj = ex.rollout(diffusion_stepper, 200, include_init=True)(u_0)

plt.imshow(diffusion_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-1, vmax=1)
<matplotlib.image.AxesImage at 0x709dd86661b0>
No description has been provided for this image

Advection-Diffusion¤

\[ \frac{\partial u}{\partial t} + c \frac{\partial u}{\partial x} = \nu \frac{\partial^2 u}{\partial x^2} \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
VELOCITY = 1.0
NU = 0.001

advection_diffusion_stepper = ex.stepper.AdvectionDiffusion(
    1, DOMAIN_EXTENT, NUM_POINTS, DT, velocity=VELOCITY, diffusivity=NU
)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(10 * jnp.pi * grid / DOMAIN_EXTENT)

advection_diffusion_trj = ex.rollout(
    advection_diffusion_stepper, 200, include_init=True
)(u_0)

plt.imshow(
    advection_diffusion_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-1, vmax=1
)
<matplotlib.image.AxesImage at 0x709dd83f6060>
No description has been provided for this image

Dispersion¤

\[ \frac{\partial u}{\partial t} + \xi \frac{\partial^3 u}{\partial x^3} = 0 \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
DISPERSIVITY = 0.01

dispersion_stepper = ex.stepper.Dispersion(
    1, DOMAIN_EXTENT, NUM_POINTS, DT, dispersivity=DISPERSIVITY
)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(2 * jnp.pi * grid / DOMAIN_EXTENT) + jnp.sin(
    4 * jnp.pi * grid / DOMAIN_EXTENT
)

dispersion_trj = ex.rollout(dispersion_stepper, 200, include_init=True)(u_0)

plt.imshow(dispersion_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-1, vmax=1)
<matplotlib.image.AxesImage at 0x709dd8397e90>
No description has been provided for this image

Hyper-Diffusion¤

\[ \frac{\partial u}{\partial t} = - \mu \frac{\partial^4 u}{\partial x^4} \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
HYPER_DIFFUSIVITY = 0.0001

hyper_diffusion_stepper = ex.stepper.HyperDiffusion(
    1, DOMAIN_EXTENT, NUM_POINTS, DT, hyper_diffusivity=HYPER_DIFFUSIVITY
)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.where((grid &gt; 0.3) &amp; (grid &lt; 0.5), 1.0, 0.0)

hyper_diffusion_trj = ex.rollout(hyper_diffusion_stepper, 200, include_init=True)(u_0)

plt.imshow(
    hyper_diffusion_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-1, vmax=1
)
<matplotlib.image.AxesImage at 0x709dd8344e90>
No description has been provided for this image

General linear stepper¤

\[ \frac{\partial u}{\partial t} = a_0 u + a_1 \frac{\partial u}{\partial x} + a_2 \frac{\partial^2 u}{\partial x^2} + a_3 \frac{\partial^3 u}{\partial x^3} + a_4 \frac{\partial^4 u}{\partial x^4} \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
a_0 = -0.01
a_1 = -0.3
a_2 = 0.001
a_3 = -0.0001
a_4 = -0.00001

general_linear_stepper = ex.stepper.generic.GeneralLinearStepper(
    1,
    DOMAIN_EXTENT,
    NUM_POINTS,
    DT,
    linear_coefficients=[a_0, a_1, a_2, a_3, a_4],
)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(2 * jnp.pi * grid / DOMAIN_EXTENT) + jnp.sin(
    4 * jnp.pi * grid / DOMAIN_EXTENT
)

general_linear_trj = ex.rollout(general_linear_stepper, 200, include_init=True)(u_0)

plt.imshow(
    general_linear_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-1, vmax=1
)
<matplotlib.image.AxesImage at 0x709dd82e1160>
No description has been provided for this image

Nonlinear PDEs¤

Burgers (nonconservative)¤

\[ \frac{\partial u}{\partial t} + u \frac{\partial u}{\partial x} = \nu \frac{\partial^2 u}{\partial x^2} \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
NU = 0.01

burgers_stepper = ex.stepper.Burgers(1, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(2 * jnp.pi * grid / DOMAIN_EXTENT) + 0.5

burgers_trj = ex.rollout(burgers_stepper, 200, include_init=True)(u_0)

plt.imshow(burgers_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-1, vmax=1)
<matplotlib.image.AxesImage at 0x709dd810b170>
No description has been provided for this image

Burgers (conservative)¤

\[ \frac{\partial u}{\partial t} + \frac{1}{2} \frac{\partial u^2}{\partial x} = \nu \frac{\partial^2 u}{\partial x^2} \]
DOMAIN_EXTENT = 1.0
NUM_POINTS = 100
DT = 0.01
NU = 0.01

burgers_stepper_conservative = ex.stepper.Burgers(
    1, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU, conservative=True
)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(2 * jnp.pi * grid / DOMAIN_EXTENT) + 0.5

burgers_trj_conservative = ex.rollout(burgers_stepper, 200, include_init=True)(u_0)

plt.imshow(
    burgers_trj_conservative[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-1, vmax=1
)
<matplotlib.image.AxesImage at 0x709da0715a30>
No description has been provided for this image

Kuramoto-Sivashinsky (KS)¤

\[ \frac{\partial u}{\partial t} + \frac{1}{2} \left(\frac{\partial u}{\partial x}\right)^2 + \frac{\partial^2 u}{\partial x^2} + \frac{\partial^4 u}{\partial x^4} = 0 \]

Format used in combustion research.

DOMAIN_EXTENT = 60.0
NUM_POINTS = 100
DT = 0.5

ks_stepper = ex.stepper.KuramotoSivashinsky(1, DOMAIN_EXTENT, NUM_POINTS, DT)

u_0 = jax.random.normal(
    jax.random.PRNGKey(0),
    (
        1,
        NUM_POINTS,
    ),
)  # IC is irrelevant

ks_trj = ex.rollout(ks_stepper, 200, include_init=True)(u_0)

plt.imshow(ks_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-6, vmax=6)
<matplotlib.image.AxesImage at 0x709dd8264cb0>
No description has been provided for this image

Kuramoto-Sivashinsky (KS) Conservative¤

\[ \frac{\partial u}{\partial t} + \frac{1}{2} \frac{\partial u^2}{\partial x} + \frac{\partial^2 u}{\partial x^2} + \frac{\partial^4 u}{\partial x^4} = 0 \]

conservative format used in fluid dynamics

DOMAIN_EXTENT = 60.0
NUM_POINTS = 100
DT = 0.5

conservative_ks_stepper = ex.stepper.KuramotoSivashinskyConservative(
    1, DOMAIN_EXTENT, NUM_POINTS, DT
)

u_0 = jax.random.normal(
    jax.random.PRNGKey(0),
    (
        1,
        NUM_POINTS,
    ),
)  # IC is irrelevant

conservative_ks_trj = ex.rollout(conservative_ks_stepper, 200, include_init=True)(u_0)

plt.imshow(
    conservative_ks_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-2.5, vmax=2.5
)
<matplotlib.image.AxesImage at 0x709dd83e6ba0>
No description has been provided for this image

Korteweg-de Vries (KdV)¤

\[ \frac{\partial u}{\partial t} - \frac{6}{2} \frac{\partial u^2}{\partial x} + \frac{\partial^3 u}{\partial x^3} = \nu \frac{\partial^2 u}{\partial x^2} - \mu \frac{\partial^4 u}{\partial x^4} \]

negative (and scaled) convection

DOMAIN_EXTENT = 30.0
NUM_POINTS = 100
DT = 0.05

# Check documentation, the KdV stepper used "-6/2" convection scale by default
# to admit simple soliton solutions
ks_stepper = ex.stepper.KortewegDeVries(1, DOMAIN_EXTENT, NUM_POINTS, DT)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.cos(4 * jnp.pi * grid / DOMAIN_EXTENT)

kdv_trj = ex.rollout(ks_stepper, 200, include_init=True)(u_0)

plt.imshow(kdv_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-1, vmax=1)
<matplotlib.image.AxesImage at 0x709dd8278e60>
No description has been provided for this image

General Gradient Norm Stepper¤

\[ \frac{\partial u}{\partial t} + b \frac{1}{2} \left(\frac{\partial u}{\partial x}\right)^2 = a_0 u + a_1 \frac{\partial u}{\partial x} + a_2 \frac{\partial^2 u}{\partial x^2} + a_3 \frac{\partial^3 u}{\partial x^3} + a_4 \frac{\partial^4 u}{\partial x^4} \]
DOMAIN_EXTENT = 40.0
NUM_POINTS = 100
DT = 0.05

b = 1.4
a_0 = -0.01
a_1 = -0.3
a_2 = +0.1
a_3 = -0.0001
a_4 = -0.01

general_gradient_norm_stepper = ex.stepper.generic.GeneralGradientNormStepper(
    1,
    DOMAIN_EXTENT,
    NUM_POINTS,
    DT,
    gradient_norm_scale=b,
    linear_coefficients=[a_0, a_1, a_2, a_3, a_4],
)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(2 * jnp.pi * grid / DOMAIN_EXTENT) + jnp.sin(
    4 * jnp.pi * grid / DOMAIN_EXTENT
)

general_gradient_norm_trj = ex.rollout(
    general_gradient_norm_stepper, 200, include_init=True
)(u_0)

plt.imshow(
    general_gradient_norm_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-1, vmax=1
)
<matplotlib.image.AxesImage at 0x709da1ef1280>
No description has been provided for this image

General Convection¤

\[ \frac{\partial u}{\partial t} + b \frac{1}{2} \frac{\partial u^2}{\partial x} = a_0 u + a_1 \frac{\partial u}{\partial x} + a_2 \frac{\partial^2 u}{\partial x^2} + a_3 \frac{\partial^3 u}{\partial x^3} + a_4 \frac{\partial^4 u}{\partial x^4} \]
DOMAIN_EXTENT = 40.0
NUM_POINTS = 100
DT = 0.05

b = 1.4
a_0 = -0.01
a_1 = -0.3
a_2 = +0.1
a_3 = -0.0001
a_4 = -0.01

general_convection_stepper = ex.stepper.generic.GeneralConvectionStepper(
    1,
    DOMAIN_EXTENT,
    NUM_POINTS,
    DT,
    convection_scale=b,
    linear_coefficients=[a_0, a_1, a_2, a_3, a_4],
)

grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)
u_0 = jnp.sin(2 * jnp.pi * grid / DOMAIN_EXTENT) + jnp.sin(
    4 * jnp.pi * grid / DOMAIN_EXTENT
)

general_convection_trj = ex.rollout(general_convection_stepper, 200, include_init=True)(
    u_0
)

plt.imshow(
    general_convection_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-1, vmax=1
)
<matplotlib.image.AxesImage at 0x709da26a9340>
No description has been provided for this image

Reaction-Diffusion PDEs¤

Fisher-KPP¤

\[ \frac{\partial u}{\partial t} = \nu \frac{\partial^2 u}{\partial x^2} + u(1-u) \]
DOMAIN_EXTENT = 10.0
NUM_POINTS = 100
DT = 0.01

fisher_kpp_stepper = ex.stepper.reaction.FisherKPP(1, DOMAIN_EXTENT, NUM_POINTS, DT)

u_0 = ex.ic.ClampingICGenerator(ex.ic.RandomTruncatedFourierSeries(1))(
    NUM_POINTS, key=jax.random.PRNGKey(0)
)

fisher_kpp_trj = ex.rollout(fisher_kpp_stepper, 200, include_init=True)(u_0)

plt.imshow(fisher_kpp_trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=-1, vmax=1)
<matplotlib.image.AxesImage at 0x709da25fb050>
No description has been provided for this image

All together¤

fig, ax_s = plt.subplots(4, 4, figsize=(15, 10))

for (
    ax,
    trj,
    title,
    limits,
) in zip(
    ax_s.flatten(),
    [
        advection_trj,
        diffusion_trj,
        advection_diffusion_trj,
        dispersion_trj,
        hyper_diffusion_trj,
        general_linear_trj,
        burgers_trj,
        burgers_trj_conservative,
        ks_trj,
        conservative_ks_trj,
        kdv_trj,
        general_gradient_norm_trj,
        general_convection_trj,
        fisher_kpp_trj,
    ],
    [
        "Advection",
        "Diffusion",
        "Advection-Diffusion",
        "Dispersion",
        "Hyper-Diffusion",
        "General Linear",
        "Burgers",
        "Conservative Burgers",
        "Kuramoto-Sivashinsky",
        "Conservative Kuramoto-Sivashinsky",
        "Korteweg-de Vries",
        "General Gradient Norm",
        "General Convection",
        "Fisher-KPP",
    ],
    [
        (-1, 1),
        (-1, 1),
        (-1, 1),
        (-1, 1),
        (-1, 1),
        (-1, 1),
        (-1, 1),
        (-1, 1),
        (-6, 6),
        (-2.5, 2.5),
        (-1, 1),
        (-1, 1),
        (-1, 1),
        (-1, 1),
    ],
):
    ax.imshow(
        trj[:, 0, :].T, origin="lower", cmap="RdBu_r", vmin=limits[0], vmax=limits[1]
    )
    ax.set_title(title)
    ax.set_xlabel("time")
    ax.set_ylabel("space")
No description has been provided for this image