Skip to content

Spatial-based¤

exponax.metrics.MSE ¤

MSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Optional[Float[Array, "C ... N"]] = None,
    *,
    domain_extent: float = 1.0
) -> float

Compute the mean squared error (MSE) between two states.

∑_(channels) ∑_(space) (L/N)ᴰ |uₕ - uₕʳ|²

Given the correct domain_extent, this is consistent to the following functional norm:

‖ u - uʳ ‖²_L²(Ω) = ∫_Ω |u(x) - uʳ(x)|² dx

The channel axis is summed after the aggregation.

Tip

To apply this function to a state tensor with a leading batch axis, use jax.vmap. Then the batch axis can be reduced, e.g., by jnp.mean. As a helper for this, exponax.metrics.mean_metric is provided.

Arguments:

  • u_pred: The state array, must follow the Exponax convention with a leading channel axis, and either one, two, or three subsequent spatial axes.
  • u_ref: The reference state array. Must have the same shape as u_pred. If not specified, the MSE is computed against zero, i.e., the norm of u_pred.
  • domain_extent: The extent L of the domain Ω = (0, L)ᴰ. Must be provide to get the correctly consistent norm. If this metric is used an optimization objective, it can often be ignored since it only contributes a multiplicative factor.
Source code in exponax/metrics/_spatial.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def MSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Optional[Float[Array, "C ... N"]] = None,
    *,
    domain_extent: float = 1.0,
) -> float:
    """
    Compute the mean squared error (MSE) between two states.

        ∑_(channels) ∑_(space) (L/N)ᴰ |uₕ - uₕʳ|²

    Given the correct `domain_extent`, this is consistent to the following
    functional norm:

        ‖ u - uʳ ‖²_L²(Ω) = ∫_Ω |u(x) - uʳ(x)|² dx

    The channel axis is summed **after** the aggregation.

    !!! tip
        To apply this function to a state tensor with a leading batch axis, use
        `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
        a helper for this, [`exponax.metrics.mean_metric`][] is provided.


    **Arguments:**

    - `u_pred`: The state array, must follow the `Exponax` convention with a
        leading channel axis, and either one, two, or three subsequent spatial
        axes.
    - `u_ref`: The reference state array. Must have the same shape as `u_pred`.
        If not specified, the MSE is computed against zero, i.e., the norm of
        `u_pred`.
    - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
        provide to get the correctly consistent norm. If this metric is used an
        optimization objective, it can often be ignored since it only
        contributes a multiplicative factor.
    """
    return spatial_norm(
        u_pred,
        u_ref,
        mode="absolute",
        domain_extent=domain_extent,
        inner_exponent=2.0,
        outer_exponent=1.0,
    )

exponax.metrics.MAE ¤

MAE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Optional[Float[Array, "C ... N"]] = None,
    *,
    domain_extent: float = 1.0
) -> float

Compute the mean absolute error (MAE) between two states.

∑_(channels) ∑_(space) (L/N)ᴰ |uₕ - uₕʳ|

Given the correct domain_extent, this is consistent to the following functional norm:

‖ u - uʳ ‖_L¹(Ω) = ∫_Ω |u(x) - uʳ(x)| dx

The channel axis is summed after the aggregation.

Tip

To apply this function to a state tensor with a leading batch axis, use jax.vmap. Then the batch axis can be reduced, e.g., by jnp.mean. As a helper for this, exponax.metrics.mean_metric is provided.

Arguments:

  • u_pred: The state array, must follow the Exponax convention with a leading channel axis, and either one, two, or three subsequent spatial axes.
  • u_ref: The reference state array. Must have the same shape as u_pred. If not specified, the MAE is computed against zero, i.e., the norm of u_pred.
  • domain_extent: The extent L of the domain Ω = (0, L)ᴰ. Must be provide to get the correctly consistent norm. If this metric is used an optimization objective, it can often be ignored since it only contributes a multiplicative factor.
Source code in exponax/metrics/_spatial.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def MAE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Optional[Float[Array, "C ... N"]] = None,
    *,
    domain_extent: float = 1.0,
) -> float:
    """
    Compute the mean absolute error (MAE) between two states.

        ∑_(channels) ∑_(space) (L/N)ᴰ |uₕ - uₕʳ|

    Given the correct `domain_extent`, this is consistent to the following
    functional norm:

        ‖ u - uʳ ‖_L¹(Ω) = ∫_Ω |u(x) - uʳ(x)| dx

    The channel axis is summed **after** the aggregation.

    !!! tip
        To apply this function to a state tensor with a leading batch axis, use
        `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
        a helper for this, [`exponax.metrics.mean_metric`][] is provided.


    **Arguments:**

    - `u_pred`: The state array, must follow the `Exponax` convention with a
        leading channel axis, and either one, two, or three subsequent spatial
        axes.
    - `u_ref`: The reference state array. Must have the same shape as `u_pred`.
        If not specified, the MAE is computed against zero, i.e., the norm of
        `u_pred`.
    - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
        provide to get the correctly consistent norm. If this metric is used an
        optimization objective, it can often be ignored since it only
        contributes a multiplicative factor.
    """
    return spatial_norm(
        u_pred,
        u_ref,
        mode="absolute",
        domain_extent=domain_extent,
        inner_exponent=1.0,
        outer_exponent=1.0,
    )

exponax.metrics.RMSE ¤

RMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Optional[Float[Array, "C ... N"]] = None,
    *,
    domain_extent: float = 1.0
) -> float

Compute the root mean squared error (RMSE) between two states.

(∑_(channels) √(∑_(space) (L/N)ᴰ |uₕ - uₕʳ|²))

Given the correct domain_extent, this is consistent to the following functional norm:

(‖ u - uʳ ‖_L²(Ω)) = √(∫_Ω |u(x) - uʳ(x)|² dx)

The channel axis is summed after the aggregation. Hence, it is also summed after the square root. If you need the RMSE per channel, consider using exponax.metrics.spatial_aggregator directly.

Tip

To apply this function to a state tensor with a leading batch axis, use jax.vmap. Then the batch axis can be reduced, e.g., by jnp.mean. As a helper for this, exponax.metrics.mean_metric is provided.

Arguments:

  • u_pred: The state array, must follow the Exponax convention with a leading channel axis, and either one, two, or three subsequent spatial axes.
  • u_ref: The reference state array. Must have the same shape as u_pred. If not specified, the RMSE is computed against zero, i.e., the norm of u_pred.
  • domain_extent: The extent L of the domain Ω = (0, L)ᴰ. Must be provide to get the correctly consistent norm. If this metric is used an optimization objective, it can often be ignored since it only contributes a multiplicative factor
Source code in exponax/metrics/_spatial.py
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
def RMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Optional[Float[Array, "C ... N"]] = None,
    *,
    domain_extent: float = 1.0,
) -> float:
    """
    Compute the root mean squared error (RMSE) between two states.

        (∑_(channels) √(∑_(space) (L/N)ᴰ |uₕ - uₕʳ|²))

    Given the correct `domain_extent`, this is consistent to the following
    functional norm:

        (‖ u - uʳ ‖_L²(Ω)) = √(∫_Ω |u(x) - uʳ(x)|² dx)

    The channel axis is summed **after** the aggregation. Hence, it is also
    summed **after** the square root. If you need the RMSE per channel, consider
    using [`exponax.metrics.spatial_aggregator`][] directly.

    !!! tip
        To apply this function to a state tensor with a leading batch axis, use
        `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
        a helper for this, [`exponax.metrics.mean_metric`][] is provided.


    **Arguments:**

    - `u_pred`: The state array, must follow the `Exponax` convention with a
        leading channel axis, and either one, two, or three subsequent spatial
        axes.
    - `u_ref`: The reference state array. Must have the same shape as `u_pred`.
        If not specified, the RMSE is computed against zero, i.e., the norm of
        `u_pred`.
    - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
        provide to get the correctly consistent norm. If this metric is used an
        optimization objective, it can often be ignored since it only
        contributes a multiplicative factor
    """
    return spatial_norm(
        u_pred,
        u_ref,
        mode="absolute",
        domain_extent=domain_extent,
        inner_exponent=2.0,
        outer_exponent=0.5,
    )

exponax.metrics.nMSE ¤

nMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0
) -> float

Compute the normalized mean squared error (nMSE) between two states.

∑_(channels) [∑_(space) (L/N)ᴰ |uₕ - uₕʳ|² / ∑_(space) (L/N)ᴰ |uₕʳ|²]

Given the correct domain_extent, this is consistent to the following functional norm:

‖ u - uʳ ‖²_L²(Ω) / ‖ uʳ ‖²_L²(Ω) = ∫_Ω |u(x) - uʳ(x)|² dx / ∫_Ω |uʳ(x)|² dx

The channel axis is summed after the aggregation.

Tip

To apply this function to a state tensor with a leading batch axis, use jax.vmap. Then the batch axis can be reduced, e.g., by jnp.mean. As a helper for this, exponax.metrics.mean_metric is provided.

Arguments:

  • u_pred: The state array, must follow the Exponax convention with a leading channel axis, and either one, two, or three subsequent spatial axes.
  • u_ref: The reference state array. Must have the same shape as u_pred.
  • domain_extent: The extent L of the domain Ω = (0, L)ᴰ. Must be provide to get the correctly consistent norm. If this metric is used an optimization objective, it can often be ignored since it only contributes a multiplicative factor.
Source code in exponax/metrics/_spatial.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
def nMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0,
) -> float:
    """
    Compute the normalized mean squared error (nMSE) between two states.

        ∑_(channels) [∑_(space) (L/N)ᴰ |uₕ - uₕʳ|² / ∑_(space) (L/N)ᴰ |uₕʳ|²]

    Given the correct `domain_extent`, this is consistent to the following
    functional norm:

        ‖ u - uʳ ‖²_L²(Ω) / ‖ uʳ ‖²_L²(Ω) = ∫_Ω |u(x) - uʳ(x)|² dx / ∫_Ω |uʳ(x)|² dx

    The channel axis is summed **after** the aggregation.

    !!! tip
        To apply this function to a state tensor with a leading batch axis, use
        `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
        a helper for this, [`exponax.metrics.mean_metric`][] is provided.


    **Arguments:**

    - `u_pred`: The state array, must follow the `Exponax` convention with a
        leading channel axis, and either one, two, or three subsequent spatial
        axes.
    - `u_ref`: The reference state array. Must have the same shape as `u_pred`.
    - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
        provide to get the correctly consistent norm. If this metric is used an
        optimization objective, it can often be ignored since it only
        contributes a multiplicative factor.
    """
    return spatial_norm(
        u_pred,
        u_ref,
        mode="normalized",
        domain_extent=domain_extent,
        inner_exponent=2.0,
        outer_exponent=1.0,
    )

exponax.metrics.nMAE ¤

nMAE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0
) -> float

Compute the normalized mean absolute error (nMAE) between two states.

∑_(channels) [∑_(space) (L/N)ᴰ |uₕ - uₕʳ| / ∑_(space) (L/N)ᴰ |uₕʳ|]

Given the correct domain_extent, this is consistent to the following functional norm:

‖ u - uʳ ‖_L¹(Ω) / ‖ uʳ ‖_L¹(Ω) = ∫_Ω |u(x) - uʳ(x)| dx / ∫_Ω |uʳ(x)| dx

The channel axis is summed after the aggregation.

Tip

To apply this function to a state tensor with a leading batch axis, use jax.vmap. Then the batch axis can be reduced, e.g., by jnp.mean. As a helper for this, exponax.metrics.mean_metric is provided.

Arguments:

  • u_pred: The state array, must follow the Exponax convention with a leading channel axis, and either one, two, or three subsequent spatial axes.
  • u_ref: The reference state array. Must have the same shape as u_pred.
  • domain_extent: The extent L of the domain Ω = (0, L)ᴰ. Must be provide to get the correctly consistent norm. If this metric is used an optimization objective, it can often be ignored since it only contributes a multiplicative factor.
Source code in exponax/metrics/_spatial.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def nMAE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0,
) -> float:
    """
    Compute the normalized mean absolute error (nMAE) between two states.

        ∑_(channels) [∑_(space) (L/N)ᴰ |uₕ - uₕʳ| / ∑_(space) (L/N)ᴰ |uₕʳ|]

    Given the correct `domain_extent`, this is consistent to the following
    functional norm:

        ‖ u - uʳ ‖_L¹(Ω) / ‖ uʳ ‖_L¹(Ω) = ∫_Ω |u(x) - uʳ(x)| dx / ∫_Ω |uʳ(x)| dx

    The channel axis is summed **after** the aggregation.

    !!! tip
        To apply this function to a state tensor with a leading batch axis, use
        `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
        a helper for this, [`exponax.metrics.mean_metric`][] is provided.


    **Arguments:**

    - `u_pred`: The state array, must follow the `Exponax` convention with a
        leading channel axis, and either one, two, or three subsequent spatial
        axes.
    - `u_ref`: The reference state array. Must have the same shape as `u_pred`.
    - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
        provide to get the correctly consistent norm. If this metric is used an
        optimization objective, it can often be ignored since it only
        contributes a multiplicative factor.
    """
    return spatial_norm(
        u_pred,
        u_ref,
        mode="normalized",
        domain_extent=domain_extent,
        inner_exponent=1.0,
        outer_exponent=1.0,
    )

exponax.metrics.nRMSE ¤

nRMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0
) -> float

Compute the normalized root mean squared error (nRMSE) between two states.

∑_(channels) [√(∑_(space) (L/N)ᴰ |uₕ - uₕʳ|²) / √(∑_(space) (L/N)ᴰ
|uₕʳ|²)]

Given the correct domain_extent, this is consistent to the following functional norm:

(‖ u - uʳ ‖_L²(Ω) / ‖ uʳ ‖_L²(Ω)) = √(∫_Ω |u(x) - uʳ(x)|² dx / ∫_Ω
|uʳ(x)|² dx

The channel axis is summed after the aggregation. Hence, it is also summed after the square root and after normalization. If you need more fine-grained control, consider using exponax.metrics.spatial_aggregator directly.

Tip

To apply this function to a state tensor with a leading batch axis, use jax.vmap. Then the batch axis can be reduced, e.g., by jnp.mean. As a helper for this, exponax.metrics.mean_metric is provided.

Arguments:

  • u_pred: The state array, must follow the Exponax convention with a leading channel axis, and either one, two, or three subsequent spatial axes.
  • u_ref: The reference state array. Must have the same shape as u_pred.
  • domain_extent: The extent L of the domain Ω = (0, L)ᴰ. Must be provide to get the correctly consistent norm. If this metric is used an optimization objective, it can often be ignored since it only contributes a multiplicative factor
Source code in exponax/metrics/_spatial.py
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
def nRMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0,
) -> float:
    """
    Compute the normalized root mean squared error (nRMSE) between two states.

        ∑_(channels) [√(∑_(space) (L/N)ᴰ |uₕ - uₕʳ|²) / √(∑_(space) (L/N)ᴰ
        |uₕʳ|²)]

    Given the correct `domain_extent`, this is consistent to the following
    functional norm:

        (‖ u - uʳ ‖_L²(Ω) / ‖ uʳ ‖_L²(Ω)) = √(∫_Ω |u(x) - uʳ(x)|² dx / ∫_Ω
        |uʳ(x)|² dx

    The channel axis is summed **after** the aggregation. Hence, it is also
    summed **after** the square root and after normalization. If you need more
    fine-grained control, consider using
    [`exponax.metrics.spatial_aggregator`][] directly.

    !!! tip
        To apply this function to a state tensor with a leading batch axis, use
        `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
        a helper for this, [`exponax.metrics.mean_metric`][] is provided.


    **Arguments:**

    - `u_pred`: The state array, must follow the `Exponax` convention with a
        leading channel axis, and either one, two, or three subsequent spatial
        axes.
    - `u_ref`: The reference state array. Must have the same shape as `u_pred`.
    - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
        provide to get the correctly consistent norm. If this metric is used an
        optimization objective, it can often be ignored since it only
        contributes a multiplicative factor
    """
    return spatial_norm(
        u_pred,
        u_ref,
        mode="normalized",
        domain_extent=domain_extent,
        inner_exponent=2.0,
        outer_exponent=0.5,
    )

exponax.metrics.sMAE ¤

sMAE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0
) -> float

Compute the symmetric mean absolute error (sMAE) between two states.

∑_(channels) [2 ∑_(space) (L/N)ᴰ |uₕ - uₕʳ| / (∑_(space) (L/N)ᴰ |uₕ| + ∑_(space) (L/N)ᴰ |uₕʳ|)]

Given the correct domain_extent, this is consistent to the following functional norm:

2 ∫_Ω |u(x) - uʳ(x)| dx / (∫_Ω |u(x)| dx + ∫_Ω |uʳ(x)| dx)

The channel axis is summed after the aggregation.

Tip

To apply this function to a state tensor with a leading batch axis, use jax.vmap. Then the batch axis can be reduced, e.g., by jnp.mean. As a helper for this, exponax.metrics.mean_metric is provided.

Info

This symmetric metric is bounded between 0 and C with C being the number of channels.

Arguments:

  • u_pred: The state array, must follow the Exponax convention with a leading channel axis, and either one, two, or three subsequent spatial axes.
  • u_ref: The reference state array. Must have the same shape as u_pred.
  • domain_extent: The extent L of the domain Ω = (0, L)ᴰ. Must be provide to get the correctly consistent norm. If this metric is used an optimization objective, it can often be ignored since it only contributes a multiplicative factor.
Source code in exponax/metrics/_spatial.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def sMAE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0,
) -> float:
    """
    Compute the symmetric mean absolute error (sMAE) between two states.

        ∑_(channels) [2 ∑_(space) (L/N)ᴰ |uₕ - uₕʳ| / (∑_(space) (L/N)ᴰ |uₕ| + ∑_(space) (L/N)ᴰ |uₕʳ|)]

    Given the correct `domain_extent`, this is consistent to the following
    functional norm:

        2 ∫_Ω |u(x) - uʳ(x)| dx / (∫_Ω |u(x)| dx + ∫_Ω |uʳ(x)| dx)

    The channel axis is summed **after** the aggregation.

    !!! tip
        To apply this function to a state tensor with a leading batch axis, use
        `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
        a helper for this, [`exponax.metrics.mean_metric`][] is provided.

    !!! info
        This symmetric metric is bounded between 0 and C with C being the number
        of channels.


    **Arguments:**

    - `u_pred`: The state array, must follow the `Exponax` convention with a
        leading channel axis, and either one, two, or three subsequent spatial
        axes.
    - `u_ref`: The reference state array. Must have the same shape as `u_pred`.
    - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
        provide to get the correctly consistent norm. If this metric is used an
        optimization objective, it can often be ignored since it only
        contributes a multiplicative factor.
    """
    return spatial_norm(
        u_pred,
        u_ref,
        mode="symmetric",
        domain_extent=domain_extent,
        inner_exponent=1.0,
        outer_exponent=1.0,
    )

exponax.metrics.sMSE ¤

sMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0
) -> float

Compute the symmetric mean squared error (sMSE) between two states.

∑_(channels) [2 ∑_(space) (L/N)ᴰ |uₕ - uₕʳ|² / (∑_(space) (L/N)ᴰ |uₕ|² + ∑_(space) (L/N)ᴰ |uₕʳ|²)]

Given the correct domain_extent, this is consistent to the following functional norm:

2 ∫_Ω |u(x) - uʳ(x)|² dx / (∫_Ω |u(x)|² dx + ∫_Ω |uʳ(x)|² dx)

The channel axis is summed after the aggregation.

Tip

To apply this function to a state tensor with a leading batch axis, use jax.vmap. Then the batch axis can be reduced, e.g., by jnp.mean. As a helper for this, exponax.metrics.mean_metric is provided.

Info

This symmetric metric is bounded between 0 and C with C being the number of channels.

Arguments:

  • u_pred: The state array, must follow the Exponax convention with a leading channel axis, and either one, two, or three subsequent spatial axes.
  • u_ref: The reference state array. Must have the same shape as u_pred.
  • domain_extent: The extent L of the domain Ω = (0, L)ᴰ. Must be provide to get the correctly consistent norm. If this metric is used an optimization objective, it can often be ignored since it only contributes a multiplicative factor.
Source code in exponax/metrics/_spatial.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
def sMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0,
) -> float:
    """
    Compute the symmetric mean squared error (sMSE) between two states.

        ∑_(channels) [2 ∑_(space) (L/N)ᴰ |uₕ - uₕʳ|² / (∑_(space) (L/N)ᴰ |uₕ|² + ∑_(space) (L/N)ᴰ |uₕʳ|²)]

    Given the correct `domain_extent`, this is consistent to the following
    functional norm:

        2 ∫_Ω |u(x) - uʳ(x)|² dx / (∫_Ω |u(x)|² dx + ∫_Ω |uʳ(x)|² dx)

    The channel axis is summed **after** the aggregation.

    !!! tip
        To apply this function to a state tensor with a leading batch axis, use
        `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
        a helper for this, [`exponax.metrics.mean_metric`][] is provided.

    !!! info
        This symmetric metric is bounded between 0 and C with C being the number
        of channels.


    **Arguments:**

    - `u_pred`: The state array, must follow the `Exponax` convention with a
        leading channel axis, and either one, two, or three subsequent spatial
        axes.
    - `u_ref`: The reference state array. Must have the same shape as `u_pred`.
    - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
        provide to get the correctly consistent norm. If this metric is used an
        optimization objective, it can often be ignored since it only
        contributes a multiplicative factor.
    """
    return spatial_norm(
        u_pred,
        u_ref,
        mode="symmetric",
        domain_extent=domain_extent,
        inner_exponent=2.0,
        outer_exponent=1.0,
    )

exponax.metrics.sRMSE ¤

sRMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0
) -> float

Compute the symmetric root mean squared error (sRMSE) between two states.

∑_(channels) [2 √(∑_(space) (L/N)ᴰ |uₕ - uₕʳ|²) / (√(∑_(space) (L/N)ᴰ
|uₕ|²) + √(∑_(space) (L/N)ᴰ |uₕʳ|²))]

Given the correct domain_extent, this is consistent to the following functional norm:

2 √(∫_Ω |u(x) - uʳ(x)|² dx) / (√(∫_Ω |u(x)|² dx) + √(∫_Ω |uʳ(x)|² dx))

The channel axis is summed after the aggregation. Hence, it is also summed after the square root and after normalization. If you need more fine-grained control, consider using exponax.metrics.spatial_aggregator directly.

Tip

To apply this function to a state tensor with a leading batch axis, use jax.vmap. Then the batch axis can be reduced, e.g., by jnp.mean. As a helper for this, exponax.metrics.mean_metric is provided.

Info

This symmetric metric is bounded between 0 and C with C being the number of channels.

Arguments:

  • u_pred: The state array, must follow the Exponax convention with a leading channel axis, and either one, two, or three subsequent spatial axes.
  • u_ref: The reference state array. Must have the same shape as u_pred.
  • domain_extent: The extent L of the domain Ω = (0, L)ᴰ. Must be provide to get the correctly consistent norm. If this metric is used an optimization objective, it can often be ignored since it only contributes a multiplicative factor
Source code in exponax/metrics/_spatial.py
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
def sRMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0,
) -> float:
    """
    Compute the symmetric root mean squared error (sRMSE) between two states.

        ∑_(channels) [2 √(∑_(space) (L/N)ᴰ |uₕ - uₕʳ|²) / (√(∑_(space) (L/N)ᴰ
        |uₕ|²) + √(∑_(space) (L/N)ᴰ |uₕʳ|²))]

    Given the correct `domain_extent`, this is consistent to the following
    functional norm:

        2 √(∫_Ω |u(x) - uʳ(x)|² dx) / (√(∫_Ω |u(x)|² dx) + √(∫_Ω |uʳ(x)|² dx))

    The channel axis is summed **after** the aggregation. Hence, it is also
    summed **after** the square root and after normalization. If you need more
    fine-grained control, consider using
    [`exponax.metrics.spatial_aggregator`][] directly.

    !!! tip
        To apply this function to a state tensor with a leading batch axis, use
        `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
        a helper for this, [`exponax.metrics.mean_metric`][] is provided.

    !!! info
        This symmetric metric is bounded between 0 and C with C being the number
        of channels.


    **Arguments:**

    - `u_pred`: The state array, must follow the `Exponax` convention with a
        leading channel axis, and either one, two, or three subsequent spatial
        axes.
    - `u_ref`: The reference state array. Must have the same shape as `u_pred`.
    - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
        provide to get the correctly consistent norm. If this metric is used an
        optimization objective, it can often be ignored since it only contributes
        a multiplicative factor
    """
    return spatial_norm(
        u_pred,
        u_ref,
        mode="symmetric",
        domain_extent=domain_extent,
        inner_exponent=2.0,
        outer_exponent=0.5,
    )

exponax.metrics.spatial_norm ¤

spatial_norm(
    state: Float[Array, "C ... N"],
    state_ref: Optional[Float[Array, "C ... N"]] = None,
    *,
    mode: Literal[
        "absolute", "normalized", "symmetric"
    ] = "absolute",
    domain_extent: float = 1.0,
    inner_exponent: float = 2.0,
    outer_exponent: Optional[float] = None
) -> float

Compute the conistent counterpart of the Lᴾ functional norm.

See exponax.metrics.spatial_aggregator for more details. This function sums over the channel axis after aggregation. If you need more low-level control, consider using exponax.metrics.spatial_aggregator directly.

This function allows providing a second state (state_ref) to compute either the absolute, normalized, or symmetric difference. The "absolute" mode computes

(‖uₕ - uₕʳ‖_L^p(Ω))^(q*p)

while the "normalized" mode computes

(‖uₕ - uₕʳ‖_L^p(Ω))^(q*p) / ((‖uₕʳ‖_L^p(Ω))^(q*p))

and the "symmetric" mode computes

2 * (‖uₕ - uₕʳ‖_L^p(Ω))^(q*p) / ((‖uₕ‖_L^p(Ω))^(q*p) + (‖uₕʳ‖_L^p(Ω))^(q*p))

In either way, the channels are summed after the aggregation. The inner_exponent corresponds to p in the above formulas. The outer_exponent corresponds to q. If it is not specified, it is set to q = 1/p to get a valid norm.

Tip

To operate on states with a leading batch axis, use jax.vmap. Then the batch axis can be reduced, e.g., by jnp.mean. As a helper for this, exponax.metrics.mean_metric is provided.

Arguments:

  • state: The state tensor. Must follow the Exponax convention with a leading channel axis, and either one, two, or three subsequent spatial axes.
  • state_ref: The reference state tensor. Must have the same shape as state. If not specified, only the absolute norm of state is computed.
  • mode: The mode of the norm. Either "absolute", "normalized", or "symmetric".
  • domain_extent: The extent L of the domain Ω = (0, L)ᴰ.
  • inner_exponent: The exponent p in the L^p norm.
  • outer_exponent: The exponent q the result after aggregation is raised to. If not specified, it is set to q = 1/p.
Source code in exponax/metrics/_spatial.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def spatial_norm(
    state: Float[Array, "C ... N"],
    state_ref: Optional[Float[Array, "C ... N"]] = None,
    *,
    mode: Literal["absolute", "normalized", "symmetric"] = "absolute",
    domain_extent: float = 1.0,
    inner_exponent: float = 2.0,
    outer_exponent: Optional[float] = None,
) -> float:
    """
    Compute the conistent counterpart of the `Lᴾ` functional norm.

    See [`exponax.metrics.spatial_aggregator`][] for more details. This function
    sums over the channel axis **after aggregation**. If you need more low-level
    control, consider using [`exponax.metrics.spatial_aggregator`][] directly.

    This function allows providing a second state (`state_ref`) to compute
    either the absolute, normalized, or symmetric difference. The `"absolute"`
    mode computes

        (‖uₕ - uₕʳ‖_L^p(Ω))^(q*p)

    while the `"normalized"` mode computes

        (‖uₕ - uₕʳ‖_L^p(Ω))^(q*p) / ((‖uₕʳ‖_L^p(Ω))^(q*p))

    and the `"symmetric"` mode computes

        2 * (‖uₕ - uₕʳ‖_L^p(Ω))^(q*p) / ((‖uₕ‖_L^p(Ω))^(q*p) + (‖uₕʳ‖_L^p(Ω))^(q*p))

    In either way, the channels are summed **after** the aggregation. The
    `inner_exponent` corresponds to `p` in the above formulas. The
    `outer_exponent` corresponds to `q`. If it is not specified, it is set to `q
    = 1/p` to get a valid norm.

    !!! tip
        To operate on states with a leading batch axis, use `jax.vmap`. Then the
        batch axis can be reduced, e.g., by `jnp.mean`. As a helper for this,
        [`exponax.metrics.mean_metric`][] is provided.


    **Arguments:**

    - `state`: The state tensor. Must follow the `Exponax` convention with a
        leading channel axis, and either one, two, or three subsequent spatial
        axes.
    - `state_ref`: The reference state tensor. Must have the same shape as
        `state`. If not specified, only the absolute norm of `state` is
        computed.
    - `mode`: The mode of the norm. Either `"absolute"`, `"normalized"`, or
        `"symmetric"`.
    - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`.
    - `inner_exponent`: The exponent `p` in the L^p norm.
    - `outer_exponent`: The exponent `q` the result after aggregation is raised
        to. If not specified, it is set to `q = 1/p`.
    """
    if state_ref is None:
        if mode == "normalized":
            raise ValueError("mode 'normalized' requires state_ref")
        if mode == "symmetric":
            raise ValueError("mode 'symmetric' requires state_ref")
        diff = state
    else:
        diff = state - state_ref

    diff_norm_per_channel = jax.vmap(
        lambda s: spatial_aggregator(
            s,
            domain_extent=domain_extent,
            inner_exponent=inner_exponent,
            outer_exponent=outer_exponent,
        ),
    )(diff)

    if mode == "normalized":
        ref_norm_per_channel = jax.vmap(
            lambda r: spatial_aggregator(
                r,
                domain_extent=domain_extent,
                inner_exponent=inner_exponent,
                outer_exponent=outer_exponent,
            ),
        )(state_ref)
        normalized_diff_per_channel = diff_norm_per_channel / ref_norm_per_channel
        norm_per_channel = normalized_diff_per_channel
    elif mode == "symmetric":
        state_norm_per_channel = jax.vmap(
            lambda s: spatial_aggregator(
                s,
                domain_extent=domain_extent,
                inner_exponent=inner_exponent,
                outer_exponent=outer_exponent,
            ),
        )(state)
        ref_norm_per_channel = jax.vmap(
            lambda r: spatial_aggregator(
                r,
                domain_extent=domain_extent,
                inner_exponent=inner_exponent,
                outer_exponent=outer_exponent,
            ),
        )(state_ref)
        symmetric_diff_per_channel = (
            2 * diff_norm_per_channel / (state_norm_per_channel + ref_norm_per_channel)
        )
        norm_per_channel = symmetric_diff_per_channel
    else:
        norm_per_channel = diff_norm_per_channel

    return jnp.sum(norm_per_channel)

exponax.metrics.spatial_aggregator ¤

spatial_aggregator(
    state_no_channel: Float[Array, "... N"],
    *,
    num_spatial_dims: Optional[int] = None,
    domain_extent: float = 1.0,
    num_points: Optional[int] = None,
    inner_exponent: float = 2.0,
    outer_exponent: Optional[float] = None
) -> float

Aggregate over the spatial axes of a (channel-less) state tensor to get a consistent counterpart to a functional L^p norm in the continuous case.

Assuming the Exponax convention that the domain is always the scaled hypercube Ω = (0, L)ᴰ (with L = domain_extent) and each spatial dimension being discretized uniformly into N points (i.e., there are Nᴰ points in total), and the left boundary is considered a degree of freedom, and the right is not, there is the following relation between a continuous function u(x) and its discretely sampled counterpart uₕ

‖ u(x) ‖_Lᵖ(Ω) = (∫_Ω |u(x)|ᵖ dx)^(1/p) ≈ ( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^(1/p)

where the summation ∑ᵢ must be understood as a sum over all Nᴰ points across all spatial dimensions. The inner_exponent corresponds to p in the above formula. This function also allows setting the outer exponent q which via

( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^q

If it is not specified, it is set to q = 1/p to get a valid norm.

Tip

To apply this function to a state tensor with a leading channel axis, use jax.vmap.

Arguments:

  • state_no_channel: The state tensor without a leading channel axis.
  • num_spatial_dims: The number of spatial dimensions. If not specified, it is inferred from the number of axes in state_no_channel.
  • domain_extent: The extent L of the domain Ω = (0, L)ᴰ.
  • num_points: The number of points N in each spatial dimension. If not specified, it is inferred from the last axis of state_no_channel.
  • inner_exponent: The exponent p in the L^p norm.
  • outer_exponent: The exponent q the result after aggregation is raised to. If not specified, it is set to q = 1/p.

Warning

To get a truly consistent counterpart to the continuous norm, the domain_extent must be set. This is relevant to compare performance across domain sizes. However, if this is just used as a training objective, the domain_extent can be set to 1.0 since it only contributes a multiplicative factor.

Info

The approximation to the continuous integral is of the following form: - Exact if the state is bandlimited. - Exponentially linearly convergent if the state is smooth. It is converged once the state becomes effectively bandlimited under num_points. - Polynomially linear in all other cases.

Source code in exponax/metrics/_spatial.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def spatial_aggregator(
    state_no_channel: Float[Array, "... N"],
    *,
    num_spatial_dims: Optional[int] = None,
    domain_extent: float = 1.0,
    num_points: Optional[int] = None,
    inner_exponent: float = 2.0,
    outer_exponent: Optional[float] = None,
) -> float:
    """
    Aggregate over the spatial axes of a (channel-less) state tensor to get a
    *consistent* counterpart to a functional L^p norm in the continuous case.

    Assuming the `Exponax` convention that the domain is always the scaled
    hypercube `Ω = (0, L)ᴰ` (with `L = domain_extent`) and each spatial
    dimension being discretized uniformly into `N` points (i.e., there are `Nᴰ`
    points in total), and the left boundary is considered a degree of freedom,
    and the right is not, there is the following relation between a continuous
    function `u(x)` and its discretely sampled counterpart `uₕ`

        ‖ u(x) ‖_Lᵖ(Ω) = (∫_Ω |u(x)|ᵖ dx)^(1/p) ≈ ( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^(1/p)

    where the summation `∑ᵢ` must be understood as a sum over all `Nᴰ` points
    across all spatial dimensions. The `inner_exponent` corresponds to `p` in
    the above formula. This function also allows setting the outer exponent `q`
    which via

        ( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^q

    If it is not specified, it is set to `q = 1/p` to get a valid norm.

    !!! tip
        To apply this function to a state tensor with a leading channel axis,
        use `jax.vmap`.

    **Arguments:**

    - `state_no_channel`: The state tensor **without a leading channel
        axis**.
    - `num_spatial_dims`: The number of spatial dimensions. If not specified,
        it is inferred from the number of axes in `state_no_channel`.
    - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`.
    - `num_points`: The number of points `N` in each spatial dimension. If not
        specified, it is inferred from the last axis of `state_no_channel`.
    - `inner_exponent`: The exponent `p` in the L^p norm.
    - `outer_exponent`: The exponent `q` the result after aggregation is raised
        to. If not specified, it is set to `q = 1/p`.

    !!! warning
        To get a truly consistent counterpart to the continuous norm, the
        `domain_extent` must be set. This is relevant to compare performance
        across domain sizes. However, if this is just used as a training
        objective, the `domain_extent` can be set to `1.0` since it only
        contributes a multiplicative factor.

    !!! info
        The approximation to the continuous integral is of the following form:
            - **Exact** if the state is bandlimited.
            - **Exponentially linearly convergent** if the state is smooth. It
                is converged once the state becomes effectively bandlimited
                under `num_points`.
            - **Polynomially linear** in all other cases.
    """
    if num_spatial_dims is None:
        num_spatial_dims = state_no_channel.ndim
    if num_points is None:
        num_points = state_no_channel.shape[-1]

    if outer_exponent is None:
        outer_exponent = 1 / inner_exponent

    scale = (domain_extent / num_points) ** num_spatial_dims

    aggregated = jnp.sum(jnp.abs(state_no_channel) ** inner_exponent)

    return (scale * aggregated) ** outer_exponent