Skip to content

Utilities to take spectral derivatives¤

exponax.derivative ¤

derivative(
    field: Float[Array, "C ... N"],
    domain_extent: float,
    *,
    order: int = 1,
    indexing: str = "ij"
) -> Union[
    Float[Array, "C D ... (N//2)+1"],
    Float[Array, "D ... (N//2)+1"],
]

Perform the spectral derivative of a field. In higher dimensions, this defaults to the gradient (the collection of all partial derivatives). In 1d, the resulting channel dimension holds the derivative. If the function is called with an d-dimensional field which has 1 channel, the result will be a d-dimensional field with d channels (one per partial derivative). If the field originally had C channels, the result will be a matrix field with C rows and d columns.

Note that applying this operator twice will produce issues at the Nyquist if the number of degrees of freedom N is even. For this, consider also using the order option.

Warning

The argument num_spatial_dims can only be correctly inferred if the array follows the Exponax convention, e.g., no leading batch axis. For a batched operation, use jax.vmap on this function.

Arguments:

  • field: The field to differentiate, shape (C, ..., N,). C can be 1 for a scalar field or D for a vector field.
  • domain_extent: The size of the domain L; in higher dimensions the domain is assumed to be a scaled hypercube Ω = (0, L)ᵈ.
  • order: The order of the derivative. Default is 1.
  • indexing: The indexing scheme to use for jax.numpy.meshgrid. Either "ij" or "xy". Default is "ij".

Returns:

  • field_der: The derivative of the field, shape (C, D, ..., (N//2)+1) or (D, ..., (N//2)+1).
Source code in exponax/_spectral.py
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
def derivative(
    field: Float[Array, "C ... N"],
    domain_extent: float,
    *,
    order: int = 1,
    indexing: str = "ij",
) -> Union[Float[Array, "C D ... (N//2)+1"], Float[Array, "D ... (N//2)+1"]]:
    """
    Perform the spectral derivative of a field. In higher dimensions, this
    defaults to the gradient (the collection of all partial derivatives). In 1d,
    the resulting channel dimension holds the derivative. If the function is
    called with an d-dimensional field which has 1 channel, the result will be a
    d-dimensional field with d channels (one per partial derivative). If the
    field originally had C channels, the result will be a matrix field with C
    rows and d columns.

    Note that applying this operator twice will produce issues at the Nyquist if
    the number of degrees of freedom N is even. For this, consider also using
    the order option.

    !!! warning
        The argument `num_spatial_dims` can only be correctly inferred if the
        array follows the Exponax convention, e.g., no leading batch axis. For a
        batched operation, use `jax.vmap` on this function.

    **Arguments:**

    - `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be
        `1` for a scalar field or `D` for a vector field.
    - `domain_extent`: The size of the domain `L`; in higher dimensions
        the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`.
    - `order`: The order of the derivative. Default is `1`.
    - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`.
        Either `"ij"` or `"xy"`. Default is `"ij"`.

    **Returns:**

    - `field_der`: The derivative of the field, shape `(C, D, ...,
        (N//2)+1)` or `(D, ..., (N//2)+1)`.
    """
    channel_shape = field.shape[0]
    spatial_shape = field.shape[1:]
    num_spatial_dims = len(spatial_shape)
    num_points = spatial_shape[0]
    derivative_operator = build_derivative_operator(
        num_spatial_dims, domain_extent, num_points, indexing=indexing
    )
    # # I decided to not use this fix

    # # Required for even N, no effect for odd N
    # derivative_operator_fixed = (
    #     derivative_operator * nyquist_filter_mask(D, N)
    # )
    derivative_operator_fixed = derivative_operator**order

    field_hat = fft(field, num_spatial_dims=num_spatial_dims)
    if channel_shape == 1:
        # Do not introduce another channel axis
        field_der_hat = derivative_operator_fixed * field_hat
    else:
        # Create a "derivative axis" right after the channel axis
        field_der_hat = field_hat[:, None] * derivative_operator_fixed[None, ...]

    field_der = ifft(
        field_der_hat, num_spatial_dims=num_spatial_dims, num_points=num_points
    )

    return field_der

exponax.spectral.make_incompressible ¤

make_incompressible(
    field: Float[Array, "D ... N"], *, indexing: str = "ij"
)

Makes a velocity field incompressible by solving the associated pressure Poisson equation and subtract the pressure gradient.

With the divergence of the velocity field as the right-hand side, solve the Poisson equation for pressure p

Δp = - ∇ ⋅ v⃗

and then correct the velocity field to be incompressible

v⃗ ← v⃗ - ∇p

Arguments:

  • field: The velocity field to make incompressible, shape (D, ..., N,). Must have as many channel dimensions as spatial axes.
  • indexing: The indexing scheme to use for jax.numpy.meshgrid.

Returns:

  • incompressible_field: The incompressible velocity field, shape (D, ..., N,).
Source code in exponax/_spectral.py
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
def make_incompressible(
    field: Float[Array, "D ... N"],
    *,
    indexing: str = "ij",
):
    """
    Makes a velocity field incompressible by solving the associated pressure
    Poisson equation and subtract the pressure gradient.

    With the divergence of the velocity field as the right-hand side, solve the
    Poisson equation for pressure `p`

        Δp = - ∇ ⋅ v⃗

    and then correct the velocity field to be incompressible

        v⃗ ← v⃗ - ∇p

    **Arguments:**

    - `field`: The velocity field to make incompressible, shape `(D, ..., N,)`.
        Must have as many channel dimensions as spatial axes.
    - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`.

    **Returns:**

    - `incompressible_field`: The incompressible velocity field, shape `(D, ...,
        N,)`.
    """
    channel_shape = field.shape[0]
    spatial_shape = field.shape[1:]
    num_spatial_dims = len(spatial_shape)
    if channel_shape != num_spatial_dims:
        raise ValueError(
            f"Expected the number of channels to be {num_spatial_dims}, got {channel_shape}."
        )
    num_points = spatial_shape[0]

    derivative_operator = build_derivative_operator(
        num_spatial_dims, 1.0, num_points, indexing=indexing
    )  # domain_extent does not matter because it will cancel out

    incompressible_field_hat = fft(field, num_spatial_dims=num_spatial_dims)

    divergence = jnp.sum(
        derivative_operator * incompressible_field_hat, axis=0, keepdims=True
    )

    laplace_operator = build_laplace_operator(derivative_operator)

    inv_laplace_operator = jnp.where(
        laplace_operator == 0,
        1.0,
        1.0 / laplace_operator,
    )

    pseudo_pressure = -inv_laplace_operator * divergence

    pseudo_pressure_garadient = derivative_operator * pseudo_pressure

    incompressible_field_hat = incompressible_field_hat - pseudo_pressure_garadient

    incompressible_field = ifft(
        incompressible_field_hat,
        num_spatial_dims=num_spatial_dims,
        num_points=num_points,
    )

    return incompressible_field