The metrics of comparing fields in Exponax¤
There are four major classes of metrics:
- Spatial-based (that work in physical space)
- Fourier-based (that work in the coefficient space)
- Correlation-based
- Derivative-based (which sugarcoat the functionalities to Fourier-based approaches to achieve Sobolev-like norms)
Class 1., 2., and 4. can be further divided into: 1. Absolute metrics (i.e., related to the MAE) 2. Absolute squared metrics (i.e., related to the MSE) 3. Rooted metrics (i.e., related to the RMSE)
Then for each of the three, there is both the absolute version and a relative/normalized version. For all spatial-based metrics, MAE, MSE, and RMSE also come with a symmetric version.
All metrics computation work on single state arrays, i.e., arrays with a leading channel axis and one, two, or three subsequent spatial axes. The arrays shall not have leading batch axes. To work with batched arrays use jax.vmap and then reduce, e.g., by jnp.mean. Alternatively, use the convinience wrapper exponax.metrics.mean_metric.
All metrics sum over the channel axis.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import exponax as ex
Setup: Reference and Predicted States¤
Let's create two 1D fields using DiffusedNoise. We use the same random key
so the signals share large-scale structure, but different intensity values
produce different levels of smoothing.
NUM_POINTS = 100
u_ref = ex.ic.DiffusedNoise(1, intensity=0.001)(NUM_POINTS, key=jax.random.PRNGKey(0))
u_pred = ex.ic.DiffusedNoise(1, intensity=0.0005)(NUM_POINTS, key=jax.random.PRNGKey(0))
print("u_ref shape:", u_ref.shape)
print("u_pred shape:", u_pred.shape)
fig, axes = plt.subplots(1, 2, figsize=(10, 3), sharey=True)
ex.viz.plot_state_1d(u_ref, ax=axes[0])
axes[0].set_title("Reference")
ex.viz.plot_state_1d(u_pred, ax=axes[1])
axes[1].set_title("Prediction")
plt.tight_layout()
plt.show()
The Standard Candidates: MAE, MSE, RMSE¤
Absolute Metrics¤
The three workhorses of error measurement:
- MAE (Mean Absolute Error) — related to the L1 norm of the error
- MSE (Mean Squared Error) — related to the squared L2 norm of the error
- RMSE (Root Mean Squared Error) — related to the L2 norm of the error
mae = ex.metrics.MAE(u_pred, u_ref)
mse = ex.metrics.MSE(u_pred, u_ref)
rmse = ex.metrics.RMSE(u_pred, u_ref)
print(f"MAE = {mae:.6f}")
print(f"MSE = {mse:.6f}")
print(f"RMSE = {rmse:.6f}")
As expected, RMSE = sqrt(MSE):
print(f"sqrt(MSE) = {jnp.sqrt(mse):.6f}")
print(f"RMSE = {rmse:.6f}")
If no reference is provided, the metric computes the norm of the state itself (i.e., the error against zero):
print(f"RMSE(u_ref, ref=None) = {ex.metrics.RMSE(u_ref):.6f} (norm of u_ref)")
print(f"RMSE(u_ref, ref=0) = {ex.metrics.RMSE(u_ref, jnp.zeros_like(u_ref)):.6f}")
Normalized/Relative Metrics¤
The normalized variants divide the absolute metric by the norm of the reference. This makes them scale-invariant: the error is expressed relative to the magnitude of the reference signal.
nmae = ex.metrics.nMAE(u_pred, u_ref)
nmse = ex.metrics.nMSE(u_pred, u_ref)
nrmse = ex.metrics.nRMSE(u_pred, u_ref)
print(f"nMAE = {nmae:.6f}")
print(f"nMSE = {nmse:.6f}")
print(f"nRMSE = {nrmse:.6f}")
Scale invariance: multiplying both signals by the same constant leaves nRMSE unchanged:
scale = 10.0
nrmse_original = ex.metrics.nRMSE(u_pred, u_ref)
nrmse_scaled = ex.metrics.nRMSE(scale * u_pred, scale * u_ref)
print(f"nRMSE (original) = {nrmse_original:.6f}")
print(f"nRMSE (x{scale:.0f}) = {nrmse_scaled:.6f}")
There are also symmetric variants (sMAE, sMSE, sRMSE) that normalize
by the sum of the norms of both signals. These are bounded between 0 and the
number of channels C:
print(f"sRMSE = {ex.metrics.sRMSE(u_pred, u_ref):.6f} (bounded in [0, C])")
Why it needs the domain size?¤
All metrics approximate continuous integrals via the trapezoidal rule:
Because of the \((L/N)^D\) grid spacing factor, changing the domain_extent
changes the absolute metric values.
for L in [1.0, 2.0, 5.0]:
mse_L = ex.metrics.MSE(u_pred, u_ref, domain_extent=L)
nrmse_L = ex.metrics.nRMSE(u_pred, u_ref, domain_extent=L)
print(f"L={L:.1f}: MSE={mse_L:.6f} nRMSE={nrmse_L:.6f}")
Notice that MSE changes with domain extent, but nRMSE stays constant because
the domain-extent factor cancels in the ratio.
Practical advice: When using metrics as optimization objectives (e.g., for
training neural emulators), the domain extent is just a constant factor and does
not affect the optimizer. But when comparing metrics across different setups,
make sure domain_extent is set correctly.
Correlation¤
Correlation measures the shape similarity between two fields via a normalized dot product. It ranges from -1 (anti-correlated) to +1 (identical shape). Crucially, it does not capture amplitude differences.
corr = ex.metrics.correlation(u_pred, u_ref)
print(f"correlation(u_pred, u_ref) = {corr:.6f}")
# Identical fields -> correlation = 1.0
print(f"correlation(u_ref, u_ref) = {ex.metrics.correlation(u_ref, u_ref):.6f}")
# Amplitude scaling does not change correlation
print(
f"correlation(10*u_pred, u_ref) = {ex.metrics.correlation(10 * u_pred, u_ref):.6f}"
)
# Two uncorrelated random fields -> correlation ~ 0
u_random = ex.ic.DiffusedNoise(1, intensity=0.001)(
NUM_POINTS, key=jax.random.PRNGKey(42)
)
print(f"correlation(u_ref, u_random) = {ex.metrics.correlation(u_ref, u_random):.6f}")
Fourier-based Metrics¤
Wait? Isn't that my MSE? A quick intro to Parseval's theorem¤
Parseval's theorem tells us that the L2 norm in physical space equals the
L2 norm in Fourier space. So fourier_MSE should numerically agree with MSE:
mse_spatial = ex.metrics.MSE(u_pred, u_ref)
mse_fourier = ex.metrics.fourier_MSE(u_pred, u_ref)
print(f"MSE (spatial) = {mse_spatial:.8f}")
print(f"fourier_MSE = {mse_fourier:.8f}")
print(f"Difference = {abs(mse_spatial - mse_fourier):.2e}")
# Same holds for RMSE and nRMSE
print(f"RMSE = {ex.metrics.RMSE(u_pred, u_ref):.8f}")
print(f"fourier_RMSE = {ex.metrics.fourier_RMSE(u_pred, u_ref):.8f}")
print()
print(f"nRMSE = {ex.metrics.nRMSE(u_pred, u_ref):.8f}")
print(f"fourier_nRMSE = {ex.metrics.fourier_nRMSE(u_pred, u_ref):.8f}")
Filtering and Scale-Specific Metrics¤
The Fourier-based metrics accept low and high parameters to restrict the
error computation to specific frequency (wavenumber) ranges. This lets you
measure the error at different spatial scales.
Let's demonstrate: we add high-frequency noise to a signal and show that the low-frequency error remains small while the full-spectrum error is large.
# Add high-frequency noise
noise = 0.3 * ex.ic.DiffusedNoise(1, intensity=0.00001)(
NUM_POINTS, key=jax.random.PRNGKey(1)
)
u_noisy = u_ref + noise
fig, axes = plt.subplots(1, 2, figsize=(10, 3), sharey=True)
ex.viz.plot_state_1d(u_ref, ax=axes[0])
axes[0].set_title("Reference")
ex.viz.plot_state_1d(u_noisy, ax=axes[1])
axes[1].set_title("Reference + High-Freq Noise")
plt.tight_layout()
plt.show()
print("Full spectrum:")
print(f" fourier_nRMSE = {ex.metrics.fourier_nRMSE(u_noisy, u_ref):.6f}")
print()
print("Low frequencies only (wavenumbers 0..5):")
print(f" fourier_nRMSE = {ex.metrics.fourier_nRMSE(u_noisy, u_ref, high=5):.6f}")
print()
print("High frequencies only (wavenumbers 5..):")
print(f" fourier_nRMSE = {ex.metrics.fourier_nRMSE(u_noisy, u_ref, low=5):.6f}")
This also works in higher dimensions. Here is a 2D example:
u_ref_2d = ex.ic.DiffusedNoise(2, intensity=0.001)(64, key=jax.random.PRNGKey(0))
u_pred_2d = ex.ic.DiffusedNoise(2, intensity=0.0005)(64, key=jax.random.PRNGKey(0))
fig, axes = plt.subplots(1, 2, figsize=(8, 3.5))
ex.viz.plot_state_2d(u_ref_2d, ax=axes[0])
axes[0].set_title("Reference (2D)")
ex.viz.plot_state_2d(u_pred_2d, ax=axes[1])
axes[1].set_title("Prediction (2D)")
plt.tight_layout()
plt.show()
print("2D fourier_nRMSE (full spectrum):")
print(f" {ex.metrics.fourier_nRMSE(u_pred_2d, u_ref_2d):.6f}")
print("2D fourier_nRMSE (low freq, high=5):")
print(f" {ex.metrics.fourier_nRMSE(u_pred_2d, u_ref_2d, high=5):.6f}")
print("2D fourier_nRMSE (high freq, low=5):")
print(f" {ex.metrics.fourier_nRMSE(u_pred_2d, u_ref_2d, low=5):.6f}")
Metrics with derivatives¤
The Fourier-based metrics support a derivative_order parameter. Setting
derivative_order=1 computes the error of the first derivative (done spectrally
by multiplying with ik in Fourier space — no finite differences needed).
rmse_0 = ex.metrics.fourier_RMSE(u_pred, u_ref)
rmse_1 = ex.metrics.fourier_RMSE(u_pred, u_ref, derivative_order=1)
rmse_2 = ex.metrics.fourier_RMSE(u_pred, u_ref, derivative_order=2)
print(f"fourier_RMSE (0th derivative) = {rmse_0:.6f}")
print(f"fourier_RMSE (1st derivative) = {rmse_1:.6f}")
print(f"fourier_RMSE (2nd derivative) = {rmse_2:.6f}")
Higher derivative orders amplify high-frequency errors, making these metrics sensitive to fine-scale discrepancies.
Sobolev-like Metrics¤
Wait? Who is Sobolev?¤
The \(H^1\) (Sobolev) norm combines the \(L^2\) norm of the function itself with the \(L^2\) norm of its gradient:
This means \(H^1\) metrics penalize both value errors and derivative/smoothness errors. They are especially useful for detecting predictions that have the right large-scale structure but wrong fine-scale details.
u_strongly_diffused = ex.ic.DiffusedNoise(1, intensity=0.001)(
100, key=jax.random.PRNGKey(0)
)
u_less_diffused = ex.ic.DiffusedNoise(1, intensity=0.0003)(
100, key=jax.random.PRNGKey(0)
)
ex.viz.plot_state_1d(jnp.concatenate([u_strongly_diffused, u_less_diffused]))
nrmse_val = ex.metrics.nRMSE(u_strongly_diffused, u_less_diffused)
h1_nrmse_val = ex.metrics.H1_nRMSE(u_strongly_diffused, u_less_diffused)
print(f"nRMSE = {nrmse_val:.6f}")
print(f"H1_nRMSE = {h1_nrmse_val:.6f}")
The H1_nRMSE is significantly larger than nRMSE because the two fields
differ substantially in their gradient (smoothness) content.
Application: Detecting Blurry Predictions of Neural Emulators¤
A common failure mode of neural PDE emulators is producing blurry
predictions: the large-scale structure is correct, but fine details are smeared
out. Standard nRMSE may look acceptable, but H1_nRMSE reveals the
smoothness mismatch.
Let's simulate this scenario with the 1D Burgers equation.
# Run a Burgers simulation to get a "ground truth" state
burgers = ex.stepper.Burgers(
num_spatial_dims=1,
domain_extent=1.0,
num_points=100,
dt=0.1,
diffusivity=0.01,
)
u0 = ex.ic.DiffusedNoise(1, intensity=0.0005, max_one=True)(
100, key=jax.random.PRNGKey(0)
)
# Step forward a few times to develop sharp gradients
trj = ex.rollout(burgers, 5, include_init=True)(u0)
u_truth = trj[-1] # ground truth at t=0.5
print("u_truth shape:", u_truth.shape)
# Create a "blurry" prediction by running with much higher diffusivity
burgers_blurry = ex.stepper.Burgers(
num_spatial_dims=1,
domain_extent=1.0,
num_points=100,
dt=0.1,
diffusivity=0.05, # 5x more diffusion
)
trj_blurry = ex.rollout(burgers_blurry, 5, include_init=True)(u0)
u_blurry = trj_blurry[-1]
fig, ax = plt.subplots(figsize=(6, 3))
ex.viz.plot_state_1d(
jnp.concatenate([u_truth, u_blurry]),
ax=ax,
labels=["Ground Truth", "Blurry Prediction"],
)
ax.legend()
ax.set_title("Burgers @ t=0.5")
plt.tight_layout()
plt.show()
print(f"nRMSE = {ex.metrics.nRMSE(u_blurry, u_truth):.6f}")
print(f"H1_nRMSE = {ex.metrics.H1_nRMSE(u_blurry, u_truth):.6f}")
print()
print("The H1 metric is more sensitive to the blurriness because it also")
print("penalizes the smoothed-out gradients.")
Working with Batches¤
All metrics operate on single states (no batch axis). To compute metrics over a
batch, use exponax.metrics.mean_metric which vmaps and averages for you:
# Generate a small batch of initial conditions
keys = jax.random.split(jax.random.PRNGKey(0), 8)
ic_gen_ref = ex.ic.DiffusedNoise(1, intensity=0.001)
ic_gen_pred = ex.ic.DiffusedNoise(1, intensity=0.0005)
batch_ref = jax.vmap(lambda k: ic_gen_ref(NUM_POINTS, key=k))(keys)
batch_pred = jax.vmap(lambda k: ic_gen_pred(NUM_POINTS, key=k))(keys)
print(f"batch_ref shape: {batch_ref.shape}")
print(f"batch_pred shape: {batch_pred.shape}")
mean_nrmse = ex.metrics.mean_metric(ex.metrics.nRMSE, batch_pred, batch_ref)
mean_h1 = ex.metrics.mean_metric(ex.metrics.H1_nRMSE, batch_pred, batch_ref)
print(f"Mean nRMSE over batch = {mean_nrmse:.6f}")
print(f"Mean H1_nRMSE over batch = {mean_h1:.6f}")