Fourier-spectral utilities¤
exponax.fft
¤
fft(
field: Float[Array, "C ... N"],
*,
num_spatial_dims: Optional[int] = None
) -> Complex[Array, "C ... (N//2)+1"]
Perform a real-valued FFT of a field. This function is designed for
states in Exponax
with a leading channel axis and then one, two, or three
subsequent spatial axes, each of the same length N.
Only accepts real-valued input fields and performs a real-valued FFT. Hence, the last axis of the returned field is of length N//2+1.
Warning
The argument num_spatial_dims
can only be correctly inferred if the
array follows the Exponax convention, e.g., no leading batch axis. For a
batched operation, use jax.vmap
on this function.
Arguments:
field
: The state to transform.num_spatial_dims
: The number of spatial dimensions, i.e., how many spatial axes follow the channel axis. Can be inferred from the array if it follows the Exponax convention. For example, it is not allowed to have a leading batch axis, in such a case usejax.vmap
on this function.
Returns:
field_hat
: The transformed field, shape(C, ..., N//2+1)
.
Info
Internally uses jax.numpy.fft.rfftn
with the default settings for the
norm
argument with norm="backward"
. This means that the forward FFT
(this function) does not apply any normalization to the result, only the
exponax.ifft
function applies normalization. To extract the
amplitude of the coefficients divide by
expoanx.spectral.build_scaling_array
.
Source code in exponax/_spectral.py
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 |
|
exponax.ifft
¤
ifft(
field_hat: Complex[Array, "C ... (N//2)+1"],
*,
num_spatial_dims: Optional[int] = None,
num_points: Optional[int] = None
) -> Float[Array, "C ... N"]
Perform the inverse real-valued FFT of a field. This is the inverse
operation of exponax.fft
. This function is designed for states in
Exponax
with a leading channel axis and then one, two, or three following
spatial axes. In state space all spatial axes have the same length N (here
called num_points
).
Requires a complex-valued field in Fourier space with the last axis of length N//2+1.
Info
The number of points (N, or num_points
) must be provided if the number
of spatial dimensions is 1. Otherwise, it can be inferred from the shape
of the field.
Warning
The argument num_spatial_dims
can only be correctly inferred if the
array follows the Exponax convention, e.g., no leading batch axis. For a
batched operation, use jax.vmap
on this function.
Arguments:
field_hat
: The transformed field, shape(C, ..., N//2+1)
.num_spatial_dims
: The number of spatial dimensions, i.e., how many spatial axes follow the channel axis. Can be inferred from the array if it follows the Exponax convention. For example, it is not allowed to have a leading batch axis, in such a case usejax.vmap
on this function.num_points
: The number of points in each spatial dimension. Can be inferred ifnum_spatial_dims
>= 2
Returns:
field
: The state in physical space, shape(C, ..., N,)
.
Info
Internally uses jax.numpy.fft.irfftn
with the default settings for the
norm
argument with norm="backward"
. This means that the forward FFT
exponax.fft
function does not apply any normalization to the
input, only the inverse FFT (this function) applies normalization.
Hence, if you want to define a state in Fourier space and inversely
transform it, consider using exponax.spectral.build_scaling_array
to correctly scale the complex values before transforming them back.
Source code in exponax/_spectral.py
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 |
|
exponax.get_spectrum
¤
get_spectrum(
state: Float[Array, "C ... N"], *, power: bool = True
) -> Float[Array, C(N // 2) + 1]
Compute the Fourier spectrum of a state, either the power spectrum or the amplitude spectrum.
Info
The returned array will always have two axes, no matter how many spatial axes the input has.
Arguments:
state
: The state to compute the spectrum of. The state must follow theExponax
convention with a leading channel axis and then one, two, or three subsequent spatial axes, each of the same length N.power
: Whether to compute the power spectrum or the amplitude spectrum. Default isTrue
meaning the amplitude spectrum.
Returns:
spectrum
: The spectrum of the state, shape(C, (N//2)+1)
.
Tip
The spectrum is usually best presented with a logarithmic y-axis, either
as plt.semiology
or plt.loglog
. Sometimes it can be helpful to set
the spectrum below a threshold to zero to better visualize the relevant
parts of the spectrum. This can be done with jnp.maximum(spectrum,
1e-10)
for example.
Info
If it is applied to a vorticity field with power=True
(default), it
produces the enstrophy spectrum.
Note
The binning in higher dimensions can sometimes be counterintuitive. For
example, on a 2D grid if mode [2, 2]
is populated, this is not
represented in the 2-bin (i.e., when indexing the returning array of
this function at [2]
), but in the 3-bin because its distance from the
center is sqrt(2**2 + 2**2) = 2.8284...
which is not in the range of
the 2-bin [1.5, 2.5)
.
Source code in exponax/_spectral.py
858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 |
|
exponax.spectral.get_fourier_coefficients
¤
get_fourier_coefficients(
state: Float[Array, "C ... N"],
*,
scaling_compensation_mode: Optional[
Literal[
"norm_compensation",
"reconstruction",
"coef_extraction",
]
] = "coef_extraction",
round: Optional[int] = 5,
indexing: str = "ij"
) -> Complex[Array, "C ... (N//2)+1"]
Extract the Fourier coefficients of a state in Fourier space.
It correctly compensates the scaling used in exponax.fft
such that the
coefficient values can be directly read off from the array.
Arguments:
state
: The state following theExponax
convention with a leading channel axis and then one, two, or three subsequent spatial axes, each of the same length N.scaling_compensation_mode
: The mode of the scaling array to use to compensate the scaling of the Fourier transform. The mode"norm_compensation"
would produce the coefficient array as produced ifjnp.fft.rfftn
was applied withnorm="forward"
, instead of the default ofnorm="backward"
which is also the default used inExponax
. The mode"reconstruction"
is similar to that but compensates for the fact that the rfft only has half of the coefficients along the right-most axis. The mode"coef_extraction"
allows to read of the coefficient e.g. at index [i, j] (in 2D) directly wheras in the other modes, one would require to consider both the positive and negative wavenumbers. Can be set toNone
to not apply any scaling compensation. See alsoexponax.spectral.build_scaling_array
for more information.round
: The number of decimals to round the coefficients to. Default is5
which compensates for the rounding errors created by the FFT in single precision such that all coefficients that should not carry any energy also have zero value. Set toNone
to not round.indexing
: The indexing scheme to use forjax.numpy.meshgrid
.
Returns:
coefficients
: The Fourier coefficients of the state.
Warning
Do not use the results of this function together with the exponax.viz
utilities since they will periodically wrap the boundary condition which
is not needed in Fourier space.
Tip
Use this function to visualize the coefficients in higher dimensions. For example in 2D
state_2d = ... # shape (1, N, N)
coef_2d = exponax.spectral.get_fourier_coefficients(state_2d)
# shape (1, N, (N//2)+1)
plt.imshow(
jnp.log10(jnp.abs(coef_2d[0])),
)
And in 3D (requires the vape4d
volume renderer to be installed - only works on GPU devices).
state_3d = ... # shape (1, N, N, N)
coef_3d = exponax.spectral.get_fourier_coefficients(
state_3d, round=None,
)
images = ex.viz.volume_render_state_3d(
jnp.log10(jnp.abs(coef_3d)), vlim=(-8, 2),
)
plt.imshow(images[0])
To have the major half to the real-valued axis more prominent, consider flipping it via
coef_3d_flipped = jnp.flip(coef_3d, axis=-1)
Tip
Interpretation Guide In general for a FFT following the NumPy conventions, we have:
- Positive amplitudes on cosine signals have positive coefficients in the real part of both the positive and the negative wavenumber.
- Positive amplitudes on sine signals have negative coefficients in the imaginary part of the positive wavenumber and positive coefficients in the imaginary part of the negative wavenumber.
As such, if the output of this function on a 1D state was
array([[3.0 + 0.0j, 0.0 - 1.5j, 0.3 + 0.8j, 0.0 + 0.0j,]])
This would correspond to a signal with:
- A constant offset of +3.0
- A first sine mode with amplitude +1.5
- A second cosine mode with amplitude +0.3
- A second sine mode with amplitude -0.8
In higher dimensions, the interpretation arise out of the tensor
product. Also be aware that for a (1, N, N)
state, the coefficients
are in the shape (1, N, (N//2)+1)
.
Source code in exponax/_spectral.py
957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 |
|
exponax.spectral.build_scaling_array
¤
build_scaling_array(
num_spatial_dims: int,
num_points: int,
*,
mode: Literal[
"norm_compensation",
"reconstruction",
"coef_extraction",
],
indexing: str = "ij"
) -> Float[Array, "1 ... (N//2)+1"]
When exponax.fft
is used, the resulting array in Fourier space represents
a scaled version of the Fourier coefficients. Use this function to produce
arrays to counteract this scaling based on the task.
"norm_compensation"
: The scaling is exactly the scaling theexponax.ifft
applies."reconstruction"
: Technically"norm_compensation"
should provide an array of coefficients that can be used to build a Fourier interpolant (i.e., whatexponax.FourierInterpolator
does). However, sinceexponax.fft
uses the real-valued FFT, there is only half of the contribution for the coefficients along the right-most axis. This mode provides the scaling to counteract this."coef_extraction"
: Any of the former modes (in higher dimensions) does not produce the same coefficients as the amplitude in the physical space (because there is a coefficient contribution both in the positive and negative wavenumber). For example, if the signal3 * cos(2x)
was discretized on the domain[0, 2pi]
with 10 points, the amplitude of the Fourier coefficient at the 2nd wavenumber would be3/2
if rescaled with mode"norm_compensation"
. This mode provides the scaling to extract the correct coefficients.
Arguments:
num_spatial_dims
: The number of spatial dimensions.num_points
: The number of points in each spatial dimension.mode
: The mode of the scaling array. Either"norm_compensation"
,"reconstruction"
, or"coef_extraction"
.indexing
: The indexing scheme to use forjax.numpy.meshgrid
. Either"ij"
or"xy"
. Default is"ij"
.
Returns:
scaling
: The scaling array.
Source code in exponax/_spectral.py
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 |
|