Skip to content

Utilities to take spectral derivatives¤

exponax.derivative ¤

derivative(
    field: Float[Array, "C ... N"],
    domain_extent: float,
    *,
    order: int = 1,
    indexing: str = "ij"
) -> Union[
    Float[Array, "C D ... (N//2)+1"],
    Float[Array, "D ... (N//2)+1"],
]

Perform the spectral derivative of a field. In higher dimensions, this defaults to the gradient (the collection of all partial derivatives). In 1d, the resulting channel dimension holds the derivative. If the function is called with an d-dimensional field which has 1 channel, the result will be a d-dimensional field with d channels (one per partial derivative). If the field originally had C channels, the result will be a matrix field with C rows and d columns.

Note that applying this operator twice will produce issues at the Nyquist if the number of degrees of freedom N is even. For this, consider also using the order option.

Arguments: - field: The field to differentiate, shape (C, ..., N,). C can be 1 for a scalar field or D for a vector field. - L: The domain extent. - order: The order of the derivative. Default is 1. - indexing: The indexing scheme to use for jax.numpy.meshgrid. Either "ij" or "xy". Default is "ij".

Returns: - field_der: The derivative of the field, shape (C, D, ..., (N//2)+1) or (D, ..., (N//2)+1).

Source code in exponax/_spectral.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def derivative(
    field: Float[Array, "C ... N"],
    domain_extent: float,
    *,
    order: int = 1,
    indexing: str = "ij",
) -> Union[Float[Array, "C D ... (N//2)+1"], Float[Array, "D ... (N//2)+1"]]:
    """
    Perform the spectral derivative of a field. In higher dimensions, this
    defaults to the gradient (the collection of all partial derivatives). In 1d,
    the resulting channel dimension holds the derivative. If the function is
    called with an d-dimensional field which has 1 channel, the result will be a
    d-dimensional field with d channels (one per partial derivative). If the
    field originally had C channels, the result will be a matrix field with C
    rows and d columns.

    Note that applying this operator twice will produce issues at the Nyquist if
    the number of degrees of freedom N is even. For this, consider also using
    the order option.

    **Arguments:**
        - `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be
            `1` for a scalar field or `D` for a vector field.
        - `L`: The domain extent.
        - `order`: The order of the derivative. Default is `1`.
        - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`.
          Either `"ij"` or `"xy"`. Default is `"ij"`.

    **Returns:**
        - `field_der`: The derivative of the field, shape `(C, D, ...,
          (N//2)+1)` or `(D, ..., (N//2)+1)`.
    """
    channel_shape = field.shape[0]
    spatial_shape = field.shape[1:]
    D = len(spatial_shape)
    N = spatial_shape[0]
    derivative_operator = build_derivative_operator(
        D, domain_extent, N, indexing=indexing
    )
    # # I decided to not use this fix

    # # Required for even N, no effect for odd N
    # derivative_operator_fixed = (
    #     derivative_operator * nyquist_filter_mask(D, N)
    # )
    derivative_operator_fixed = derivative_operator**order

    field_hat = jnp.fft.rfftn(field, axes=space_indices(D))
    if channel_shape == 1:
        # Do not introduce another channel axis
        field_der_hat = derivative_operator_fixed * field_hat
    else:
        # Create a "derivative axis" right after the channel axis
        field_der_hat = field_hat[:, None] * derivative_operator_fixed[None, ...]

    field_der = jnp.fft.irfftn(field_der_hat, s=spatial_shape, axes=space_indices(D))

    return field_der

exponax.make_incompressible ¤

make_incompressible(
    field: Float[Array, "D ... N"], *, indexing: str = "ij"
)
Source code in exponax/_spectral.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def make_incompressible(
    field: Float[Array, "D ... N"],
    *,
    indexing: str = "ij",
):
    channel_shape = field.shape[0]
    spatial_shape = field.shape[1:]
    num_spatial_dims = len(spatial_shape)
    if channel_shape != num_spatial_dims:
        raise ValueError(
            f"Expected the number of channels to be {num_spatial_dims}, got {channel_shape}."
        )
    num_points = spatial_shape[0]

    derivative_operator = build_derivative_operator(
        num_spatial_dims, 1.0, num_points, indexing=indexing
    )  # domain_extent does not matter because it will cancel out

    incompressible_field_hat = jnp.fft.rfftn(
        field, axes=space_indices(num_spatial_dims)
    )

    divergence = jnp.sum(
        derivative_operator * incompressible_field_hat, axis=0, keepdims=True
    )

    laplace_operator = build_laplace_operator(derivative_operator)

    inv_laplace_operator = jnp.where(
        laplace_operator == 0,
        1.0,
        1.0 / laplace_operator,
    )

    pseudo_pressure = -inv_laplace_operator * divergence

    pseudo_pressure_garadient = derivative_operator * pseudo_pressure

    incompressible_field_hat = incompressible_field_hat - pseudo_pressure_garadient

    incompressible_field = jnp.fft.irfftn(
        incompressible_field_hat, s=spatial_shape, axes=space_indices(num_spatial_dims)
    )

    return incompressible_field