Skip to content

Getting started¤

Installation¤

Clone the repository, navigate to the folder and install the package with pip:

pip install trainax

Requires Python 3.10+ and JAX 0.4.13+. 👉 JAX install guide.

Quickstart¤

Train a kernel size 2 linear convolution (no bias) to become an emulator for the 1D advection problem.

import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import trainax as tx

CFL = -0.75

ref_data = tx.sample_data.advection_1d_periodic(
    cfl = CFL,
    key = jax.random.PRNGKey(0),
)

linear_conv_kernel_2 = eqx.nn.Conv1d(
    1, 1, 2,
    padding="SAME", padding_mode="CIRCULAR", use_bias=False,
    key=jax.random.PRNGKey(73)
)

sup_1_trainer, sup_5_trainer, sup_20_trainer = (
    tx.trainer.SupervisedTrainer(
        ref_data,
        num_rollout_steps=r,
        optimizer=optax.adam(1e-2),
        num_training_steps=1000,
        batch_size=32,
    )
    for r in (1, 5, 20)
)

sup_1_conv, sup_1_loss_history = sup_1_trainer(
    linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
sup_5_conv, sup_5_loss_history = sup_5_trainer(
    linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
sup_20_conv, sup_20_loss_history = sup_20_trainer(
    linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)

FOU_STENCIL = jnp.array([1+CFL, -CFL])

print(jnp.linalg.norm(sup_1_conv.weight - FOU_STENCIL))   # 0.033
print(jnp.linalg.norm(sup_5_conv.weight - FOU_STENCIL))   # 0.025
print(jnp.linalg.norm(sup_20_conv.weight - FOU_STENCIL))  # 0.017

Increasing the supervised unrolling steps during training makes the learned stencil come closer to the numerical FOU stencil.

Features¤

  • Wide collection of unrolled training methodologies:
    • Supervised
    • Diverted Chain
    • Mix Chain
    • Residuum
  • Based on JAX:
    • One of the best Automatic Differentiation engines (forward & reverse)
    • Automatic vectorization
    • Backend-agnostic code (run on CPU, GPU, and TPU)
  • Build on top and compatible with Equinox
  • Batch-Parallel Training
  • Collection of Callbacks
  • Composability

Citation¤

This package was developed as part of the APEBench paper (arxiv.org/abs/2411.00180) (accepted at Neurips 2024). If you find it useful for your research, please consider citing it:

@article{koehler2024apebench,
  title={{APEBench}: A Benchmark for Autoregressive Neural Emulators of {PDE}s},
  author={Felix Koehler and Simon Niedermayr and R{\"}udiger Westermann and Nils Thuerey},
  journal={Advances in Neural Information Processing Systems (NeurIPS)},
  volume={38},
  year={2024}
}

(Feel free to also give the project a star on GitHub if you like it.)

Here you can find the APEBench benchmark suite.

Funding¤

The main author (Felix Koehler) is a PhD student in the group of Prof. Thuerey at TUM and his research is funded by the Munich Center for Machine Learning.

License¤

MIT, see here


fkoehler.site  ·  GitHub @ceyron  ·  X @felix_m_koehler  ·  LinkedIn Felix Köhler