Skip to content

Getting Started¤

Installation¤

Clone the repository, navigate to the folder and install the package with pip:

pip install pdequinox

Requires Python 3.10+ and JAX 0.4.13+. 👉 JAX install guide.

Quickstart¤

Train a UNet to become an emulator for the 1D Poisson equation.

import jax
import jax.numpy as jnp
import equinox as eqx
import optax  # `pip install optax`
import pdequinox as pdeqx
from tqdm import tqdm  # `pip install tqdm`

force_fields, displacement_fields = pdeqx.sample_data.poisson_1d_dirichlet(
    key=jax.random.PRNGKey(0)
)

force_fields_train = force_fields[:800]
force_fields_test = force_fields[800:]
displacement_fields_train = displacement_fields[:800]
displacement_fields_test = displacement_fields[800:]

unet = pdeqx.arch.ClassicUNet(1, 1, 1, key=jax.random.PRNGKey(1))

def loss_fn(model, x, y):
    y_pref = jax.vmap(model)(x)
    return jnp.mean((y_pref - y) ** 2)

opt = optax.adam(3e-4)
opt_state = opt.init(eqx.filter(unet, eqx.is_array))

@eqx.filter_jit
def update_fn(model, state, x, y):
    loss, grad = eqx.filter_value_and_grad(loss_fn)(model, x, y)
    updates, new_state = opt.update(grad, state, model)
    new_model = eqx.apply_updates(model, updates)
    return new_model, new_state, loss

loss_history = []
shuffle_key = jax.random.PRNGKey(151)
for epoch in tqdm(range(100)):
    shuffle_key, subkey = jax.random.split(shuffle_key)

    for batch in pdeqx.dataloader(
        (force_fields_train, displacement_fields_train),
        batch_size=32,
        key=subkey
    ):
        unet, opt_state, loss = update_fn(
            unet,
            opt_state,
            *batch,
        )
        loss_history.append(loss)

Features¤

  • Based on JAX:
    • One of the best Automatic Differentiation engines (forward & reverse)
    • Automatic vectorization
    • Backend-agnostic code (run on CPU, GPU, and TPU)
  • Built on top of Equinox:
    • Single-Batch by design
    • Integration into the Equinox SciML ecosystem
  • Agnostic to the spatial dimension (works for 1D, 2D, and 3D)
  • Agnostic to the boundary condition (works for Dirichlet, Neumann, and periodic BCs)
  • Composability
  • Tools to count parameters and assess receptive fields

Citation¤

This package was developed as part of the APEBench paper (arxiv.org/abs/2411.00180) (accepted at Neurips 2024). If you find it useful for your research, please consider citing it:

@article{koehler2024apebench,
  title={{APEBench}: A Benchmark for Autoregressive Neural Emulators of {PDE}s},
  author={Felix Koehler and Simon Niedermayr and R{\"}udiger Westermann and Nils Thuerey},
  journal={Advances in Neural Information Processing Systems (NeurIPS)},
  volume={38},
  year={2024}
}

(Feel free to also give the project a star on GitHub if you like it.)

Here you can find the APEBench benchmark suite.

License¤

MIT, see here


fkoehler.site  ·  GitHub @ceyron  ·  X @felix_m_koehler