Skip to content

Interpolation¤

... or utilities to move between different grid representations.

exponax.map_between_resolutions ¤

map_between_resolutions(
    state: Float[Array, "C ... N"],
    new_num_points: int,
    *,
    oddball_zero: bool = True
) -> Float[Array, "C ... N_new"]

Upsamples or downsamples a state in Exponax convention to a new resolution via manipulation of its Fourier representation.

This approach is way more efficient that exponax.FourierInterpolator but can only move the state between uniform Cartesian grids of different resolutions.

Info

If the new resolution is higher than the old resolution, the state is upsampled. If the new resolution is lower than the old resolution, the state is downsampled. If the given state is bandlimited, i.e., the highest wavenumber containing non-zero energy is at max (N//2), then upsampling will be exact (no interpolation error). Also, in case of downsampling: if the given state was bandlimited, and the it would be still be bandlimited in the new resolution, this downsampling will also be exact, i.e., no coarsening artifacts. If this is not the case, one loses high-frequency (fine scale) information.

Arguments:

  • state: The state to interpolate. Must conform to the Exponax standard with a leading channel axis (can be a singleton axis if there is only one channel), and one, two, or three subsequent spatial axes (depending on the number of spatial dimensions). These latter spatial axes must have the same number of dimensions.
  • new_num_points: The new number of points in each spatial dimension.
  • oddball_zero: Whether to zero out the Nyquist frequency in case of even-sized grids. This is usually preferred.

Returns:

  • new_state: The state interpolated to the new resolution. This will have the same number of channels as the input state.
Source code in exponax/_interpolation.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def map_between_resolutions(
    state: Float[Array, "C ... N"],
    new_num_points: int,
    *,
    oddball_zero: bool = True,
) -> Float[Array, "C ... N_new"]:
    """
    Upsamples or downsamples a state in `Exponax` convention to a new resolution
    via manipulation of its Fourier representation.

    This approach is way more efficient that `exponax.FourierInterpolator` but
    can only move the state between uniform Cartesian grids of different
    resolutions.

    !!! info
        If the new resolution is higher than the old resolution, the state is
        upsampled. If the new resolution is lower than the old resolution, the
        state is downsampled. If the given state is bandlimited, i.e., the
        highest wavenumber containing non-zero energy is at max `(N//2)`, then
        upsampling will be exact (no interpolation error). Also, in case of
        downsampling: if the given state was bandlimited, and the it would be
        still be bandlimited in the new resolution, this downsampling will also
        be exact, i.e., no coarsening artifacts. If this is not the case, one
        loses high-frequency (fine scale) information.

    **Arguments:**

    - `state`: The state to interpolate. Must conform to the `Exponax`
        standard with a leading channel axis (can be a singleton axis if there
        is only one channel), and one, two, or three subsequent spatial axes
        (depending on the number of spatial dimensions). These latter spatial
        axes must have the same number of dimensions.
    - `new_num_points`: The new number of points in each spatial dimension.
    - `oddball_zero`: Whether to zero out the Nyquist frequency in case of
        even-sized grids. This is usually preferred.

    **Returns:**

    - `new_state`: The state interpolated to the new resolution. This will have
        the same number of channels as the input state.
    """
    num_spatial_dims = state.ndim - 1
    old_num_points = state.shape[-1]
    num_channels = state.shape[0]

    if old_num_points == new_num_points:
        return state

    old_state_hat_scaled = fft(
        state, num_spatial_dims=num_spatial_dims
    ) / build_scaling_array(
        num_spatial_dims,
        old_num_points,
        mode="norm_compensation",
    )

    if new_num_points > old_num_points:
        # Upscaling
        if old_num_points % 2 == 0 and oddball_zero:
            old_state_hat_scaled *= oddball_filter_mask(
                num_spatial_dims, old_num_points
            )

    new_state_hat_scaled = jnp.zeros(
        (num_channels,) + wavenumber_shape(num_spatial_dims, new_num_points),
        dtype=old_state_hat_scaled.dtype,
    )

    modes_slices: list[list[slice]] = get_modes_slices(
        num_spatial_dims,
        min(old_num_points, new_num_points),
    )

    for block_slice in modes_slices:
        new_state_hat_scaled = new_state_hat_scaled.at[block_slice].set(
            old_state_hat_scaled[block_slice]
        )

    new_state_hat = new_state_hat_scaled * build_scaling_array(
        num_spatial_dims,
        new_num_points,
        mode="norm_compensation",
    )
    if old_num_points > new_num_points:
        # Downscaling
        if new_num_points % 2 == 0 and oddball_zero:
            new_state_hat *= oddball_filter_mask(num_spatial_dims, new_num_points)

    new_state = ifft(
        new_state_hat,
        num_spatial_dims=num_spatial_dims,
        num_points=new_num_points,
    )

    return new_state

exponax.FourierInterpolator ¤

Bases: Module

Source code in exponax/_interpolation.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
class FourierInterpolator(eqx.Module):
    num_spatial_dims: int
    domain_extent: float
    num_points: int
    state_hat_scaled: Complex[Array, "C ... (N//2)+1"]
    wavenumbers: Float[Array, "D ... (N//2)+1"]

    def __init__(
        self,
        state: Float[Array, "C ... N"],
        *,
        domain_extent: float = 1.0,
        indexing: Literal["ij", "xy"] = "ij",
    ):
        """
        Builds an interpolation function for an `Exponax` state using its
        Fourier representation.

        After instantiation, the interpolant can be called with a query
        coordinate `x ∈ ℝᴰ` (e.g., `x = jnp.array([0.3, 0.5])` in 2D) to obtain
        the corresponding value. If the query coordinate is not within the
        domain, i.e., `x ∉ Ω = [0, L]ᴰ`, the returned result is found in its
        periodic extension.

        !!! info
            If the state is band-limited, i.e., the highest wavenumber
            containing non-zero energy is at max `(N//2)`, then the
            interpolation will be exact (no interpolation error).

        !!! warning
            This interpolation uses global basis functions. Hence its memory and
            computation for evaluating one query location scales with `O(N^D)`.
            Consequently, if multiple query locations are to be evaluated in
            parallel (via `jax.vmap`), the memory and computation scales with
            `O(N^D * M)` where `M` is the number of query locations. This can
            easily exceed available resources. In such cases, either consider
            evaluating the query locations in smaller batches or resort to local
            basis interpolants like linear or cubic splines (see
            `scipy.interpolate` or its JAX anologons).

        **Arguments:**

        - `state`: The state to interpolate. Must conform to the `Exponax`
            standard with a leading channel axis (can be a singleton axis if
            there is only one channel), and one, two, or three subsequent
            spatial axes (depending on the number of spatial dimensions). These
            latter spatial axes must have the same number of dimensions.
        - `domain_extent`: The size of the domain `L`; in higher dimensions the
            domain is assumed to be a scaled hypercube `Ω = (0, L)ᴰ`.
        - `indexing`: The indexing convention of the spatial axes. The default
            `"ij"` follows the `Exponax` convention.
        """
        self.num_spatial_dims = state.ndim - 1
        self.domain_extent = domain_extent
        self.num_points = state.shape[-1]

        self.state_hat_scaled = fft(state, num_spatial_dims=self.num_spatial_dims) / (
            build_scaling_array(
                self.num_spatial_dims,
                self.num_points,
                mode="reconstruction",
                indexing=indexing,
            )
        )
        self.wavenumbers = build_scaled_wavenumbers(
            self.num_spatial_dims,
            self.domain_extent,
            self.num_points,
            indexing=indexing,
        )

    def __call__(
        self,
        x: Float[Array, "D"],
    ) -> Float[Array, "C"]:
        """
        Evaluate the interpolant at the query location `x`.

        **Arguments:**

        - `x`: The query location. Must be a vector of length `D` where `D` is
            the number of spatial dimensions. This must match the number of
            spatial dimensions of the state used to build the interpolant.

        **Returns:**

        - `interpolated_value`: The interpolated value at the query location
            `x`. This will have as many channels as the state used to build the
            interpolant.


        !!! tip
            To evaluate the interpolant at multiple query locations in parallel,
            use `jax.vmap`. For example, in 1d:

            ```python

            print(state.shape)  # (C, N)

            interpolator = FourierInterpolator(state, domain_extent=1.0)

            print(query_locations.shape)  # (1, M)

            interpolated_values = jax.vmap(
                interpolator, in_axes=-1, out_axes=-1,
            )(query_locations)

            print(interpolated_values.shape)  # (C, M)

            ```

            If the query locations have multiple batch axes (e.g., to represent
            another grid), consider using nested `jax.vmap` calls. For example,
            in 2D

            ```python

            print(state.shape)  # (C, N, N)

            interpolator = FourierInterpolator(state, domain_extent=1.0)

            print(query_locations.shape)  # (2, M, P)

            interpolated_values = jax.vmap(
                jax.vmap(interpolator, in_axes=-1, out_axes=-1), in_axes=-2,
                out_axes=-2,
            )(query_locations)

            print(interpolated_values.shape)  # (C, M, P)

            ```

        !!! warning
            This interpolation uses global basis functions. Hence its memory and
            computation for evaluating one query location scales with `O(N^D)`.
            Consequently, if multiple query locations are to be evaluated in
            parallel (via `jax.vmap`), the memory and computation scales with
            `O(N^D * M)` where `M` is the number of query locations. This can
            easily exceed available resources. In such cases, consider
            evaluating the query locations in smaller batches.
        """
        # Adds singleton axes for each spatial dimension
        x_bloated: Float[Array, "D ... 1"] = jnp.expand_dims(
            x, axis=space_indices(self.num_spatial_dims)
        )

        # The exponential term sums over the wavenumber dimension axis (`"D"`)
        exp_term: Complex[Array, "... (N//2)+1"] = jnp.exp(
            jnp.sum(1j * self.wavenumbers * x_bloated, axis=0)
        )

        # Re-add a singleton channel axis to have broadcasting work correctly
        exp_term: Complex[Array, "1 ... (N//2)+1"] = exp_term[None, ...]

        interpolation_operation: Complex[Array, "C ... (N//2)+1"] = (
            self.state_hat_scaled * exp_term
        )

        interpolated_value: Float[Array, "C"] = jnp.real(
            jax.vmap(jnp.sum)(interpolation_operation)
        )

        return interpolated_value
__init__ ¤
__init__(
    state: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0,
    indexing: Literal["ij", "xy"] = "ij"
)

Builds an interpolation function for an Exponax state using its Fourier representation.

After instantiation, the interpolant can be called with a query coordinate x ∈ ℝᴰ (e.g., x = jnp.array([0.3, 0.5]) in 2D) to obtain the corresponding value. If the query coordinate is not within the domain, i.e., x ∉ Ω = [0, L]ᴰ, the returned result is found in its periodic extension.

Info

If the state is band-limited, i.e., the highest wavenumber containing non-zero energy is at max (N//2), then the interpolation will be exact (no interpolation error).

Warning

This interpolation uses global basis functions. Hence its memory and computation for evaluating one query location scales with O(N^D). Consequently, if multiple query locations are to be evaluated in parallel (via jax.vmap), the memory and computation scales with O(N^D * M) where M is the number of query locations. This can easily exceed available resources. In such cases, either consider evaluating the query locations in smaller batches or resort to local basis interpolants like linear or cubic splines (see scipy.interpolate or its JAX anologons).

Arguments:

  • state: The state to interpolate. Must conform to the Exponax standard with a leading channel axis (can be a singleton axis if there is only one channel), and one, two, or three subsequent spatial axes (depending on the number of spatial dimensions). These latter spatial axes must have the same number of dimensions.
  • domain_extent: The size of the domain L; in higher dimensions the domain is assumed to be a scaled hypercube Ω = (0, L)ᴰ.
  • indexing: The indexing convention of the spatial axes. The default "ij" follows the Exponax convention.
Source code in exponax/_interpolation.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def __init__(
    self,
    state: Float[Array, "C ... N"],
    *,
    domain_extent: float = 1.0,
    indexing: Literal["ij", "xy"] = "ij",
):
    """
    Builds an interpolation function for an `Exponax` state using its
    Fourier representation.

    After instantiation, the interpolant can be called with a query
    coordinate `x ∈ ℝᴰ` (e.g., `x = jnp.array([0.3, 0.5])` in 2D) to obtain
    the corresponding value. If the query coordinate is not within the
    domain, i.e., `x ∉ Ω = [0, L]ᴰ`, the returned result is found in its
    periodic extension.

    !!! info
        If the state is band-limited, i.e., the highest wavenumber
        containing non-zero energy is at max `(N//2)`, then the
        interpolation will be exact (no interpolation error).

    !!! warning
        This interpolation uses global basis functions. Hence its memory and
        computation for evaluating one query location scales with `O(N^D)`.
        Consequently, if multiple query locations are to be evaluated in
        parallel (via `jax.vmap`), the memory and computation scales with
        `O(N^D * M)` where `M` is the number of query locations. This can
        easily exceed available resources. In such cases, either consider
        evaluating the query locations in smaller batches or resort to local
        basis interpolants like linear or cubic splines (see
        `scipy.interpolate` or its JAX anologons).

    **Arguments:**

    - `state`: The state to interpolate. Must conform to the `Exponax`
        standard with a leading channel axis (can be a singleton axis if
        there is only one channel), and one, two, or three subsequent
        spatial axes (depending on the number of spatial dimensions). These
        latter spatial axes must have the same number of dimensions.
    - `domain_extent`: The size of the domain `L`; in higher dimensions the
        domain is assumed to be a scaled hypercube `Ω = (0, L)ᴰ`.
    - `indexing`: The indexing convention of the spatial axes. The default
        `"ij"` follows the `Exponax` convention.
    """
    self.num_spatial_dims = state.ndim - 1
    self.domain_extent = domain_extent
    self.num_points = state.shape[-1]

    self.state_hat_scaled = fft(state, num_spatial_dims=self.num_spatial_dims) / (
        build_scaling_array(
            self.num_spatial_dims,
            self.num_points,
            mode="reconstruction",
            indexing=indexing,
        )
    )
    self.wavenumbers = build_scaled_wavenumbers(
        self.num_spatial_dims,
        self.domain_extent,
        self.num_points,
        indexing=indexing,
    )
__call__ ¤
__call__(x: Float[Array, D]) -> Float[Array, C]

Evaluate the interpolant at the query location x.

Arguments:

  • x: The query location. Must be a vector of length D where D is the number of spatial dimensions. This must match the number of spatial dimensions of the state used to build the interpolant.

Returns:

  • interpolated_value: The interpolated value at the query location x. This will have as many channels as the state used to build the interpolant.

Tip

To evaluate the interpolant at multiple query locations in parallel, use jax.vmap. For example, in 1d:

print(state.shape)  # (C, N)

interpolator = FourierInterpolator(state, domain_extent=1.0)

print(query_locations.shape)  # (1, M)

interpolated_values = jax.vmap(
    interpolator, in_axes=-1, out_axes=-1,
)(query_locations)

print(interpolated_values.shape)  # (C, M)

If the query locations have multiple batch axes (e.g., to represent another grid), consider using nested jax.vmap calls. For example, in 2D

print(state.shape)  # (C, N, N)

interpolator = FourierInterpolator(state, domain_extent=1.0)

print(query_locations.shape)  # (2, M, P)

interpolated_values = jax.vmap(
    jax.vmap(interpolator, in_axes=-1, out_axes=-1), in_axes=-2,
    out_axes=-2,
)(query_locations)

print(interpolated_values.shape)  # (C, M, P)

Warning

This interpolation uses global basis functions. Hence its memory and computation for evaluating one query location scales with O(N^D). Consequently, if multiple query locations are to be evaluated in parallel (via jax.vmap), the memory and computation scales with O(N^D * M) where M is the number of query locations. This can easily exceed available resources. In such cases, consider evaluating the query locations in smaller batches.

Source code in exponax/_interpolation.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def __call__(
    self,
    x: Float[Array, "D"],
) -> Float[Array, "C"]:
    """
    Evaluate the interpolant at the query location `x`.

    **Arguments:**

    - `x`: The query location. Must be a vector of length `D` where `D` is
        the number of spatial dimensions. This must match the number of
        spatial dimensions of the state used to build the interpolant.

    **Returns:**

    - `interpolated_value`: The interpolated value at the query location
        `x`. This will have as many channels as the state used to build the
        interpolant.


    !!! tip
        To evaluate the interpolant at multiple query locations in parallel,
        use `jax.vmap`. For example, in 1d:

        ```python

        print(state.shape)  # (C, N)

        interpolator = FourierInterpolator(state, domain_extent=1.0)

        print(query_locations.shape)  # (1, M)

        interpolated_values = jax.vmap(
            interpolator, in_axes=-1, out_axes=-1,
        )(query_locations)

        print(interpolated_values.shape)  # (C, M)

        ```

        If the query locations have multiple batch axes (e.g., to represent
        another grid), consider using nested `jax.vmap` calls. For example,
        in 2D

        ```python

        print(state.shape)  # (C, N, N)

        interpolator = FourierInterpolator(state, domain_extent=1.0)

        print(query_locations.shape)  # (2, M, P)

        interpolated_values = jax.vmap(
            jax.vmap(interpolator, in_axes=-1, out_axes=-1), in_axes=-2,
            out_axes=-2,
        )(query_locations)

        print(interpolated_values.shape)  # (C, M, P)

        ```

    !!! warning
        This interpolation uses global basis functions. Hence its memory and
        computation for evaluating one query location scales with `O(N^D)`.
        Consequently, if multiple query locations are to be evaluated in
        parallel (via `jax.vmap`), the memory and computation scales with
        `O(N^D * M)` where `M` is the number of query locations. This can
        easily exceed available resources. In such cases, consider
        evaluating the query locations in smaller batches.
    """
    # Adds singleton axes for each spatial dimension
    x_bloated: Float[Array, "D ... 1"] = jnp.expand_dims(
        x, axis=space_indices(self.num_spatial_dims)
    )

    # The exponential term sums over the wavenumber dimension axis (`"D"`)
    exp_term: Complex[Array, "... (N//2)+1"] = jnp.exp(
        jnp.sum(1j * self.wavenumbers * x_bloated, axis=0)
    )

    # Re-add a singleton channel axis to have broadcasting work correctly
    exp_term: Complex[Array, "1 ... (N//2)+1"] = exp_term[None, ...]

    interpolation_operation: Complex[Array, "C ... (N//2)+1"] = (
        self.state_hat_scaled * exp_term
    )

    interpolated_value: Float[Array, "C"] = jnp.real(
        jax.vmap(jnp.sum)(interpolation_operation)
    )

    return interpolated_value