Skip to content

Fourier-spectral utilities¤

exponax.fft ¤

fft(
    field: Float[Array, "C ... N"],
    *,
    num_spatial_dims: Optional[int] = None
) -> Complex[Array, "C ... (N//2)+1"]

Perform a real-valued FFT of a field. This function is designed for states in Exponax with a leading channel axis and then one, two, or three subsequent spatial axes, each of the same length N.

Only accepts real-valued input fields and performs a real-valued FFT. Hence, the last axis of the returned field is of length N//2+1.

Warning

The argument num_spatial_dims can only be correctly inferred if the array follows the Exponax convention, e.g., no leading batch axis. For a batched operation, use jax.vmap on this function.

Arguments:

  • field: The state to transform.
  • num_spatial_dims: The number of spatial dimensions, i.e., how many spatial axes follow the channel axis. Can be inferred from the array if it follows the Exponax convention. For example, it is not allowed to have a leading batch axis, in such a case use jax.vmap on this function.

Returns:

  • field_hat: The transformed field, shape (C, ..., N//2+1).

Info

Internally uses jax.numpy.fft.rfftn with the default settings for the norm argument with norm="backward". This means that the forward FFT (this function) does not apply any normalization to the result, only the exponax.ifft function applies normalization. To extract the amplitude of the coefficients divide by expoanx.spectral.build_scaling_array.

Source code in exponax/_spectral.py
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
def fft(
    field: Float[Array, "C ... N"],
    *,
    num_spatial_dims: Optional[int] = None,
) -> Complex[Array, "C ... (N//2)+1"]:
    """
    Perform a **real-valued** FFT of a field. This function is designed for
    states in `Exponax` with a leading channel axis and then one, two, or three
    subsequent spatial axes, **each of the same length** N.

    Only accepts real-valued input fields and performs a real-valued FFT. Hence,
    the last axis of the returned field is of length N//2+1.

    !!! warning
        The argument `num_spatial_dims` can only be correctly inferred if the
        array follows the Exponax convention, e.g., no leading batch axis. For a
        batched operation, use `jax.vmap` on this function.

    **Arguments:**

    - `field`: The state to transform.
    - `num_spatial_dims`: The number of spatial dimensions, i.e., how many
        spatial axes follow the channel axis. Can be inferred from the array if
        it follows the Exponax convention. For example, it is not allowed to
        have a leading batch axis, in such a case use `jax.vmap` on this
        function.

    **Returns:**

    - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`.

    !!! info
        Internally uses `jax.numpy.fft.rfftn` with the default settings for the
        `norm` argument with `norm="backward"`. This means that the forward FFT
        (this function) does not apply any normalization to the result, only the
        [`exponax.ifft`][] function applies normalization. To extract the
        amplitude of the coefficients divide by
        `expoanx.spectral.build_scaling_array`.
    """
    if num_spatial_dims is None:
        num_spatial_dims = field.ndim - 1

    return jnp.fft.rfftn(field, axes=space_indices(num_spatial_dims))

exponax.ifft ¤

ifft(
    field_hat: Complex[Array, "C ... (N//2)+1"],
    *,
    num_spatial_dims: Optional[int] = None,
    num_points: Optional[int] = None
) -> Float[Array, "C ... N"]

Perform the inverse real-valued FFT of a field. This is the inverse operation of exponax.fft. This function is designed for states in Exponax with a leading channel axis and then one, two, or three following spatial axes. In state space all spatial axes have the same length N (here called num_points).

Requires a complex-valued field in Fourier space with the last axis of length N//2+1.

Info

The number of points (N, or num_points) must be provided if the number of spatial dimensions is 1. Otherwise, it can be inferred from the shape of the field.

Warning

The argument num_spatial_dims can only be correctly inferred if the array follows the Exponax convention, e.g., no leading batch axis. For a batched operation, use jax.vmap on this function.

Arguments:

  • field_hat: The transformed field, shape (C, ..., N//2+1).
  • num_spatial_dims: The number of spatial dimensions, i.e., how many spatial axes follow the channel axis. Can be inferred from the array if it follows the Exponax convention. For example, it is not allowed to have a leading batch axis, in such a case use jax.vmap on this function.
  • num_points: The number of points in each spatial dimension. Can be inferred if num_spatial_dims >= 2

Returns:

  • field: The state in physical space, shape (C, ..., N,).

Info

Internally uses jax.numpy.fft.irfftn with the default settings for the norm argument with norm="backward". This means that the forward FFT exponax.fft function does not apply any normalization to the input, only the inverse FFT (this function) applies normalization. Hence, if you want to define a state in Fourier space and inversely transform it, consider using exponax.spectral.build_scaling_array to correctly scale the complex values before transforming them back.

Source code in exponax/_spectral.py
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
def ifft(
    field_hat: Complex[Array, "C ... (N//2)+1"],
    *,
    num_spatial_dims: Optional[int] = None,
    num_points: Optional[int] = None,
) -> Float[Array, "C ... N"]:
    """
    Perform the inverse **real-valued** FFT of a field. This is the inverse
    operation of `exponax.fft`. This function is designed for states in
    `Exponax` with a leading channel axis and then one, two, or three following
    spatial axes. In state space all spatial axes have the same length N (here
    called `num_points`).

    Requires a complex-valued field in Fourier space with the last axis of
    length N//2+1.

    !!! info
        The number of points (N, or `num_points`) must be provided if the number
        of spatial dimensions is 1. Otherwise, it can be inferred from the shape
        of the field.

    !!! warning
        The argument `num_spatial_dims` can only be correctly inferred if the
        array follows the Exponax convention, e.g., no leading batch axis. For a
        batched operation, use `jax.vmap` on this function.

    **Arguments:**

    - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`.
    - `num_spatial_dims`: The number of spatial dimensions, i.e., how many
        spatial axes follow the channel axis. Can be inferred from the array if
        it follows the Exponax convention. For example, it is not allowed to
        have a leading batch axis, in such a case use `jax.vmap` on this
        function.
    - `num_points`: The number of points in each spatial dimension. Can be
        inferred if `num_spatial_dims` >= 2

    **Returns:**

    - `field`: The state in physical space, shape `(C, ..., N,)`.

    !!! info
        Internally uses `jax.numpy.fft.irfftn` with the default settings for the
        `norm` argument with `norm="backward"`. This means that the forward FFT
        [`exponax.fft`][] function does not apply any normalization to the
        input, only the inverse FFT (this function) applies normalization.
        Hence, if you want to define a state in Fourier space and inversely
        transform it, consider using [`exponax.spectral.build_scaling_array`][]
        to correctly scale the complex values before transforming them back.
    """
    if num_spatial_dims is None:
        num_spatial_dims = field_hat.ndim - 1

    if num_points is None:
        if num_spatial_dims >= 2:
            num_points = field_hat.shape[-2]
        else:
            raise ValueError("num_points must be provided if num_spatial_dims == 1.")
    return jnp.fft.irfftn(
        field_hat,
        s=spatial_shape(num_spatial_dims, num_points),
        axes=space_indices(num_spatial_dims),
    )

exponax.get_spectrum ¤

get_spectrum(
    state: Float[Array, "C ... N"], *, power: bool = True
) -> Float[Array, C(N // 2) + 1]

Compute the Fourier spectrum of a state, either the power spectrum or the amplitude spectrum.

Info

The returned array will always have two axes, no matter how many spatial axes the input has.

Arguments:

  • state: The state to compute the spectrum of. The state must follow the Exponax convention with a leading channel axis and then one, two, or three subsequent spatial axes, each of the same length N.
  • power: Whether to compute the power spectrum or the amplitude spectrum. Default is True meaning the amplitude spectrum.

Returns:

  • spectrum: The spectrum of the state, shape (C, (N//2)+1).

Tip

The spectrum is usually best presented with a logarithmic y-axis, either as plt.semiology or plt.loglog. Sometimes it can be helpful to set the spectrum below a threshold to zero to better visualize the relevant parts of the spectrum. This can be done with jnp.maximum(spectrum, 1e-10) for example.

Info

If it is applied to a vorticity field with power=True (default), it produces the enstrophy spectrum.

Note

The binning in higher dimensions can sometimes be counterintuitive. For example, on a 2D grid if mode [2, 2] is populated, this is not represented in the 2-bin (i.e., when indexing the returning array of this function at [2]), but in the 3-bin because its distance from the center is sqrt(2**2 + 2**2) = 2.8284... which is not in the range of the 2-bin [1.5, 2.5).

Source code in exponax/_spectral.py
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
def get_spectrum(
    state: Float[Array, "C ... N"],
    *,
    power: bool = True,
) -> Float[Array, "C (N//2)+1"]:
    """
    Compute the Fourier spectrum of a state, either the power spectrum or the
    amplitude spectrum.

    !!! info
        The returned array will always have two axes, no matter how many spatial
        axes the input has.

    **Arguments:**

    - `state`: The state to compute the spectrum of. The state must follow the
        `Exponax` convention with a leading channel axis and then one, two, or
        three subsequent spatial axes, **each of the same length** N.
    - `power`: Whether to compute the power spectrum or the amplitude spectrum.
        Default is `True` meaning the amplitude spectrum.

    **Returns:**

    - `spectrum`: The spectrum of the state, shape `(C, (N//2)+1)`.

    !!! tip
        The spectrum is usually best presented with a logarithmic y-axis, either
        as `plt.semiology` or `plt.loglog`. Sometimes it can be helpful to set
        the spectrum below a threshold to zero to better visualize the relevant
        parts of the spectrum. This can be done with `jnp.maximum(spectrum,
        1e-10)` for example.

    !!! info
        If it is applied to a vorticity field with `power=True` (default), it
        produces the enstrophy spectrum.

    !!! note
        The binning in higher dimensions can sometimes be counterintuitive. For
        example, on a 2D grid if mode `[2, 2]` is populated, this is not
        represented in the 2-bin (i.e., when indexing the returning array of
        this function at `[2]`), but in the 3-bin because its distance from the
        center is `sqrt(2**2 + 2**2) = 2.8284...` which is not in the range of
        the 2-bin `[1.5, 2.5)`.
    """
    num_spatial_dims = state.ndim - 1
    num_points = state.shape[-1]

    state_hat = fft(state, num_spatial_dims=num_spatial_dims)
    state_hat_scaled = state_hat / build_scaling_array(
        num_spatial_dims,
        num_points,
        mode="reconstruction",  # because of rfft
    )

    if power:
        magnitude = 0.5 * jnp.abs(state_hat_scaled) ** 2
    else:
        magnitude = jnp.abs(state_hat_scaled)

    if num_spatial_dims == 1:
        # 1D does not need any binning and can be returned directly
        return magnitude

    wavenumbers_mesh = build_wavenumbers(num_spatial_dims, num_points)
    wavenumbers_1d = build_wavenumbers(1, num_points)
    wavenumbers_norm = jnp.linalg.norm(wavenumbers_mesh, axis=0, keepdims=True)

    dk = wavenumbers_1d[0, 1] - wavenumbers_1d[0, 0]

    spectrum = []

    def power_in_bucket(p, k):
        lower_limit = k - dk / 2
        upper_limit = k + dk / 2
        mask = (wavenumbers_norm[0] >= lower_limit) & (
            wavenumbers_norm[0] < upper_limit
        )
        # return jnp.sum(p[mask])
        return jnp.where(
            mask,
            p,
            0.0,
        ).sum()

    def scan_fn(_, k):
        return None, jax.vmap(power_in_bucket, in_axes=(0, None))(magnitude, k)

    _, spectrum = jax.lax.scan(scan_fn, None, wavenumbers_1d[0, :])

    spectrum = jnp.moveaxis(spectrum, 0, -1)

    # for k in wavenumbers_1d[0, :]:
    #     spectrum.append(jax.vmap(power_in_bucket, in_axes=(0, None))(magnitude, k))

    # spectrum = jnp.stack(spectrum, axis=-1)

    return spectrum

exponax.spectral.get_fourier_coefficients ¤

get_fourier_coefficients(
    state: Float[Array, "C ... N"],
    *,
    scaling_compensation_mode: Optional[
        Literal[
            "norm_compensation",
            "reconstruction",
            "coef_extraction",
        ]
    ] = "coef_extraction",
    round: Optional[int] = 5,
    indexing: str = "ij"
) -> Complex[Array, "C ... (N//2)+1"]

Extract the Fourier coefficients of a state in Fourier space.

It correctly compensates the scaling used in exponax.fft such that the coefficient values can be directly read off from the array.

Arguments:

  • state: The state following the Exponax convention with a leading channel axis and then one, two, or three subsequent spatial axes, each of the same length N.
  • scaling_compensation_mode: The mode of the scaling array to use to compensate the scaling of the Fourier transform. The mode "norm_compensation" would produce the coefficient array as produced if jnp.fft.rfftn was applied with norm="forward", instead of the default of norm="backward" which is also the default used in Exponax. The mode "reconstruction" is similar to that but compensates for the fact that the rfft only has half of the coefficients along the right-most axis. The mode "coef_extraction" allows to read of the coefficient e.g. at index [i, j] (in 2D) directly wheras in the other modes, one would require to consider both the positive and negative wavenumbers. Can be set to None to not apply any scaling compensation. See also exponax.spectral.build_scaling_array for more information.
  • round: The number of decimals to round the coefficients to. Default is 5 which compensates for the rounding errors created by the FFT in single precision such that all coefficients that should not carry any energy also have zero value. Set to None to not round.
  • indexing: The indexing scheme to use for jax.numpy.meshgrid.

Returns:

  • coefficients: The Fourier coefficients of the state.

Warning

Do not use the results of this function together with the exponax.viz utilities since they will periodically wrap the boundary condition which is not needed in Fourier space.

Tip

Use this function to visualize the coefficients in higher dimensions. For example in 2D

state_2d = ...  # shape (1, N, N)

coef_2d = exponax.spectral.get_fourier_coefficients(state_2d)

# shape (1, N, (N//2)+1)

plt.imshow(
    jnp.log10(jnp.abs(coef_2d[0])),
)

And in 3D (requires the vape4d volume renderer to be installed - only works on GPU devices).

state_3d = ...  # shape (1, N, N, N)

coef_3d = exponax.spectral.get_fourier_coefficients(
    state_3d, round=None,
)

images = ex.viz.volume_render_state_3d(
    jnp.log10(jnp.abs(coef_3d)), vlim=(-8, 2),
)

plt.imshow(images[0])

To have the major half to the real-valued axis more prominent, consider flipping it via

coef_3d_flipped = jnp.flip(coef_3d, axis=-1)

Tip

Interpretation Guide In general for a FFT following the NumPy conventions, we have:

  • Positive amplitudes on cosine signals have positive coefficients in the real part of both the positive and the negative wavenumber.
  • Positive amplitudes on sine signals have negative coefficients in the imaginary part of the positive wavenumber and positive coefficients in the imaginary part of the negative wavenumber.

As such, if the output of this function on a 1D state was

array([[3.0 + 0.0j, 0.0 - 1.5j, 0.3 + 0.8j, 0.0 + 0.0j,]])

This would correspond to a signal with:

  • A constant offset of +3.0
  • A first sine mode with amplitude +1.5
  • A second cosine mode with amplitude +0.3
  • A second sine mode with amplitude -0.8

In higher dimensions, the interpretation arise out of the tensor product. Also be aware that for a (1, N, N) state, the coefficients are in the shape (1, N, (N//2)+1).

Source code in exponax/_spectral.py
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
def get_fourier_coefficients(
    state: Float[Array, "C ... N"],
    *,
    scaling_compensation_mode: Optional[
        Literal["norm_compensation", "reconstruction", "coef_extraction"]
    ] = "coef_extraction",
    round: Optional[int] = 5,
    indexing: str = "ij",
) -> Complex[Array, "C ... (N//2)+1"]:
    """
    Extract the Fourier coefficients of a state in Fourier space.

    It correctly compensates the scaling used in `exponax.fft` such that the
    coefficient values can be directly read off from the array.

    **Arguments:**

    - `state`: The state following the `Exponax` convention with a leading
        channel axis and then one, two, or three subsequent spatial axes, each
        of the same length N.
    - `scaling_compensation_mode`: The mode of the scaling array to use to
        compensate the scaling of the Fourier transform. The mode
        `"norm_compensation"` would produce the coefficient array as produced if
        `jnp.fft.rfftn` was applied with `norm="forward"`, instead of the
        default of `norm="backward"` which is also the default used in
        `Exponax`. The mode `"reconstruction"` is similar to that but
        compensates for the fact that the rfft only has half of the coefficients
        along the right-most axis. The mode `"coef_extraction"` allows to read
        of the coefficient e.g. at index [i, j] (in 2D) directly wheras in the
        other modes, one would require to consider both the positive and
        negative wavenumbers. Can be set to `None` to not apply any scaling
        compensation. See also [`exponax.spectral.build_scaling_array`][] for
        more information.
    - `round`: The number of decimals to round the coefficients to. Default is
        `5` which compensates for the rounding errors created by the FFT in
        single precision such that all coefficients that should not carry any
        energy also have zero value. Set to `None` to not round.
    - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`.

    **Returns:**

    - `coefficients`: The Fourier coefficients of the state.

    !!! warning
        Do not use the results of this function together with the `exponax.viz`
        utilities since they will periodically wrap the boundary condition which
        is not needed in Fourier space.

    !!! tip
        Use this function to visualize the coefficients in higher dimensions.
        For example in 2D

        ```python

        state_2d = ...  # shape (1, N, N)

        coef_2d = exponax.spectral.get_fourier_coefficients(state_2d)

        # shape (1, N, (N//2)+1)

        plt.imshow(
            jnp.log10(jnp.abs(coef_2d[0])),
        )

        ```

        And in 3D (requires the [`vape4d`](https://github.com/KeKsBoTer/vape4d)
        volume renderer to be installed - only works on GPU devices).

        ```python

        state_3d = ...  # shape (1, N, N, N)

        coef_3d = exponax.spectral.get_fourier_coefficients(
            state_3d, round=None,
        )

        images = ex.viz.volume_render_state_3d(
            jnp.log10(jnp.abs(coef_3d)), vlim=(-8, 2),
        )

        plt.imshow(images[0])

        ```

        To have the major half to the real-valued axis more prominent, consider
        flipping it via

        ```python

        coef_3d_flipped = jnp.flip(coef_3d, axis=-1)

        ```

    !!! tip
        **Interpretation Guide** In general for a FFT following the NumPy
        conventions, we have:

        * Positive amplitudes on cosine signals have positive coefficients in
            the real part of both the positive and the negative wavenumber.
        * Positive amplitudes on sine signals have negative coefficients in the
            imaginary part of the positive wavenumber and positive coefficients
            in the imaginary part of the negative wavenumber.

        As such, if the output of this function on a 1D state was

        ```python

        array([[3.0 + 0.0j, 0.0 - 1.5j, 0.3 + 0.8j, 0.0 + 0.0j,]])

        ```

        This would correspond to a signal with:

        * A constant offset of +3.0
        * A first sine mode with amplitude +1.5
        * A second cosine mode with amplitude +0.3
        * A second sine mode with amplitude -0.8

        In higher dimensions, the interpretation arise out of the tensor
        product. Also be aware that for a `(1, N, N)` state, the coefficients
        are in the shape `(1, N, (N//2)+1)`.
    """
    state_hat = fft(state)
    if scaling_compensation_mode is not None:
        scaling = build_scaling_array(
            state.ndim - 1,
            state.shape[-1],
            mode=scaling_compensation_mode,
            indexing=indexing,
        )
        coefficients = state_hat / scaling
    else:
        coefficients = state_hat

    if round is not None:
        coefficients = jnp.round(coefficients, round)

    return coefficients

exponax.spectral.build_scaling_array ¤

build_scaling_array(
    num_spatial_dims: int,
    num_points: int,
    *,
    mode: Literal[
        "norm_compensation",
        "reconstruction",
        "coef_extraction",
    ],
    indexing: str = "ij"
) -> Float[Array, "1 ... (N//2)+1"]

When exponax.fft is used, the resulting array in Fourier space represents a scaled version of the Fourier coefficients. Use this function to produce arrays to counteract this scaling based on the task.

  1. "norm_compensation": The scaling is exactly the scaling the exponax.ifft applies.
  2. "reconstruction": Technically "norm_compensation" should provide an array of coefficients that can be used to build a Fourier interpolant (i.e., what exponax.FourierInterpolator does). However, since exponax.fft uses the real-valued FFT, there is only half of the contribution for the coefficients along the right-most axis. This mode provides the scaling to counteract this.
  3. "coef_extraction": Any of the former modes (in higher dimensions) does not produce the same coefficients as the amplitude in the physical space (because there is a coefficient contribution both in the positive and negative wavenumber). For example, if the signal 3 * cos(2x) was discretized on the domain [0, 2pi] with 10 points, the amplitude of the Fourier coefficient at the 2nd wavenumber would be 3/2 if rescaled with mode "norm_compensation". This mode provides the scaling to extract the correct coefficients.

Arguments:

  • num_spatial_dims: The number of spatial dimensions.
  • num_points: The number of points in each spatial dimension.
  • mode: The mode of the scaling array. Either "norm_compensation", "reconstruction", or "coef_extraction".
  • indexing: The indexing scheme to use for jax.numpy.meshgrid. Either "ij" or "xy". Default is "ij".

Returns:

  • scaling: The scaling array.
Source code in exponax/_spectral.py
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
def build_scaling_array(
    num_spatial_dims: int,
    num_points: int,
    *,
    mode: Literal["norm_compensation", "reconstruction", "coef_extraction"],
    indexing: str = "ij",
) -> Float[Array, "1 ... (N//2)+1"]:
    """
    When `exponax.fft` is used, the resulting array in Fourier space represents
    a scaled version of the Fourier coefficients. Use this function to produce
    arrays to counteract this scaling based on the task.

    1. `"norm_compensation"`: The scaling is exactly the scaling the
       `exponax.ifft` applies.
    2. `"reconstruction"`: Technically `"norm_compensation"` should provide an
        array of coefficients that can be used to build a Fourier interpolant
        (i.e., what [`exponax.FourierInterpolator`][] does). However, since
        [`exponax.fft`][] uses the real-valued FFT, there is only half of the
        contribution for the coefficients along the right-most axis. This mode
        provides the scaling to counteract this.
    3. `"coef_extraction"`: Any of the former modes (in higher dimensions) does
        not produce the same coefficients as the amplitude in the physical space
        (because there is a coefficient contribution both in the positive and
        negative wavenumber). For example, if the signal `3 * cos(2x)` was
        discretized on the domain `[0, 2pi]` with 10 points, the amplitude of
        the Fourier coefficient at the 2nd wavenumber would be `3/2` if rescaled
        with mode `"norm_compensation"`. This mode provides the scaling to
        extract the correct coefficients.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions.
    - `num_points`: The number of points in each spatial dimension.
    - `mode`: The mode of the scaling array. Either `"norm_compensation"`,
        `"reconstruction"`, or `"coef_extraction"`.
    - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`.
        Either `"ij"` or `"xy"`. Default is `"ij"`.

    **Returns:**

    - `scaling`: The scaling array.
    """
    if mode == "norm_compensation":
        return _build_scaling_array(
            num_spatial_dims,
            num_points,
            right_most_scaling_denominator=1,
            others_scaling_denominator=1,
            indexing=indexing,
        )
    elif mode == "reconstruction":
        return _build_scaling_array(
            num_spatial_dims,
            num_points,
            right_most_scaling_denominator=2,
            others_scaling_denominator=1,
            indexing=indexing,
        )
    elif mode == "coef_extraction":
        return _build_scaling_array(
            num_spatial_dims,
            num_points,
            right_most_scaling_denominator=2,
            others_scaling_denominator=2,
            indexing=indexing,
        )
    else:
        raise ValueError("Invalid mode.")