Skip to content

Fourier-spectral utilities¤

exponax.fft ¤

fft(
    field: Float[Array, "C ... N"],
    *,
    num_spatial_dims: int | None = None
) -> Complex[Array, "C ... (N//2)+1"]

Perform a real-valued FFT of a field. This function is designed for states in Exponax with a leading channel axis and then one, two, or three subsequent spatial axes, each of the same length N.

Only accepts real-valued input fields and performs a real-valued FFT. Hence, the last axis of the returned field is of length N//2+1.

Warning

The argument num_spatial_dims can only be correctly inferred if the array follows the Exponax convention, e.g., no leading batch axis. For a batched operation, use jax.vmap on this function.

Arguments:

  • field: The state to transform.
  • num_spatial_dims: The number of spatial dimensions, i.e., how many spatial axes follow the channel axis. Can be inferred from the array if it follows the Exponax convention. For example, it is not allowed to have a leading batch axis, in such a case use jax.vmap on this function.

Returns:

  • field_hat: The transformed field, shape (C, ..., N//2+1).

Info

Internally uses jax.numpy.fft.rfftn with the default settings for the norm argument with norm="backward". This means that the forward FFT (this function) does not apply any normalization to the result, only the exponax.ifft function applies normalization. To extract the amplitude of the coefficients divide by expoanx.spectral.build_scaling_array.

Source code in exponax/_spectral.py
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
def fft(
    field: Float[Array, "C ... N"],
    *,
    num_spatial_dims: int | None = None,
) -> Complex[Array, "C ... (N//2)+1"]:
    """
    Perform a **real-valued** FFT of a field. This function is designed for
    states in `Exponax` with a leading channel axis and then one, two, or three
    subsequent spatial axes, **each of the same length** N.

    Only accepts real-valued input fields and performs a real-valued FFT. Hence,
    the last axis of the returned field is of length N//2+1.

    !!! warning
        The argument `num_spatial_dims` can only be correctly inferred if the
        array follows the Exponax convention, e.g., no leading batch axis. For a
        batched operation, use `jax.vmap` on this function.

    **Arguments:**

    - `field`: The state to transform.
    - `num_spatial_dims`: The number of spatial dimensions, i.e., how many
        spatial axes follow the channel axis. Can be inferred from the array if
        it follows the Exponax convention. For example, it is not allowed to
        have a leading batch axis, in such a case use `jax.vmap` on this
        function.

    **Returns:**

    - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`.

    !!! info
        Internally uses `jax.numpy.fft.rfftn` with the default settings for the
        `norm` argument with `norm="backward"`. This means that the forward FFT
        (this function) does not apply any normalization to the result, only the
        [`exponax.ifft`][] function applies normalization. To extract the
        amplitude of the coefficients divide by
        `expoanx.spectral.build_scaling_array`.
    """
    if num_spatial_dims is None:
        num_spatial_dims = field.ndim - 1

    return jnp.fft.rfftn(field, axes=space_indices(num_spatial_dims))

exponax.ifft ¤

ifft(
    field_hat: Complex[Array, "C ... (N//2)+1"],
    *,
    num_spatial_dims: int | None = None,
    num_points: int | None = None
) -> Float[Array, "C ... N"]

Perform the inverse real-valued FFT of a field. This is the inverse operation of exponax.fft. This function is designed for states in Exponax with a leading channel axis and then one, two, or three following spatial axes. In state space all spatial axes have the same length N (here called num_points).

Requires a complex-valued field in Fourier space with the last axis of length N//2+1.

Info

The number of points (N, or num_points) must be provided if the number of spatial dimensions is 1. Otherwise, it can be inferred from the shape of the field.

Warning

The argument num_spatial_dims can only be correctly inferred if the array follows the Exponax convention, e.g., no leading batch axis. For a batched operation, use jax.vmap on this function.

Arguments:

  • field_hat: The transformed field, shape (C, ..., N//2+1).
  • num_spatial_dims: The number of spatial dimensions, i.e., how many spatial axes follow the channel axis. Can be inferred from the array if it follows the Exponax convention. For example, it is not allowed to have a leading batch axis, in such a case use jax.vmap on this function.
  • num_points: The number of points in each spatial dimension. Can be inferred if num_spatial_dims >= 2

Returns:

  • field: The state in physical space, shape (C, ..., N,).

Info

Internally uses jax.numpy.fft.irfftn with the default settings for the norm argument with norm="backward". This means that the forward FFT exponax.fft function does not apply any normalization to the input, only the inverse FFT (this function) applies normalization. Hence, if you want to define a state in Fourier space and inversely transform it, consider using exponax.spectral.build_scaling_array to correctly scale the complex values before transforming them back.

Source code in exponax/_spectral.py
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
def ifft(
    field_hat: Complex[Array, "C ... (N//2)+1"],
    *,
    num_spatial_dims: int | None = None,
    num_points: int | None = None,
) -> Float[Array, "C ... N"]:
    """
    Perform the inverse **real-valued** FFT of a field. This is the inverse
    operation of `exponax.fft`. This function is designed for states in
    `Exponax` with a leading channel axis and then one, two, or three following
    spatial axes. In state space all spatial axes have the same length N (here
    called `num_points`).

    Requires a complex-valued field in Fourier space with the last axis of
    length N//2+1.

    !!! info
        The number of points (N, or `num_points`) must be provided if the number
        of spatial dimensions is 1. Otherwise, it can be inferred from the shape
        of the field.

    !!! warning
        The argument `num_spatial_dims` can only be correctly inferred if the
        array follows the Exponax convention, e.g., no leading batch axis. For a
        batched operation, use `jax.vmap` on this function.

    **Arguments:**

    - `field_hat`: The transformed field, shape `(C, ..., N//2+1)`.
    - `num_spatial_dims`: The number of spatial dimensions, i.e., how many
        spatial axes follow the channel axis. Can be inferred from the array if
        it follows the Exponax convention. For example, it is not allowed to
        have a leading batch axis, in such a case use `jax.vmap` on this
        function.
    - `num_points`: The number of points in each spatial dimension. Can be
        inferred if `num_spatial_dims` >= 2

    **Returns:**

    - `field`: The state in physical space, shape `(C, ..., N,)`.

    !!! info
        Internally uses `jax.numpy.fft.irfftn` with the default settings for the
        `norm` argument with `norm="backward"`. This means that the forward FFT
        [`exponax.fft`][] function does not apply any normalization to the
        input, only the inverse FFT (this function) applies normalization.
        Hence, if you want to define a state in Fourier space and inversely
        transform it, consider using [`exponax.spectral.build_scaling_array`][]
        to correctly scale the complex values before transforming them back.
    """
    if num_spatial_dims is None:
        num_spatial_dims = field_hat.ndim - 1

    if num_points is None:
        if num_spatial_dims >= 2:
            num_points = field_hat.shape[-2]
        else:
            raise ValueError("num_points must be provided if num_spatial_dims == 1.")
    return jnp.fft.irfftn(
        field_hat,
        s=spatial_shape(num_spatial_dims, num_points),
        axes=space_indices(num_spatial_dims),
    )

exponax.get_spectrum ¤

get_spectrum(
    state: Float[Array, "C ... N"],
    *,
    power: bool = True,
    radial_binning: Literal["average", "sum"] = "sum"
) -> Float[Array, C(N // 2) + 1]

Compute the Fourier spectrum of a state, either the power spectrum or the amplitude spectrum.

Info

The returned array will always have two axes, no matter how many spatial axes the input has.

Arguments:

  • state: The state to compute the spectrum of. The state must follow the Exponax convention with a leading channel axis and then one, two, or three subsequent spatial axes, each of the same length N.
  • power: Whether to compute the power spectrum or the amplitude spectrum. Default is True meaning the power spectrum.
  • radial_binning: How to aggregate Fourier modes within each radial (spherical shell) bin. Either "sum" (default) or "average".
    • "sum": Computes the total power/amplitude in each bin. This is the conventional approach in turbulence literature where the Kolmogorov -5/3 law applies. Preserves Parseval's theorem (sum over all bins equals total energy).
    • "average": Computes the mean power/amplitude per mode in each bin. This gives a spectral density that is resolution-independent and removes the geometric scaling with wavenumber.

Returns:

  • spectrum: The spectrum of the state, shape (C, (N//2)+1).

Tip

The spectrum is usually best presented with a logarithmic y-axis, either as plt.semilogy or plt.loglog. Sometimes it can be helpful to set the spectrum below a threshold to zero to better visualize the relevant parts of the spectrum (fast Fourier transformations include rounding errors that, especially, can aggregate when radially binning in higher dimensions). This can be done with jnp.maximum(spectrum, 1e-10) for example. Recommended lower thresholds are 1e-10 and 1e-5 for the power spectrum and amplitude spectrum, respectively, both in single precision. (Reason: 1e-5 is slightly higher than the precision limit for single precision, and 1e-10 is its square (since in the power spectrum is the amplitude spectrum squared).)

Info

If it is applied to a vorticity field with power=True (default), it produces the enstrophy spectrum.

Info

Multi-Channel Fields: For example, when computing the power spectrum of a velocity field in 2D or 3D, natively, this function will compute the power spectrum of each velocity component separately, returning an array of shape (D, (N//2)+1) where D is the number of spatial dimensions. Typically, one is interested in the total kinetic energy which would be the vector norm of the velocity field squared. This can be computed by summing the power spectrum across the channel axis, i.e., jnp.sum(spectrum, axis=0).

Note

The binning in higher dimensions can sometimes be counterintuitive. For example, on a 2D grid if mode [2, 2] is populated, this is not represented in the 2-bin (i.e., when indexing the returning array of this function at [2]), but in the 3-bin because its distance from the center is sqrt(2**2 + 2**2) = 2.8284... which is not in the range of the 2-bin [1.5, 2.5).

Note

On shell surface area scaling: In continuous formulations, the 1D isotropic spectrum relates to the spectral density tensor via a shell surface area factor: E(k) = 2πk · Φ(k) in 2D and E(k) = 4πk² · Φ(k) in 3D. In this discrete implementation: (a) with radial_binning="sum": The geometric factor is implicit because the number of discrete modes in each bin grows proportionally to the shell surface area. (b) with radial_binning="average": The geometric factor is divided out, yielding a per-mode density.

Note

On the radial bin range: In D > 1 dimensions, the radial bins only extend up to N//2 + 1 (the 1D Nyquist frequency), not to sqrt(D) * (N//2 + 1) (the corner of the wavenumber cube). Modes with |k| > (N//2 + 1) that exist in the corners of the Cartesian wavenumber grid are not included. The spectrum thus covers the Nyquist sphere inscribed in the wavenumber cube.

Note

Parseval's identity: With power=True and radial_binning="sum", the spectrum satisfies jnp.sum(spectrum) == 0.5 * jnp.mean(u**2) (per channel). In 1D this holds exactly; in higher dimensions it holds for signals whose energy is contained within the Nyquist sphere (see note above on radial bin range). Note that this is a sum in Fourier space equaling a mean in physical space — the reverse of the textbook form (1/N^D) Σ|û|² = Σ|u|². The inversion happens because the coefficients are divided by scaling arrays of order N^D; since power is quadratic, this scales as N^(2D), which cancels the N^D factor from the sum over Fourier modes, leaving a single N^D factor in the denominator on the right-hand side, which is the mean in physical space.

Source code in exponax/_spectral.py
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
def get_spectrum(
    state: Float[Array, "C ... N"],
    *,
    power: bool = True,
    radial_binning: Literal["average", "sum"] = "sum",
) -> Float[Array, "C (N//2)+1"]:
    """
    Compute the Fourier spectrum of a state, either the power spectrum or the
    amplitude spectrum.

    !!! info
        The returned array will always have two axes, no matter how many spatial
        axes the input has.

    **Arguments:**

    - `state`: The state to compute the spectrum of. The state must follow the
        `Exponax` convention with a leading channel axis and then one, two, or
        three subsequent spatial axes, **each of the same length** N.
    - `power`: Whether to compute the power spectrum or the amplitude spectrum.
        Default is `True` meaning the power spectrum.
    - `radial_binning`: How to aggregate Fourier modes within each radial
        (spherical shell) bin. Either `"sum"` (default) or `"average"`.
        - `"sum"`: Computes the total power/amplitude in each bin. This is the
          conventional approach in turbulence literature where the Kolmogorov
          -5/3 law applies. Preserves Parseval's theorem (sum over all bins
          equals total energy).
        - `"average"`: Computes the mean power/amplitude per mode in each bin.
          This gives a spectral density that is resolution-independent and
          removes the geometric scaling with wavenumber.

    **Returns:**

    - `spectrum`: The spectrum of the state, shape `(C, (N//2)+1)`.

    !!! tip
        The spectrum is usually best presented with a logarithmic y-axis, either
        as `plt.semilogy` or `plt.loglog`. Sometimes it can be helpful to set
        the spectrum below a threshold to zero to better visualize the relevant
        parts of the spectrum (fast Fourier transformations include rounding
        errors that, especially, can aggregate when radially binning in higher
        dimensions). This can be done with `jnp.maximum(spectrum, 1e-10)` for
        example. Recommended lower thresholds are `1e-10` and `1e-5` for the
        power spectrum and amplitude spectrum, respectively, both in single
        precision. (Reason: `1e-5` is slightly higher than the precision
        limit for single precision, and `1e-10` is its square (since in the
        power spectrum is the amplitude spectrum squared).)

    !!! info
        If it is applied to a vorticity field with `power=True` (default), it
        produces the enstrophy spectrum.

    !!! info
        Multi-Channel Fields: For example, when computing the power spectrum of
        a velocity field in 2D or 3D, natively, this function will compute the
        power spectrum of each velocity component separately, returning an array
        of shape `(D, (N//2)+1)` where `D` is the number of spatial dimensions.
        Typically, one is interested in the total kinetic energy which would be
        the vector norm of the velocity field squared. This can be computed by
        summing the power spectrum across the channel axis, i.e.,
        `jnp.sum(spectrum, axis=0)`.

    !!! note
        The binning in higher dimensions can sometimes be counterintuitive. For
        example, on a 2D grid if mode `[2, 2]` is populated, this is not
        represented in the 2-bin (i.e., when indexing the returning array of
        this function at `[2]`), but in the 3-bin because its distance from the
        center is `sqrt(2**2 + 2**2) = 2.8284...` which is not in the range of
        the 2-bin `[1.5, 2.5)`.

    !!! note
        **On shell surface area scaling:** In continuous formulations, the 1D
        isotropic spectrum relates to the spectral density tensor via a shell
        surface area factor: `E(k) = 2πk · Φ(k)` in 2D and `E(k) = 4πk² · Φ(k)`
        in 3D. In this discrete implementation: (a) with `radial_binning="sum"`:
        The geometric factor is implicit because the number of discrete modes in
        each bin grows proportionally to the shell surface area. (b) with
        `radial_binning="average"`: The geometric factor is divided out,
        yielding a per-mode density.

    !!! note
        **On the radial bin range:** In D > 1 dimensions, the radial bins only
        extend up to `N//2 + 1` (the 1D Nyquist frequency), not to `sqrt(D) *
        (N//2 + 1)` (the corner of the wavenumber cube). Modes with `|k| > (N//2
        + 1)` that exist in the corners of the Cartesian wavenumber grid are not
        included. The spectrum thus covers the **Nyquist sphere** inscribed in
        the wavenumber cube.

    !!! note
        **Parseval's identity:** With `power=True` and `radial_binning="sum"`,
        the spectrum satisfies `jnp.sum(spectrum) == 0.5 * jnp.mean(u**2)` (per
        channel). In 1D this holds exactly; in higher dimensions it holds for
        signals whose energy is contained within the Nyquist sphere (see note
        above on radial bin range). Note that this is a **sum** in Fourier
        space equaling a **mean** in physical space — the reverse of the
        textbook form `(1/N^D) Σ|û|² = Σ|u|²`. The inversion happens because the
        coefficients are divided by scaling arrays of order N^D; since power is
        quadratic, this scales as N^(2D), which cancels the N^D factor from the
        sum over Fourier modes, leaving a single N^D factor in the denominator
        on the right-hand side, which is the mean in physical space.
    """
    num_spatial_dims = state.ndim - 1
    num_points = state.shape[-1]

    state_hat_abs = jnp.abs(fft(state, num_spatial_dims=num_spatial_dims))
    magnitude = state_hat_abs / build_scaling_array(
        num_spatial_dims,
        num_points,
        mode="reconstruction",
    )

    if power:
        # The "reconstruction" scaling doubles intermediate rfft-axis modes
        # to account for the missing conjugate.  This is correct for the
        # amplitude (linear in |c|), but squaring would turn that 2x into 4x.
        # Mixing one "reconstruction"-scaled and one "norm_compensation"-scaled
        # factor gives: |û|² / (recon * norm).  This equals |c|² for DC and
        # Nyquist (where recon == norm) and |c|²/2 for intermediate modes
        # (where the conjugate pair together contribute half each), so that
        # sum(spectrum) == 0.5 * mean(u²).
        magnitude_norm_compensated = state_hat_abs / build_scaling_array(
            num_spatial_dims,
            num_points,
            mode="norm_compensation",
        )
        quantity = 0.5 * magnitude * magnitude_norm_compensated
    else:
        quantity = magnitude

    if num_spatial_dims == 1:
        # 1D does not need any binning and can be returned directly
        return quantity

    wavenumbers_mesh = build_wavenumbers(num_spatial_dims, num_points)
    wavenumbers_1d = build_wavenumbers(1, num_points)
    wavenumbers_norm = jnp.linalg.norm(wavenumbers_mesh, axis=0, keepdims=True)

    dk = wavenumbers_1d[0, 1] - wavenumbers_1d[0, 0]

    spectrum = []

    def power_in_bucket(p, k):
        lower_limit = k - dk / 2
        upper_limit = k + dk / 2
        mask = (wavenumbers_norm[0] >= lower_limit) & (
            wavenumbers_norm[0] < upper_limit
        )
        if radial_binning == "average":
            return jnp.nanmean(p, where=mask)
        else:  # radial_binning == "sum"
            return jnp.nansum(p, where=mask)

    def scan_fn(_, k):
        return None, jax.vmap(power_in_bucket, in_axes=(0, None))(quantity, k)

    _, spectrum = jax.lax.scan(scan_fn, None, wavenumbers_1d[0, :])

    spectrum = jnp.moveaxis(spectrum, 0, -1)

    # for k in wavenumbers_1d[0, :]:
    #     spectrum.append(jax.vmap(power_in_bucket, in_axes=(0, None))(magnitude, k))

    # spectrum = jnp.stack(spectrum, axis=-1)

    return spectrum

exponax.spectral.get_fourier_coefficients ¤

get_fourier_coefficients(
    state: Float[Array, "C ... N"],
    *,
    scaling_compensation_mode: (
        Literal[
            "norm_compensation",
            "reconstruction",
            "coef_extraction",
        ]
        | None
    ) = "coef_extraction",
    round: int | None = 5,
    indexing: str = "ij"
) -> Complex[Array, "C ... (N//2)+1"]

Extract the Fourier coefficients of a state in Fourier space.

It correctly compensates the scaling used in exponax.fft such that the coefficient values can be directly read off from the array.

Arguments:

  • state: The state following the Exponax convention with a leading channel axis and then one, two, or three subsequent spatial axes, each of the same length N.
  • scaling_compensation_mode: The mode of the scaling array to use to compensate the scaling of the Fourier transform. The mode "norm_compensation" would produce the coefficient array as produced if jnp.fft.rfftn was applied with norm="forward", instead of the default of norm="backward" which is also the default used in Exponax. The mode "reconstruction" is similar to that but compensates for the fact that the rfft only has half of the coefficients along the right-most axis. The mode "coef_extraction" allows to read of the coefficient e.g. at index [i, j] (in 2D) directly wheras in the other modes, one would require to consider both the positive and negative wavenumbers. Can be set to None to not apply any scaling compensation. See also exponax.spectral.build_scaling_array for more information.
  • round: The number of decimals to round the coefficients to. Default is 5 which compensates for the rounding errors created by the FFT in single precision such that all coefficients that should not carry any energy also have zero value. Set to None to not round.
  • indexing: The indexing scheme to use for jax.numpy.meshgrid.

Returns:

  • coefficients: The Fourier coefficients of the state.

Warning

Do not use the results of this function together with the exponax.viz utilities since they will periodically wrap the boundary condition which is not needed in Fourier space.

Tip

Use this function to visualize the coefficients in higher dimensions. For example in 2D

state_2d = ...  # shape (1, N, N)

coef_2d = exponax.spectral.get_fourier_coefficients(state_2d)

# shape (1, N, (N//2)+1)

plt.imshow(
    jnp.log10(jnp.abs(coef_2d[0])),
)

And in 3D (requires the vape4d volume renderer to be installed - only works on GPU devices).

state_3d = ...  # shape (1, N, N, N)

coef_3d = exponax.spectral.get_fourier_coefficients(
    state_3d, round=None,
)

images = ex.viz.volume_render_state_3d(
    jnp.log10(jnp.abs(coef_3d)), vlim=(-8, 2),
)

plt.imshow(images[0])

To have the major half to the real-valued axis more prominent, consider flipping it via

coef_3d_flipped = jnp.flip(coef_3d, axis=-1)

Tip

Interpretation Guide In general for a FFT following the NumPy conventions, we have:

  • Positive amplitudes on cosine signals have positive coefficients in the real part of both the positive and the negative wavenumber.
  • Positive amplitudes on sine signals have negative coefficients in the imaginary part of the positive wavenumber and positive coefficients in the imaginary part of the negative wavenumber.

As such, if the output of this function on a 1D state was

array([[3.0 + 0.0j, 0.0 - 1.5j, 0.3 + 0.8j, 0.0 + 0.0j,]])

This would correspond to a signal with:

  • A constant offset of +3.0
  • A first sine mode with amplitude +1.5
  • A second cosine mode with amplitude +0.3
  • A second sine mode with amplitude -0.8

In higher dimensions, the interpretation arise out of the tensor product. Also be aware that for a (1, N, N) state, the coefficients are in the shape (1, N, (N//2)+1).

Source code in exponax/_spectral.py
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
def get_fourier_coefficients(
    state: Float[Array, "C ... N"],
    *,
    scaling_compensation_mode: Literal[
        "norm_compensation", "reconstruction", "coef_extraction"
    ]
    | None = "coef_extraction",
    round: int | None = 5,
    indexing: str = "ij",
) -> Complex[Array, "C ... (N//2)+1"]:
    """
    Extract the Fourier coefficients of a state in Fourier space.

    It correctly compensates the scaling used in `exponax.fft` such that the
    coefficient values can be directly read off from the array.

    **Arguments:**

    - `state`: The state following the `Exponax` convention with a leading
        channel axis and then one, two, or three subsequent spatial axes, each
        of the same length N.
    - `scaling_compensation_mode`: The mode of the scaling array to use to
        compensate the scaling of the Fourier transform. The mode
        `"norm_compensation"` would produce the coefficient array as produced if
        `jnp.fft.rfftn` was applied with `norm="forward"`, instead of the
        default of `norm="backward"` which is also the default used in
        `Exponax`. The mode `"reconstruction"` is similar to that but
        compensates for the fact that the rfft only has half of the coefficients
        along the right-most axis. The mode `"coef_extraction"` allows to read
        of the coefficient e.g. at index [i, j] (in 2D) directly wheras in the
        other modes, one would require to consider both the positive and
        negative wavenumbers. Can be set to `None` to not apply any scaling
        compensation. See also [`exponax.spectral.build_scaling_array`][] for
        more information.
    - `round`: The number of decimals to round the coefficients to. Default is
        `5` which compensates for the rounding errors created by the FFT in
        single precision such that all coefficients that should not carry any
        energy also have zero value. Set to `None` to not round.
    - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`.

    **Returns:**

    - `coefficients`: The Fourier coefficients of the state.

    !!! warning
        Do not use the results of this function together with the `exponax.viz`
        utilities since they will periodically wrap the boundary condition which
        is not needed in Fourier space.

    !!! tip
        Use this function to visualize the coefficients in higher dimensions.
        For example in 2D

        ```python

        state_2d = ...  # shape (1, N, N)

        coef_2d = exponax.spectral.get_fourier_coefficients(state_2d)

        # shape (1, N, (N//2)+1)

        plt.imshow(
            jnp.log10(jnp.abs(coef_2d[0])),
        )

        ```

        And in 3D (requires the [`vape4d`](https://github.com/KeKsBoTer/vape4d)
        volume renderer to be installed - only works on GPU devices).

        ```python

        state_3d = ...  # shape (1, N, N, N)

        coef_3d = exponax.spectral.get_fourier_coefficients(
            state_3d, round=None,
        )

        images = ex.viz.volume_render_state_3d(
            jnp.log10(jnp.abs(coef_3d)), vlim=(-8, 2),
        )

        plt.imshow(images[0])

        ```

        To have the major half to the real-valued axis more prominent, consider
        flipping it via

        ```python

        coef_3d_flipped = jnp.flip(coef_3d, axis=-1)

        ```

    !!! tip
        **Interpretation Guide** In general for a FFT following the NumPy
        conventions, we have:

        * Positive amplitudes on cosine signals have positive coefficients in
            the real part of both the positive and the negative wavenumber.
        * Positive amplitudes on sine signals have negative coefficients in the
            imaginary part of the positive wavenumber and positive coefficients
            in the imaginary part of the negative wavenumber.

        As such, if the output of this function on a 1D state was

        ```python

        array([[3.0 + 0.0j, 0.0 - 1.5j, 0.3 + 0.8j, 0.0 + 0.0j,]])

        ```

        This would correspond to a signal with:

        * A constant offset of +3.0
        * A first sine mode with amplitude +1.5
        * A second cosine mode with amplitude +0.3
        * A second sine mode with amplitude -0.8

        In higher dimensions, the interpretation arise out of the tensor
        product. Also be aware that for a `(1, N, N)` state, the coefficients
        are in the shape `(1, N, (N//2)+1)`.
    """
    state_hat = fft(state)
    if scaling_compensation_mode is not None:
        scaling = build_scaling_array(
            state.ndim - 1,
            state.shape[-1],
            mode=scaling_compensation_mode,
            indexing=indexing,
        )
        coefficients = state_hat / scaling
    else:
        coefficients = state_hat

    if round is not None:
        coefficients = jnp.round(coefficients, round)

    return coefficients

exponax.spectral.build_scaling_array ¤

build_scaling_array(
    num_spatial_dims: int,
    num_points: int,
    *,
    mode: Literal[
        "norm_compensation",
        "reconstruction",
        "coef_extraction",
    ],
    indexing: str = "ij"
) -> Float[Array, "1 ... (N//2)+1"]

When exponax.fft is used, the resulting array in Fourier space represents a scaled version of the Fourier coefficients. Use this function to produce arrays to counteract this scaling based on the task.

  1. "norm_compensation": The scaling is exactly the scaling the exponax.ifft applies.
  2. "reconstruction": Technically "norm_compensation" should provide an array of coefficients that can be used to build a Fourier interpolant (i.e., what exponax.FourierInterpolator does). However, since exponax.fft uses the real-valued FFT, there is only half of the contribution for the coefficients along the right-most axis. This mode provides the scaling to counteract this.
  3. "coef_extraction": Any of the former modes (in higher dimensions) does not produce the same coefficients as the amplitude in the physical space (because there is a coefficient contribution both in the positive and negative wavenumber). For example, if the signal 3 * cos(2x) was discretized on the domain [0, 2pi] with 10 points, the amplitude of the Fourier coefficient at the 2nd wavenumber would be 3/2 if rescaled with mode "norm_compensation". This mode provides the scaling to extract the correct coefficients.

Arguments:

  • num_spatial_dims: The number of spatial dimensions.
  • num_points: The number of points in each spatial dimension.
  • mode: The mode of the scaling array. Either "norm_compensation", "reconstruction", or "coef_extraction".
  • indexing: The indexing scheme to use for jax.numpy.meshgrid. Either "ij" or "xy". Default is "ij".

Returns:

  • scaling: The scaling array.
Source code in exponax/_spectral.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
def build_scaling_array(
    num_spatial_dims: int,
    num_points: int,
    *,
    mode: Literal["norm_compensation", "reconstruction", "coef_extraction"],
    indexing: str = "ij",
) -> Float[Array, "1 ... (N//2)+1"]:
    """
    When `exponax.fft` is used, the resulting array in Fourier space represents
    a scaled version of the Fourier coefficients. Use this function to produce
    arrays to counteract this scaling based on the task.

    1. `"norm_compensation"`: The scaling is exactly the scaling the
       `exponax.ifft` applies.
    2. `"reconstruction"`: Technically `"norm_compensation"` should provide an
        array of coefficients that can be used to build a Fourier interpolant
        (i.e., what [`exponax.FourierInterpolator`][] does). However, since
        [`exponax.fft`][] uses the real-valued FFT, there is only half of the
        contribution for the coefficients along the right-most axis. This mode
        provides the scaling to counteract this.
    3. `"coef_extraction"`: Any of the former modes (in higher dimensions) does
        not produce the same coefficients as the amplitude in the physical space
        (because there is a coefficient contribution both in the positive and
        negative wavenumber). For example, if the signal `3 * cos(2x)` was
        discretized on the domain `[0, 2pi]` with 10 points, the amplitude of
        the Fourier coefficient at the 2nd wavenumber would be `3/2` if rescaled
        with mode `"norm_compensation"`. This mode provides the scaling to
        extract the correct coefficients.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions.
    - `num_points`: The number of points in each spatial dimension.
    - `mode`: The mode of the scaling array. Either `"norm_compensation"`,
        `"reconstruction"`, or `"coef_extraction"`.
    - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`.
        Either `"ij"` or `"xy"`. Default is `"ij"`.

    **Returns:**

    - `scaling`: The scaling array.
    """
    if mode == "norm_compensation":
        return _build_scaling_array(
            num_spatial_dims,
            num_points,
            right_most_scaling_denominator=1,
            others_scaling_denominator=1,
            indexing=indexing,
        )
    elif mode == "reconstruction":
        return _build_scaling_array(
            num_spatial_dims,
            num_points,
            right_most_scaling_denominator=2,
            others_scaling_denominator=1,
            indexing=indexing,
        )
    elif mode == "coef_extraction":
        return _build_scaling_array(
            num_spatial_dims,
            num_points,
            right_most_scaling_denominator=2,
            others_scaling_denominator=2,
            indexing=indexing,
        )
    else:
        raise ValueError("Invalid mode.")