Skip to content

Poisson Solver¤

(for completion - not a time-dependent stepper)

exponax.poisson.Poisson ¤

Bases: Module

Source code in exponax/_poisson.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class Poisson(eqx.Module):
    num_spatial_dims: int
    domain_extent: float
    num_points: int
    dx: float

    _inv_operator: Complex[Array, "1 ... N"]

    def __init__(
        self,
        num_spatial_dims: int,
        domain_extent: float,
        num_points: int,
        *,
        order=2,
    ):
        """
        Exactly solve the Poisson equation with periodic boundary conditions.

        This "stepper" is different from all other steppers in this package in
        that it does not solve a time-dependent PDE. Instead, it solves the
        Poisson equation

        $$ u_{xx} = - f $$

        for a given right hand side $f$.

        It is included for completion.

        **Arguments:**
            - `num_spatial_dims`: The number of spatial dimensions.
            - `domain_extent`: The extent of the domain.
            - `num_points`: The number of points in each spatial dimension.
            - `order`: The order of the Poisson equation. Defaults to 2. You can
              also set `order=4` for the biharmonic equation.
        """
        self.num_spatial_dims = num_spatial_dims
        self.domain_extent = domain_extent
        self.num_points = num_points

        # Uses the convention that N does **not** include the right boundary
        # point
        self.dx = domain_extent / num_points

        derivative_operator = build_derivative_operator(
            num_spatial_dims, domain_extent, num_points
        )
        operator = build_laplace_operator(derivative_operator, order=order)

        # Uses mean zero solution
        self._inv_operator = jnp.where(operator == 0, 0.0, 1 / operator)

    def step_fourier(
        self,
        f_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        """
        Solve the Poisson equation in Fourier space.

        **Arguments:**
            - `f_hat`: The Fourier transform of the right hand side.

        **Returns:**
            - `u_hat`: The Fourier transform of the solution.
        """
        return -self._inv_operator * f_hat

    def step(
        self,
        f: Float[Array, "C ... N"],
    ) -> Float[Array, "C ... N"]:
        """
        Solve the Poisson equation in real space.

        **Arguments:**
            - `f`: The right hand side.

        **Returns:**
            - `u`: The solution.
        """
        f_hat = jnp.fft.rfftn(f, axes=space_indices(self.num_spatial_dims))
        u_hat = self.step_fourier(f_hat)
        u = jnp.fft.irfftn(
            u_hat,
            axes=space_indices(self.num_spatial_dims),
            s=spatial_shape(self.num_spatial_dims, self.num_points),
        )
        return u

    def __call__(
        self,
        f: Float[Array, "C ... N"],
    ) -> Float[Array, "C ... N"]:
        if f.shape[1:] != spatial_shape(self.num_spatial_dims, self.num_points):
            raise ValueError(
                f"Shape of f[1:] is {f.shape[1:]} but should be {spatial_shape(self.num_spatial_dims, self.num_points)}"
            )
        return self.step(f)
__init__ ¤
__init__(
    num_spatial_dims: int,
    domain_extent: float,
    num_points: int,
    *,
    order=2
)

Exactly solve the Poisson equation with periodic boundary conditions.

This "stepper" is different from all other steppers in this package in that it does not solve a time-dependent PDE. Instead, it solves the Poisson equation

\[ u_{xx} = - f \]

for a given right hand side \(f\).

It is included for completion.

Arguments: - num_spatial_dims: The number of spatial dimensions. - domain_extent: The extent of the domain. - num_points: The number of points in each spatial dimension. - order: The order of the Poisson equation. Defaults to 2. You can also set order=4 for the biharmonic equation.

Source code in exponax/_poisson.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
    self,
    num_spatial_dims: int,
    domain_extent: float,
    num_points: int,
    *,
    order=2,
):
    """
    Exactly solve the Poisson equation with periodic boundary conditions.

    This "stepper" is different from all other steppers in this package in
    that it does not solve a time-dependent PDE. Instead, it solves the
    Poisson equation

    $$ u_{xx} = - f $$

    for a given right hand side $f$.

    It is included for completion.

    **Arguments:**
        - `num_spatial_dims`: The number of spatial dimensions.
        - `domain_extent`: The extent of the domain.
        - `num_points`: The number of points in each spatial dimension.
        - `order`: The order of the Poisson equation. Defaults to 2. You can
          also set `order=4` for the biharmonic equation.
    """
    self.num_spatial_dims = num_spatial_dims
    self.domain_extent = domain_extent
    self.num_points = num_points

    # Uses the convention that N does **not** include the right boundary
    # point
    self.dx = domain_extent / num_points

    derivative_operator = build_derivative_operator(
        num_spatial_dims, domain_extent, num_points
    )
    operator = build_laplace_operator(derivative_operator, order=order)

    # Uses mean zero solution
    self._inv_operator = jnp.where(operator == 0, 0.0, 1 / operator)
__call__ ¤
__call__(
    f: Float[Array, "C ... N"]
) -> Float[Array, "C ... N"]
Source code in exponax/_poisson.py
102
103
104
105
106
107
108
109
110
def __call__(
    self,
    f: Float[Array, "C ... N"],
) -> Float[Array, "C ... N"]:
    if f.shape[1:] != spatial_shape(self.num_spatial_dims, self.num_points):
        raise ValueError(
            f"Shape of f[1:] is {f.shape[1:]} but should be {spatial_shape(self.num_spatial_dims, self.num_points)}"
        )
    return self.step(f)