Skip to content

ETDRK Backbone¤

Core clases that implement the Exponential Time Differencing Runge-Kutta (ETDRK) method for solving semi-linear PDEs in form of timesteppers. Require supplying the time step size \(\Delta t\), the linear operator in Fourier space \(\hat{\mathcal{L}}_h\), and the non-linear operator in Fourier space \(\hat{\mathcal{N}}_h\).

exponax.etdrk.ETDRK0 ¤

Bases: BaseETDRK

Exactly solve a linear PDE in Fourier space

Source code in exponax/etdrk/_etdrk_0.py
 6
 7
 8
 9
10
11
12
13
14
15
class ETDRK0(BaseETDRK):
    """
    Exactly solve a linear PDE in Fourier space
    """

    def step_fourier(
        self,
        u_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        return self._exp_term * u_hat
__init__ ¤
__init__(
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
)
Source code in exponax/etdrk/_base_etdrk.py
19
20
21
22
23
24
25
def __init__(
    self,
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
):
    self.dt = dt
    self._exp_term = jnp.exp(self.dt * linear_operator)
step_fourier ¤
step_fourier(
    u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_0.py
11
12
13
14
15
def step_fourier(
    self,
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
    return self._exp_term * u_hat

exponax.etdrk.ETDRK1 ¤

Bases: BaseETDRK

Source code in exponax/etdrk/_etdrk_1.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class ETDRK1(BaseETDRK):
    _nonlinear_fun: BaseNonlinearFun
    _coef_1: Complex[Array, "E ... (N//2)+1"]

    def __init__(
        self,
        dt: float,
        linear_operator: Complex[Array, "E ... (N//2)+1"],
        nonlinear_fun: BaseNonlinearFun,
        *,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        super().__init__(dt, linear_operator)
        self._nonlinear_fun = nonlinear_fun

        LR = (
            circle_radius * roots_of_unity(num_circle_points)
            + linear_operator[..., jnp.newaxis] * dt
        )

        self._coef_1 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real

    def step_fourier(
        self,
        u_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        return self._exp_term * u_hat + self._coef_1 * self._nonlinear_fun(u_hat)
__init__ ¤
__init__(
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0
)
Source code in exponax/etdrk/_etdrk_1.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def __init__(
    self,
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0,
):
    super().__init__(dt, linear_operator)
    self._nonlinear_fun = nonlinear_fun

    LR = (
        circle_radius * roots_of_unity(num_circle_points)
        + linear_operator[..., jnp.newaxis] * dt
    )

    self._coef_1 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real
step_fourier ¤
step_fourier(
    u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_1.py
32
33
34
35
36
def step_fourier(
    self,
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
    return self._exp_term * u_hat + self._coef_1 * self._nonlinear_fun(u_hat)

exponax.etdrk.ETDRK2 ¤

Bases: BaseETDRK

Source code in exponax/etdrk/_etdrk_2.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class ETDRK2(BaseETDRK):
    _nonlinear_fun: BaseNonlinearFun
    _coef_1: Complex[Array, "E ... (N//2)+1"]
    _coef_2: Complex[Array, "E ... (N//2)+1"]

    def __init__(
        self,
        dt: float,
        linear_operator: Complex[Array, "E ... (N//2)+1"],
        nonlinear_fun: BaseNonlinearFun,
        *,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        super().__init__(dt, linear_operator)
        self._nonlinear_fun = nonlinear_fun

        LR = (
            circle_radius * roots_of_unity(num_circle_points)
            + linear_operator[..., jnp.newaxis] * dt
        )

        self._coef_1 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real

        self._coef_2 = dt * jnp.mean((jnp.exp(LR) - 1 - LR) / LR**2, axis=-1).real

    def step_fourier(
        self,
        u_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        u_nonlin_hat = self._nonlinear_fun(u_hat)
        u_stage_1_hat = self._exp_term * u_hat + self._coef_1 * u_nonlin_hat

        u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
        u_next_hat = u_stage_1_hat + self._coef_2 * (
            u_stage_1_nonlin_hat - u_nonlin_hat
        )
        return u_next_hat
__init__ ¤
__init__(
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0
)
Source code in exponax/etdrk/_etdrk_2.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def __init__(
    self,
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0,
):
    super().__init__(dt, linear_operator)
    self._nonlinear_fun = nonlinear_fun

    LR = (
        circle_radius * roots_of_unity(num_circle_points)
        + linear_operator[..., jnp.newaxis] * dt
    )

    self._coef_1 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real

    self._coef_2 = dt * jnp.mean((jnp.exp(LR) - 1 - LR) / LR**2, axis=-1).real
step_fourier ¤
step_fourier(
    u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_2.py
35
36
37
38
39
40
41
42
43
44
45
46
def step_fourier(
    self,
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
    u_nonlin_hat = self._nonlinear_fun(u_hat)
    u_stage_1_hat = self._exp_term * u_hat + self._coef_1 * u_nonlin_hat

    u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
    u_next_hat = u_stage_1_hat + self._coef_2 * (
        u_stage_1_nonlin_hat - u_nonlin_hat
    )
    return u_next_hat

exponax.etdrk.ETDRK3 ¤

Bases: BaseETDRK

Source code in exponax/etdrk/_etdrk_3.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class ETDRK3(BaseETDRK):
    _nonlinear_fun: BaseNonlinearFun
    _half_exp_term: Complex[Array, "E ... (N//2)+1"]
    _coef_1: Complex[Array, "E ... (N//2)+1"]
    _coef_2: Complex[Array, "E ... (N//2)+1"]
    _coef_3: Complex[Array, "E ... (N//2)+1"]
    _coef_4: Complex[Array, "E ... (N//2)+1"]
    _coef_5: Complex[Array, "E ... (N//2)+1"]

    def __init__(
        self,
        dt: float,
        linear_operator: Complex[Array, "E ... (N//2)+1"],
        nonlinear_fun: BaseNonlinearFun,
        *,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        super().__init__(dt, linear_operator)
        self._nonlinear_fun = nonlinear_fun
        self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)

        LR = (
            circle_radius * roots_of_unity(num_circle_points)
            + linear_operator[..., jnp.newaxis] * dt
        )

        self._coef_1 = dt * jnp.mean((jnp.exp(LR / 2) - 1) / LR, axis=-1).real

        self._coef_2 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real

        self._coef_3 = (
            dt
            * jnp.mean(
                (-4 - LR + jnp.exp(LR) * (4 - 3 * LR + LR**2)) / (LR**3), axis=-1
            ).real
        )

        self._coef_4 = (
            dt
            * jnp.mean(
                (4.0 * (2.0 + LR + jnp.exp(LR) * (-2 + LR))) / (LR**3), axis=-1
            ).real
        )

        self._coef_5 = (
            dt
            * jnp.mean(
                (-4 - 3 * LR - LR**2 + jnp.exp(LR) * (4 - LR)) / (LR**3), axis=-1
            ).real
        )

    def step_fourier(
        self,
        u_hat: Complex[Array, "E ... (N//2)+1"],
    ) -> Complex[Array, "E ... (N//2)+1"]:
        u_nonlin_hat = self._nonlinear_fun(u_hat)
        u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat

        u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
        u_stage_2_hat = self._exp_term * u_hat + self._coef_2 * (
            2 * u_stage_1_nonlin_hat - u_nonlin_hat
        )

        u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat)

        u_next_hat = (
            self._exp_term * u_hat
            + self._coef_3 * u_nonlin_hat
            + self._coef_4 * u_stage_1_nonlin_hat
            + self._coef_5 * u_stage_2_nonlin_hat
        )

        return u_next_hat
__init__ ¤
__init__(
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0
)
Source code in exponax/etdrk/_etdrk_3.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def __init__(
    self,
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0,
):
    super().__init__(dt, linear_operator)
    self._nonlinear_fun = nonlinear_fun
    self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)

    LR = (
        circle_radius * roots_of_unity(num_circle_points)
        + linear_operator[..., jnp.newaxis] * dt
    )

    self._coef_1 = dt * jnp.mean((jnp.exp(LR / 2) - 1) / LR, axis=-1).real

    self._coef_2 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real

    self._coef_3 = (
        dt
        * jnp.mean(
            (-4 - LR + jnp.exp(LR) * (4 - 3 * LR + LR**2)) / (LR**3), axis=-1
        ).real
    )

    self._coef_4 = (
        dt
        * jnp.mean(
            (4.0 * (2.0 + LR + jnp.exp(LR) * (-2 + LR))) / (LR**3), axis=-1
        ).real
    )

    self._coef_5 = (
        dt
        * jnp.mean(
            (-4 - 3 * LR - LR**2 + jnp.exp(LR) * (4 - LR)) / (LR**3), axis=-1
        ).real
    )
step_fourier ¤
step_fourier(
    u_hat: Complex[Array, "E ... (N//2)+1"]
) -> Complex[Array, "E ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_3.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def step_fourier(
    self,
    u_hat: Complex[Array, "E ... (N//2)+1"],
) -> Complex[Array, "E ... (N//2)+1"]:
    u_nonlin_hat = self._nonlinear_fun(u_hat)
    u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat

    u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
    u_stage_2_hat = self._exp_term * u_hat + self._coef_2 * (
        2 * u_stage_1_nonlin_hat - u_nonlin_hat
    )

    u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat)

    u_next_hat = (
        self._exp_term * u_hat
        + self._coef_3 * u_nonlin_hat
        + self._coef_4 * u_stage_1_nonlin_hat
        + self._coef_5 * u_stage_2_nonlin_hat
    )

    return u_next_hat

exponax.etdrk.ETDRK4 ¤

Bases: BaseETDRK

Source code in exponax/etdrk/_etdrk_4.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class ETDRK4(BaseETDRK):
    _nonlinear_fun: BaseNonlinearFun
    _half_exp_term: Complex[Array, "E ... (N//2)+1"]
    _coef_1: Complex[Array, "E ... (N//2)+1"]
    _coef_2: Complex[Array, "E ... (N//2)+1"]
    _coef_3: Complex[Array, "E ... (N//2)+1"]
    _coef_4: Complex[Array, "E ... (N//2)+1"]
    _coef_5: Complex[Array, "E ... (N//2)+1"]
    _coef_6: Complex[Array, "E ... (N//2)+1"]

    def __init__(
        self,
        dt: float,
        linear_operator: Complex[Array, "E ... (N//2)+1"],
        nonlinear_fun: BaseNonlinearFun,
        *,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        super().__init__(dt, linear_operator)
        self._nonlinear_fun = nonlinear_fun
        self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)

        LR = (
            circle_radius * roots_of_unity(num_circle_points)
            + linear_operator[..., jnp.newaxis] * dt
        )

        self._coef_1 = dt * jnp.mean((jnp.exp(LR / 2) - 1) / LR, axis=-1).real

        self._coef_2 = self._coef_1
        self._coef_3 = self._coef_1

        self._coef_4 = (
            dt
            * jnp.mean(
                (-4 - LR + jnp.exp(LR) * (4 - 3 * LR + LR**2)) / (LR**3), axis=-1
            ).real
        )

        self._coef_5 = (
            dt * jnp.mean((2 + LR + jnp.exp(LR) * (-2 + LR)) / (LR**3), axis=-1).real
        )

        self._coef_6 = (
            dt
            * jnp.mean(
                (-4 - 3 * LR - LR**2 + jnp.exp(LR) * (4 - LR)) / (LR**3), axis=-1
            ).real
        )

    def step_fourier(
        self,
        u_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        u_nonlin_hat = self._nonlinear_fun(u_hat)
        u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat

        u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
        u_stage_2_hat = (
            self._half_exp_term * u_hat + self._coef_2 * u_stage_1_nonlin_hat
        )

        u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat)
        u_stage_3_hat = self._half_exp_term * u_stage_1_hat + self._coef_3 * (
            2 * u_stage_2_nonlin_hat - u_nonlin_hat
        )

        u_stage_3_nonlin_hat = self._nonlinear_fun(u_stage_3_hat)

        u_next_hat = (
            self._exp_term * u_hat
            + self._coef_4 * u_nonlin_hat
            + self._coef_5 * 2 * (u_stage_1_nonlin_hat + u_stage_2_nonlin_hat)
            + self._coef_6 * u_stage_3_nonlin_hat
        )

        return u_next_hat
__init__ ¤
__init__(
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0
)
Source code in exponax/etdrk/_etdrk_4.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self,
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0,
):
    super().__init__(dt, linear_operator)
    self._nonlinear_fun = nonlinear_fun
    self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)

    LR = (
        circle_radius * roots_of_unity(num_circle_points)
        + linear_operator[..., jnp.newaxis] * dt
    )

    self._coef_1 = dt * jnp.mean((jnp.exp(LR / 2) - 1) / LR, axis=-1).real

    self._coef_2 = self._coef_1
    self._coef_3 = self._coef_1

    self._coef_4 = (
        dt
        * jnp.mean(
            (-4 - LR + jnp.exp(LR) * (4 - 3 * LR + LR**2)) / (LR**3), axis=-1
        ).real
    )

    self._coef_5 = (
        dt * jnp.mean((2 + LR + jnp.exp(LR) * (-2 + LR)) / (LR**3), axis=-1).real
    )

    self._coef_6 = (
        dt
        * jnp.mean(
            (-4 - 3 * LR - LR**2 + jnp.exp(LR) * (4 - LR)) / (LR**3), axis=-1
        ).real
    )
step_fourier ¤
step_fourier(
    u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_4.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def step_fourier(
    self,
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
    u_nonlin_hat = self._nonlinear_fun(u_hat)
    u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat

    u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
    u_stage_2_hat = (
        self._half_exp_term * u_hat + self._coef_2 * u_stage_1_nonlin_hat
    )

    u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat)
    u_stage_3_hat = self._half_exp_term * u_stage_1_hat + self._coef_3 * (
        2 * u_stage_2_nonlin_hat - u_nonlin_hat
    )

    u_stage_3_nonlin_hat = self._nonlinear_fun(u_stage_3_hat)

    u_next_hat = (
        self._exp_term * u_hat
        + self._coef_4 * u_nonlin_hat
        + self._coef_5 * 2 * (u_stage_1_nonlin_hat + u_stage_2_nonlin_hat)
        + self._coef_6 * u_stage_3_nonlin_hat
    )

    return u_next_hat

exponax.etdrk.BaseETDRK ¤

Bases: Module, ABC

Source code in exponax/etdrk/_base_etdrk.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class BaseETDRK(eqx.Module, ABC):
    dt: float
    _exp_term: Complex[Array, "E ... (N//2)+1"]

    def __init__(
        self,
        dt: float,
        linear_operator: Complex[Array, "E ... (N//2)+1"],
    ):
        self.dt = dt
        self._exp_term = jnp.exp(self.dt * linear_operator)

    @abstractmethod
    def step_fourier(
        self,
        u_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        """
        Advance the state in Fourier space.
        """
        pass
__init__ ¤
__init__(
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
)
Source code in exponax/etdrk/_base_etdrk.py
19
20
21
22
23
24
25
def __init__(
    self,
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
):
    self.dt = dt
    self._exp_term = jnp.exp(self.dt * linear_operator)
step_fourier abstractmethod ¤
step_fourier(
    u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]

Advance the state in Fourier space.

Source code in exponax/etdrk/_base_etdrk.py
27
28
29
30
31
32
33
34
35
@abstractmethod
def step_fourier(
    self,
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
    """
    Advance the state in Fourier space.
    """
    pass

exponax.etdrk.roots_of_unity ¤

roots_of_unity(M: int) -> Complex[Array, M]

Return (complex-valued) array with M roots of unity.

Source code in exponax/etdrk/_utils.py
 9
10
11
12
13
14
def roots_of_unity(M: int) -> Complex[Array, "M"]:
    """
    Return (complex-valued) array with M roots of unity.
    """
    # return jnp.exp(1j * jnp.pi * (jnp.arange(1, M+1) - 0.5) / M)
    return jnp.exp(2j * jnp.pi * (jnp.arange(1, M + 1) - 0.5) / M)