Skip to content

Fourier nRMSE¤

exponax.metrics.fourier_nRMSE ¤

fourier_nRMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    low: Optional[int] = None,
    high: Optional[int] = None,
    eps: float = 1e-05
) -> float

Compute the normalized root mean squared error (nRMSE) between two fields in Fourier space.

Arguments:

  • u_pred (array): The first field to be used in the error computation.
  • u_ref (array): The second field to be used in the error computation.
  • low (int, optional): The low-pass filter cutoff. Default is 0.
  • high (int, optional): The high-pass filter cutoff. Default is the Nyquist frequency.
  • eps (float, optional): Small value to avoid division by zero and to remove numerical rounding artiacts from the FFT. Default is 1e-5.

Returns:

  • nrmse (float): The normalized root mean squared error between the fields
Source code in exponax/_metrics.py
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
def fourier_nRMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
    *,
    low: Optional[int] = None,
    high: Optional[int] = None,
    eps: float = 1e-5,
) -> float:
    """
    Compute the normalized root mean squared error (nRMSE) between two fields
    in Fourier space.

    **Arguments**:

    - `u_pred` (array): The first field to be used in the error computation.
    - `u_ref` (array): The second field to be used in the error computation.
    - `low` (int, optional): The low-pass filter cutoff. Default is 0.
    - `high` (int, optional): The high-pass filter cutoff. Default is the Nyquist
        frequency.
    - `eps` (float, optional): Small value to avoid division by zero and to
        remove numerical rounding artiacts from the FFT. Default is 1e-5.

    **Returns**:

    - `nrmse` (float): The normalized root mean squared error between the fields
    """
    num_spatial_dims = len(u_pred.shape) - 1

    nrmse = _fourier_nRMSE(
        u_pred, u_ref, low=low, high=high, num_spatial_dims=num_spatial_dims, eps=eps
    )

    return nrmse

exponax.metrics.mean_fourier_nRMSE ¤

mean_fourier_nRMSE(
    u_pred: Float[Array, "B C ... N"],
    u_ref: Float[Array, "B C ... N"],
    *,
    low: Optional[int] = None,
    high: Optional[int] = None,
    eps: float = 1e-05
) -> float

Compute the mean nRMSE between two fields in Fourier space. Use this function to correctly operate on arrays with a batch axis.

Arguments:

  • u_pred (array): The first field to be used in the error computation.
  • u_ref (array): The second field to be used in the error computation.
  • low (int, optional): The low-pass filter cutoff. Default is 0.
  • high (int, optional): The high-pass filter cutoff. Default is the Nyquist frequency.
  • eps (float, optional): Small value to avoid division by zero and to remove numerical rounding artiacts from the FFT. Default is 1e-5.

Returns:

  • mean_nrmse (float): The mean normalized root mean squared error between the fields
Source code in exponax/_metrics.py
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
def mean_fourier_nRMSE(
    u_pred: Float[Array, "B C ... N"],
    u_ref: Float[Array, "B C ... N"],
    *,
    low: Optional[int] = None,
    high: Optional[int] = None,
    eps: float = 1e-5,
) -> float:
    """
    Compute the mean nRMSE between two fields in Fourier space. Use this function
    to correctly operate on arrays with a batch axis.

    **Arguments**:

    - `u_pred` (array): The first field to be used in the error computation.
    - `u_ref` (array): The second field to be used in the error computation.
    - `low` (int, optional): The low-pass filter cutoff. Default is 0.
    - `high` (int, optional): The high-pass filter cutoff. Default is the Nyquist
        frequency.
    - `eps` (float, optional): Small value to avoid division by zero and to
        remove numerical rounding artiacts from the FFT. Default is 1e-5.

    **Returns**:

    - `mean_nrmse` (float): The mean normalized root mean squared error between the
        fields
    """
    batch_wise_nrmse = jax.vmap(
        lambda pred, ref: fourier_nRMSE(pred, ref, low=low, high=high, eps=eps)
    )(u_pred, u_ref)
    mean_nrmse = jnp.mean(batch_wise_nrmse)
    return mean_nrmse

exponax.metrics._fourier_nRMSE ¤

_fourier_nRMSE(
    u_pred: Float[Array, "... N"],
    u_ref: Float[Array, "... N"],
    *,
    low: Optional[int] = None,
    high: Optional[int] = None,
    num_spatial_dims: Optional[int] = None,
    eps: float = 1e-05
) -> float

Low-level function to compute the normalized root mean squared error (nRMSE) between two fields in Fourier space.

If num_spatial_dims is not provided, it will be inferred from the shape of the input fields. Please adjust this argument if you call this function with an array that also contains channels (even for arrays with singleton channels).

Arguments:

  • u_pred (array): The first field to be used in the error computation.
  • u_ref (array): The second field to be used in the error computation.
  • low (int, optional): The low-pass filter cutoff. Default is 0.
  • high (int, optional): The high-pass filter cutoff. Default is the Nyquist frequency.
  • num_spatial_dims (int, optional): The number of spatial dimensions in the field. If None, it will be inferred from the shape of the input fields and then is the number of axes present. Default is None.
  • eps (float, optional): Small value to avoid division by zero and to remove numerical rounding artiacts from the FFT. Default is 1e-5.
Source code in exponax/_metrics.py
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
def _fourier_nRMSE(
    u_pred: Float[Array, "... N"],
    u_ref: Float[Array, "... N"],
    *,
    low: Optional[int] = None,
    high: Optional[int] = None,
    num_spatial_dims: Optional[int] = None,
    eps: float = 1e-5,
) -> float:
    """
    Low-level function to compute the normalized root mean squared error (nRMSE)
    between two fields in Fourier space.

    If `num_spatial_dims` is not provided, it will be inferred from the shape of
    the input fields. Please adjust this argument if you call this function with
    an array that also contains channels (even for arrays with singleton
    channels).

    **Arguments**:

    - `u_pred` (array): The first field to be used in the error computation.
    - `u_ref` (array): The second field to be used in the error computation.
    - `low` (int, optional): The low-pass filter cutoff. Default is 0.
    - `high` (int, optional): The high-pass filter cutoff. Default is the
        Nyquist frequency.
    - `num_spatial_dims` (int, optional): The number of spatial dimensions in
        the field. If `None`, it will be inferred from the shape of the input
        fields and then is the number of axes present. Default is `None`.
    - `eps` (float, optional): Small value to avoid division by zero and to
        remove numerical rounding artiacts from the FFT. Default is 1e-5.
    """
    if num_spatial_dims is None:
        num_spatial_dims = len(u_pred.shape)
    # Assumes we have the same N for all dimensions
    num_points = u_pred.shape[-1]

    if low is None:
        low = 0
    if high is None:
        high = (num_points // 2) + 1

    low_mask = low_pass_filter_mask(
        num_spatial_dims,
        num_points,
        cutoff=low - 1,  # Need to subtract 1 because the cutoff is inclusive
    )
    high_mask = low_pass_filter_mask(
        num_spatial_dims,
        num_points,
        cutoff=high,
    )

    mask = jnp.invert(low_mask) & high_mask

    u_pred_fft = jnp.fft.rfftn(u_pred, axes=space_indices(num_spatial_dims))
    u_ref_fft = jnp.fft.rfftn(u_ref, axes=space_indices(num_spatial_dims))

    # The FFT incurse rounding errors around the machine precision that can be
    # noticeable in the nRMSE. We will zero out the values that are smaller than
    # the epsilon to avoid this.
    u_pred_fft = jnp.where(
        jnp.abs(u_pred_fft) < eps,
        jnp.zeros_like(u_pred_fft),
        u_pred_fft,
    )
    u_ref_fft = jnp.where(
        jnp.abs(u_ref_fft) < eps,
        jnp.zeros_like(u_ref_fft),
        u_ref_fft,
    )

    u_pred_fft_masked = u_pred_fft * mask
    u_ref_fft_masked = u_ref_fft * mask

    diff_fft_masked = u_pred_fft_masked - u_ref_fft_masked

    # Need to use vdot to correctly operate with complex numbers
    diff_norm_unscaled = jnp.sqrt(
        jnp.vdot(diff_fft_masked.flatten(), diff_fft_masked.flatten())
    ).real
    ref_norm_unscaled = jnp.sqrt(
        jnp.vdot(u_ref_fft_masked.flatten(), u_ref_fft_masked.flatten())
    ).real

    nrmse = diff_norm_unscaled / (ref_norm_unscaled + eps)

    return nrmse