Skip to content

Utilities to handle the grid and the states on it¤

exponax.make_grid ¤

make_grid(
    num_spatial_dims: int,
    domain_extent: float,
    num_points: int,
    *,
    full: bool = False,
    zero_centered: bool = False,
    indexing: str = "ij"
) -> Float[Array, "D ... N"]

Return a grid in the spatial domain. A grid in d dimensions is an array of shape (d,) + (num_points,)*d with the first axis representing all coordiate inidices.

Notice, that if num_spatial_dims = 1, the returned array has a singleton dimension in the first axis, i.e., the shape is (1, num_points).

Arguments: - num_spatial_dims: The number of spatial dimensions. - domain_extent: The extent of the domain in each spatial dimension. - num_points: The number of points in each spatial dimension. - full: Whether to include the right boundary point in the grid. Default: False. The right point is redundant for periodic boundary conditions and is not considered a degree of freedom. Use this option, for example, if you need a full grid for plotting. - zero_centered: Whether to center the grid around zero. Default: False. By default the grid considers a domain of (0, domain_extent)^(num_spatial_dims). - indexing: The indexing convention to use. Default: 'ij'.

Returns: - grid: The grid in the spatial domain. Shape: (num_spatial_dims, ..., num_points).

Source code in exponax/_utils.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def make_grid(
    num_spatial_dims: int,
    domain_extent: float,
    num_points: int,
    *,
    full: bool = False,
    zero_centered: bool = False,
    indexing: str = "ij",
) -> Float[Array, "D ... N"]:
    """
    Return a grid in the spatial domain. A grid in d dimensions is an array of
    shape (d,) + (num_points,)*d with the first axis representing all coordiate
    inidices.

    Notice, that if `num_spatial_dims = 1`, the returned array has a singleton
    dimension in the first axis, i.e., the shape is `(1, num_points)`.

    **Arguments:**
        - `num_spatial_dims`: The number of spatial dimensions.
        - `domain_extent`: The extent of the domain in each spatial dimension.
        - `num_points`: The number of points in each spatial dimension.
        - `full`: Whether to include the right boundary point in the grid.
            Default: `False`. The right point is redundant for periodic boundary
            conditions and is not considered a degree of freedom. Use this
            option, for example, if you need a full grid for plotting.
        - `zero_centered`: Whether to center the grid around zero. Default:
            `False`. By default the grid considers a domain of (0,
            domain_extent)^(num_spatial_dims).
        - `indexing`: The indexing convention to use. Default: `'ij'`.

    **Returns:**
        - `grid`: The grid in the spatial domain. Shape: `(num_spatial_dims, ..., num_points)`.
    """
    if full:
        grid_1d = jnp.linspace(0, domain_extent, num_points + 1, endpoint=True)
    else:
        grid_1d = jnp.linspace(0, domain_extent, num_points, endpoint=False)

    if zero_centered:
        grid_1d -= domain_extent / 2

    grid_list = [
        grid_1d,
    ] * num_spatial_dims

    grid = jnp.stack(
        jnp.meshgrid(*grid_list, indexing=indexing),
    )

    return grid

exponax.wrap_bc ¤

wrap_bc(u)

Wraps the periodic boundary conditions around the array u.

This can be used to plot the solution of a periodic problem on the full interval [0, L] by plotting wrap_bc(u) instead of u.

Parameters: - u: The array to wrap, shape (N,).

Returns: - u_wrapped: The wrapped array, shape (N + 1,).

Source code in exponax/_utils.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def wrap_bc(u):
    """
    Wraps the periodic boundary conditions around the array `u`.

    This can be used to plot the solution of a periodic problem on the full
    interval [0, L] by plotting `wrap_bc(u)` instead of `u`.

    **Parameters:**
        - `u`: The array to wrap, shape `(N,)`.

    **Returns:**
        - `u_wrapped`: The wrapped array, shape `(N + 1,)`.
    """
    _, *spatial_shape = u.shape
    num_spatial_dims = len(spatial_shape)

    padding_config = ((0, 0),) + ((0, 1),) * num_spatial_dims
    u_wrapped = jnp.pad(u, padding_config, mode="wrap")

    return u_wrapped