Skip to content

ETDRK Backbone¤

Core classes that implement the Exponential Time Differencing Runge-Kutta (ETDRK) method for solving semi-linear PDEs in form of timesteppers. Require supplying the time step size \(\Delta t\), the linear operator in Fourier space \(\hat{\mathcal{L}}_h\), and the non-linear operator in Fourier space \(\hat{\mathcal{N}}_h\).

exponax.etdrk.ETDRK0 ¤

Bases: BaseETDRK

Source code in exponax/etdrk/_etdrk_0.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class ETDRK0(BaseETDRK):
    def __init__(
        self,
        dt: float,
        linear_operator: Complex[Array, "E ... (N//2)+1"],
    ):
        r"""
        Exactly solve a linear PDE in Fourier space.

        $$
            \hat{u}_h^{[t+1]} = \exp(\hat{\mathcal{L}}_h \Delta t) \odot
            \hat{u}_h^{[t]}
        $$

        **Arguments:**

        - `dt`: The time step size.
        - `linear_operator`: The linear operator of the PDE. Must have a leading
            channel axis, followed by one, two or three spatial axes whereas the
            last axis must be of size `(N//2)+1` where `N` is the number of
            dimensions in the former spatial axes.
        """
        super().__init__(dt, linear_operator)

    def step_fourier(
        self,
        u_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        return self._exp_term * u_hat
__init__ ¤
__init__(
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
)

Exactly solve a linear PDE in Fourier space.

\[ \hat{u}_h^{[t+1]} = \exp(\hat{\mathcal{L}}_h \Delta t) \odot \hat{u}_h^{[t]} \]

Arguments:

  • dt: The time step size.
  • linear_operator: The linear operator of the PDE. Must have a leading channel axis, followed by one, two or three spatial axes whereas the last axis must be of size (N//2)+1 where N is the number of dimensions in the former spatial axes.
Source code in exponax/etdrk/_etdrk_0.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(
    self,
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
):
    r"""
    Exactly solve a linear PDE in Fourier space.

    $$
        \hat{u}_h^{[t+1]} = \exp(\hat{\mathcal{L}}_h \Delta t) \odot
        \hat{u}_h^{[t]}
    $$

    **Arguments:**

    - `dt`: The time step size.
    - `linear_operator`: The linear operator of the PDE. Must have a leading
        channel axis, followed by one, two or three spatial axes whereas the
        last axis must be of size `(N//2)+1` where `N` is the number of
        dimensions in the former spatial axes.
    """
    super().__init__(dt, linear_operator)
step_fourier ¤
step_fourier(
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_0.py
30
31
32
33
34
def step_fourier(
    self,
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
    return self._exp_term * u_hat

exponax.etdrk.ETDRK1 ¤

Bases: BaseETDRK

Source code in exponax/etdrk/_etdrk_1.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class ETDRK1(BaseETDRK):
    _nonlinear_fun: BaseNonlinearFun
    _coef_1: Complex[Array, "E ... (N//2)+1"]

    def __init__(
        self,
        dt: float,
        linear_operator: Complex[Array, "E ... (N//2)+1"],
        nonlinear_fun: BaseNonlinearFun,
        *,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        r"""
        Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta
        with a **first order approximation**.

        Adapted from Eq. (4) of [Cox and Matthews
        (2002)](https://doi.org/10.1006/jcph.2002.6995):

        $$
            \hat{u}_h^{[t+1]} = \exp(\hat{\mathcal{L}}_h \Delta t) \odot
            \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) -
            1}{\hat{\mathcal{L}}_h} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]})
        $$

        where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of
        the nonlinear differential operator.

        **Arguments:**

        - `dt`: The time step size.
        - `linear_operator`: The linear operator of the PDE. Must have a leading
            channel axis, followed by one, two or three spatial axes whereas the
            last axis must be of size `(N//2)+1` where `N` is the number of
            dimensions in the former spatial axes.
        - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the
            nonlinear differential operator.
        - `num_circle_points`: The number of points on the unit circle used to
            approximate the numerically challenging coefficients.
        - `circle_radius`: The radius of the circle used to approximate the
            numerically challenging coefficients.

        !!! warning
            The nonlinear function must take care of proper dealiasing.
            `BaseNonlinearFun` handles this automatically via its `fft` and
            `ifft` methods which apply pre- and post-dealiasing.

        !!! note
            The numerically stable evaluation of the coefficients follows
            [Kassam and Trefethen
            (2005)](https://doi.org/10.1137/S1064827502410633).
        """
        super().__init__(dt, linear_operator)
        self._nonlinear_fun = nonlinear_fun

        roots = roots_of_unity(num_circle_points)
        L_dt = linear_operator * dt

        def scan_body(acc, root):
            lr = circle_radius * root + L_dt
            exp_lr = jnp.exp(lr)
            return acc + ((exp_lr - 1) / lr).real, None

        sum_c1, _ = jax.lax.scan(scan_body, jnp.zeros_like(L_dt.real), roots)
        mean_c1 = sum_c1 / num_circle_points
        self._coef_1 = dt * mean_c1

    def step_fourier(
        self,
        u_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        return self._exp_term * u_hat + self._coef_1 * self._nonlinear_fun(u_hat)
__init__ ¤
__init__(
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0
)

Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta with a first order approximation.

Adapted from Eq. (4) of Cox and Matthews (2002):

\[ \hat{u}_h^{[t+1]} = \exp(\hat{\mathcal{L}}_h \Delta t) \odot \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) - 1}{\hat{\mathcal{L}}_h} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \]

where \(\hat{\mathcal{N}}_h\) is the Fourier pseudo-spectral treatment of the nonlinear differential operator.

Arguments:

  • dt: The time step size.
  • linear_operator: The linear operator of the PDE. Must have a leading channel axis, followed by one, two or three spatial axes whereas the last axis must be of size (N//2)+1 where N is the number of dimensions in the former spatial axes.
  • nonlinear_fun: The Fourier pseudo-spectral treatment of the nonlinear differential operator.
  • num_circle_points: The number of points on the unit circle used to approximate the numerically challenging coefficients.
  • circle_radius: The radius of the circle used to approximate the numerically challenging coefficients.

Warning

The nonlinear function must take care of proper dealiasing. BaseNonlinearFun handles this automatically via its fft and ifft methods which apply pre- and post-dealiasing.

Note

The numerically stable evaluation of the coefficients follows Kassam and Trefethen (2005).

Source code in exponax/etdrk/_etdrk_1.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
    self,
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0,
):
    r"""
    Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta
    with a **first order approximation**.

    Adapted from Eq. (4) of [Cox and Matthews
    (2002)](https://doi.org/10.1006/jcph.2002.6995):

    $$
        \hat{u}_h^{[t+1]} = \exp(\hat{\mathcal{L}}_h \Delta t) \odot
        \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) -
        1}{\hat{\mathcal{L}}_h} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]})
    $$

    where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of
    the nonlinear differential operator.

    **Arguments:**

    - `dt`: The time step size.
    - `linear_operator`: The linear operator of the PDE. Must have a leading
        channel axis, followed by one, two or three spatial axes whereas the
        last axis must be of size `(N//2)+1` where `N` is the number of
        dimensions in the former spatial axes.
    - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the
        nonlinear differential operator.
    - `num_circle_points`: The number of points on the unit circle used to
        approximate the numerically challenging coefficients.
    - `circle_radius`: The radius of the circle used to approximate the
        numerically challenging coefficients.

    !!! warning
        The nonlinear function must take care of proper dealiasing.
        `BaseNonlinearFun` handles this automatically via its `fft` and
        `ifft` methods which apply pre- and post-dealiasing.

    !!! note
        The numerically stable evaluation of the coefficients follows
        [Kassam and Trefethen
        (2005)](https://doi.org/10.1137/S1064827502410633).
    """
    super().__init__(dt, linear_operator)
    self._nonlinear_fun = nonlinear_fun

    roots = roots_of_unity(num_circle_points)
    L_dt = linear_operator * dt

    def scan_body(acc, root):
        lr = circle_radius * root + L_dt
        exp_lr = jnp.exp(lr)
        return acc + ((exp_lr - 1) / lr).real, None

    sum_c1, _ = jax.lax.scan(scan_body, jnp.zeros_like(L_dt.real), roots)
    mean_c1 = sum_c1 / num_circle_points
    self._coef_1 = dt * mean_c1
step_fourier ¤
step_fourier(
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_1.py
78
79
80
81
82
def step_fourier(
    self,
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
    return self._exp_term * u_hat + self._coef_1 * self._nonlinear_fun(u_hat)

exponax.etdrk.ETDRK2 ¤

Bases: BaseETDRK

Source code in exponax/etdrk/_etdrk_2.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class ETDRK2(BaseETDRK):
    _nonlinear_fun: BaseNonlinearFun
    _coef_1: Complex[Array, "E ... (N//2)+1"]
    _coef_2: Complex[Array, "E ... (N//2)+1"]

    def __init__(
        self,
        dt: float,
        linear_operator: Complex[Array, "E ... (N//2)+1"],
        nonlinear_fun: BaseNonlinearFun,
        *,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        r"""
        Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta
        with a **second order approximation**.

        Adapted from Eq. (22) of [Cox and Matthews
        (2002)](https://doi.org/10.1006/jcph.2002.6995):

        $$
            \begin{aligned}
                \hat{u}_h^* &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot
                \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) -
                1}{\hat{\mathcal{L}}_h} \odot
                \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}). \\ \hat{u}_h^{[t+1]} &=
                \hat{u}_h^* + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) - 1 -
                \hat{\mathcal{L}}_h \Delta t}{\hat{\mathcal{L}}_h^2 \Delta t}
                \left( \hat{\mathcal{N}}_h(\hat{u}_h^*) -
                \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \right)
            \end{aligned}
        $$

        where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of
        the nonlinear differential operator.

        **Arguments:**

        - `dt`: The time step size.
        - `linear_operator`: The linear operator of the PDE. Must have a leading
            channel axis, followed by one, two or three spatial axes whereas the
            last axis must be of size `(N//2)+1` where `N` is the number of
            dimensions in the former spatial axes.
        - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the
            nonlinear differential operator.
        - `num_circle_points`: The number of points on the unit circle used to
            approximate the numerically challenging coefficients.
        - `circle_radius`: The radius of the circle used to approximate the
            numerically challenging coefficients.

        !!! warning
            The nonlinear function must take care of proper dealiasing.
            `BaseNonlinearFun` handles this automatically via its `fft` and
            `ifft` methods which apply pre- and post-dealiasing.

        !!! note
            The numerically stable evaluation of the coefficients follows
            [Kassam and Trefethen
            (2005)](https://doi.org/10.1137/S1064827502410633).
        """
        super().__init__(dt, linear_operator)
        self._nonlinear_fun = nonlinear_fun

        roots = roots_of_unity(num_circle_points)
        L_dt = linear_operator * dt

        def scan_body(accs, root):
            lr = circle_radius * root + L_dt
            exp_lr = jnp.exp(lr)
            c1 = ((exp_lr - 1) / lr).real
            c2 = ((exp_lr - 1 - lr) / lr**2).real
            return (accs[0] + c1, accs[1] + c2), None

        zeros = jnp.zeros_like(L_dt.real)
        (sum_c1, sum_c2), _ = jax.lax.scan(scan_body, (zeros, zeros), roots)
        mean_c1 = sum_c1 / num_circle_points
        mean_c2 = sum_c2 / num_circle_points
        self._coef_1 = dt * mean_c1
        self._coef_2 = dt * mean_c2

    def step_fourier(
        self,
        u_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        u_nonlin_hat = self._nonlinear_fun(u_hat)
        u_stage_1_hat = self._exp_term * u_hat + self._coef_1 * u_nonlin_hat

        u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
        u_next_hat = u_stage_1_hat + self._coef_2 * (
            u_stage_1_nonlin_hat - u_nonlin_hat
        )
        return u_next_hat
__init__ ¤
__init__(
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0
)

Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta with a second order approximation.

Adapted from Eq. (22) of Cox and Matthews (2002):

\[ \begin{aligned} \hat{u}_h^* &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) - 1}{\hat{\mathcal{L}}_h} \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}). \\ \hat{u}_h^{[t+1]} &= \hat{u}_h^* + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) - 1 - \hat{\mathcal{L}}_h \Delta t}{\hat{\mathcal{L}}_h^2 \Delta t} \left( \hat{\mathcal{N}}_h(\hat{u}_h^*) - \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \right) \end{aligned} \]

where \(\hat{\mathcal{N}}_h\) is the Fourier pseudo-spectral treatment of the nonlinear differential operator.

Arguments:

  • dt: The time step size.
  • linear_operator: The linear operator of the PDE. Must have a leading channel axis, followed by one, two or three spatial axes whereas the last axis must be of size (N//2)+1 where N is the number of dimensions in the former spatial axes.
  • nonlinear_fun: The Fourier pseudo-spectral treatment of the nonlinear differential operator.
  • num_circle_points: The number of points on the unit circle used to approximate the numerically challenging coefficients.
  • circle_radius: The radius of the circle used to approximate the numerically challenging coefficients.

Warning

The nonlinear function must take care of proper dealiasing. BaseNonlinearFun handles this automatically via its fft and ifft methods which apply pre- and post-dealiasing.

Note

The numerically stable evaluation of the coefficients follows Kassam and Trefethen (2005).

Source code in exponax/etdrk/_etdrk_2.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self,
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0,
):
    r"""
    Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta
    with a **second order approximation**.

    Adapted from Eq. (22) of [Cox and Matthews
    (2002)](https://doi.org/10.1006/jcph.2002.6995):

    $$
        \begin{aligned}
            \hat{u}_h^* &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot
            \hat{u}_h^{[t]} + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) -
            1}{\hat{\mathcal{L}}_h} \odot
            \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}). \\ \hat{u}_h^{[t+1]} &=
            \hat{u}_h^* + \frac{\exp(\hat{\mathcal{L}}_h \Delta t) - 1 -
            \hat{\mathcal{L}}_h \Delta t}{\hat{\mathcal{L}}_h^2 \Delta t}
            \left( \hat{\mathcal{N}}_h(\hat{u}_h^*) -
            \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \right)
        \end{aligned}
    $$

    where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of
    the nonlinear differential operator.

    **Arguments:**

    - `dt`: The time step size.
    - `linear_operator`: The linear operator of the PDE. Must have a leading
        channel axis, followed by one, two or three spatial axes whereas the
        last axis must be of size `(N//2)+1` where `N` is the number of
        dimensions in the former spatial axes.
    - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the
        nonlinear differential operator.
    - `num_circle_points`: The number of points on the unit circle used to
        approximate the numerically challenging coefficients.
    - `circle_radius`: The radius of the circle used to approximate the
        numerically challenging coefficients.

    !!! warning
        The nonlinear function must take care of proper dealiasing.
        `BaseNonlinearFun` handles this automatically via its `fft` and
        `ifft` methods which apply pre- and post-dealiasing.

    !!! note
        The numerically stable evaluation of the coefficients follows
        [Kassam and Trefethen
        (2005)](https://doi.org/10.1137/S1064827502410633).
    """
    super().__init__(dt, linear_operator)
    self._nonlinear_fun = nonlinear_fun

    roots = roots_of_unity(num_circle_points)
    L_dt = linear_operator * dt

    def scan_body(accs, root):
        lr = circle_radius * root + L_dt
        exp_lr = jnp.exp(lr)
        c1 = ((exp_lr - 1) / lr).real
        c2 = ((exp_lr - 1 - lr) / lr**2).real
        return (accs[0] + c1, accs[1] + c2), None

    zeros = jnp.zeros_like(L_dt.real)
    (sum_c1, sum_c2), _ = jax.lax.scan(scan_body, (zeros, zeros), roots)
    mean_c1 = sum_c1 / num_circle_points
    mean_c2 = sum_c2 / num_circle_points
    self._coef_1 = dt * mean_c1
    self._coef_2 = dt * mean_c2
step_fourier ¤
step_fourier(
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_2.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def step_fourier(
    self,
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
    u_nonlin_hat = self._nonlinear_fun(u_hat)
    u_stage_1_hat = self._exp_term * u_hat + self._coef_1 * u_nonlin_hat

    u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
    u_next_hat = u_stage_1_hat + self._coef_2 * (
        u_stage_1_nonlin_hat - u_nonlin_hat
    )
    return u_next_hat

exponax.etdrk.ETDRK3 ¤

Bases: BaseETDRK

Source code in exponax/etdrk/_etdrk_3.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
class ETDRK3(BaseETDRK):
    _nonlinear_fun: BaseNonlinearFun
    _half_exp_term: Complex[Array, "E ... (N//2)+1"]
    _coef_1: Complex[Array, "E ... (N//2)+1"]
    _coef_2: Complex[Array, "E ... (N//2)+1"]
    _coef_3: Complex[Array, "E ... (N//2)+1"]
    _coef_4: Complex[Array, "E ... (N//2)+1"]
    _coef_5: Complex[Array, "E ... (N//2)+1"]

    def __init__(
        self,
        dt: float,
        linear_operator: Complex[Array, "E ... (N//2)+1"],
        nonlinear_fun: BaseNonlinearFun,
        *,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        r"""
        Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta
        with a **third order approximation**.

        Adapted from Eq. (23-25) of [Cox and Matthews
        (2002)](https://doi.org/10.1006/jcph.2002.6995):

        $$
        \begin{aligned}
            \hat{u}_h^*
            &=
            \exp(\hat{\mathcal{L}}_h \Delta t / 2)
            \odot
            \hat{u}_h^{[t]}
            +
            \frac{
                \exp(\hat{\mathcal{L}}_h \Delta t/2) - 1
            }{
                \hat{\mathcal{L}}_h
            }
            \odot
            \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}).
            \\
            \hat{u}_h^{**}
            &=
            \exp(\hat{\mathcal{L}}_h \Delta t / 2)
            \odot
            \hat{u}_h^{[t]}
            +
            \frac{
                \exp(\hat{\mathcal{L}}_h \Delta t) - 1
            }{
                \hat{\mathcal{L}}_h
            }
            \odot
            \left(
                2 \hat{\mathcal{N}}_h(\hat{u}_h^*)
                -
                \hat{\mathcal{N}}_h(\hat{u}_h^{[t]})
            \right).
            \\
            \hat{u}_h^{[t+1]}
            &=
            \exp(\hat{\mathcal{L}}_h \Delta t)
            \odot
            \hat{u}_h^{[t]}
            \\
            &+
            \frac{
                -4 - \exp(\hat{\mathcal{L}}_h \Delta t)
                +
                \exp(\hat{\mathcal{L}}_h \Delta)
                \left(
                    4 - 3 \hat{\mathcal{L}}_h \Delta t
                    +
                    \left(
                        \hat{\mathcal{L}}_h \Delta t
                    \right)^2
                \right)
            }{
                \hat{\mathcal{L}}_h^3 (\Delta t)^2
            }
            \odot
            \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}).
            \\
            &+
            4 \frac{
                2 + \hat{\mathcal{L}}_h \Delta t
                +
                \exp(\hat{\mathcal{L}}_h \Delta t)
                \left(
                    -2 + \hat{\mathcal{L}}_h \Delta t
                \right)
            }{
                \hat{\mathcal{L}}_h^3 (\Delta t)^2
            }
            \odot
            \hat{\mathcal{N}}_h(\hat{u}_h^*)
            \\
            &+
            \frac{
                -4 - 3 \hat{\mathcal{L}}_h \Delta t
                -
                \left(
                    \hat{\mathcal{L}}_h \Delta t
                \right)^2
                +
                \exp(\hat{\mathcal{L}}_h \Delta t)
                \left(
                    4 - \hat{\mathcal{L}}_h \Delta t
                \right)
            }{
                \hat{\mathcal{L}}_h^3 (\Delta t)^2
            }
            \odot
            \hat{\mathcal{N}}_h(\hat{u}_h^{**})
        \end{aligned}
        $$

        where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of
        the nonlinear differential operator.

        **Arguments:**

        - `dt`: The time step size.
        - `linear_operator`: The linear operator of the PDE. Must have a leading
            channel axis, followed by one, two or three spatial axes whereas the
            last axis must be of size `(N//2)+1` where `N` is the number of
            dimensions in the former spatial axes.
        - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the
            nonlinear differential operator.
        - `num_circle_points`: The number of points on the unit circle used to
            approximate the numerically challenging coefficients.
        - `circle_radius`: The radius of the circle used to approximate the
            numerically challenging coefficients.

        !!! warning
            The nonlinear function must take care of proper dealiasing.
            `BaseNonlinearFun` handles this automatically via its `fft` and
            `ifft` methods which apply pre- and post-dealiasing.

        !!! note
            The numerically stable evaluation of the coefficients follows
            [Kassam and Trefethen
            (2005)](https://doi.org/10.1137/S1064827502410633).
        """
        super().__init__(dt, linear_operator)
        self._nonlinear_fun = nonlinear_fun
        self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)

        roots = roots_of_unity(num_circle_points)
        L_dt = linear_operator * dt

        def scan_body(accs, root):
            lr = circle_radius * root + L_dt
            exp_lr = jnp.exp(lr)
            exp_lr_half = jnp.exp(lr / 2)
            c1 = ((exp_lr_half - 1) / lr).real
            c2 = ((exp_lr - 1) / lr).real
            c3 = ((-4 - lr + exp_lr * (4 - 3 * lr + lr**2)) / lr**3).real
            c4 = ((4.0 * (2.0 + lr + exp_lr * (-2 + lr))) / lr**3).real
            c5 = ((-4 - 3 * lr - lr**2 + exp_lr * (4 - lr)) / lr**3).real
            return (
                accs[0] + c1,
                accs[1] + c2,
                accs[2] + c3,
                accs[3] + c4,
                accs[4] + c5,
            ), None

        zeros = jnp.zeros_like(L_dt.real)
        (s1, s2, s3, s4, s5), _ = jax.lax.scan(scan_body, (zeros,) * 5, roots)
        mean_c1 = s1 / num_circle_points
        mean_c2 = s2 / num_circle_points
        mean_c3 = s3 / num_circle_points
        mean_c4 = s4 / num_circle_points
        mean_c5 = s5 / num_circle_points
        self._coef_1 = dt * mean_c1
        self._coef_2 = dt * mean_c2
        self._coef_3 = dt * mean_c3
        self._coef_4 = dt * mean_c4
        self._coef_5 = dt * mean_c5

    def step_fourier(
        self,
        u_hat: Complex[Array, "E ... (N//2)+1"],
    ) -> Complex[Array, "E ... (N//2)+1"]:
        u_nonlin_hat = self._nonlinear_fun(u_hat)
        u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat

        u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
        u_stage_2_hat = self._exp_term * u_hat + self._coef_2 * (
            2 * u_stage_1_nonlin_hat - u_nonlin_hat
        )

        u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat)

        u_next_hat = (
            self._exp_term * u_hat
            + self._coef_3 * u_nonlin_hat
            + self._coef_4 * u_stage_1_nonlin_hat
            + self._coef_5 * u_stage_2_nonlin_hat
        )

        return u_next_hat
__init__ ¤
__init__(
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0
)

Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta with a third order approximation.

Adapted from Eq. (23-25) of Cox and Matthews (2002):

\[ \begin{aligned} \hat{u}_h^* &= \exp(\hat{\mathcal{L}}_h \Delta t / 2) \odot \hat{u}_h^{[t]} + \frac{ \exp(\hat{\mathcal{L}}_h \Delta t/2) - 1 }{ \hat{\mathcal{L}}_h } \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}). \\ \hat{u}_h^{**} &= \exp(\hat{\mathcal{L}}_h \Delta t / 2) \odot \hat{u}_h^{[t]} + \frac{ \exp(\hat{\mathcal{L}}_h \Delta t) - 1 }{ \hat{\mathcal{L}}_h } \odot \left( 2 \hat{\mathcal{N}}_h(\hat{u}_h^*) - \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \right). \\ \hat{u}_h^{[t+1]} &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot \hat{u}_h^{[t]} \\ &+ \frac{ -4 - \exp(\hat{\mathcal{L}}_h \Delta t) + \exp(\hat{\mathcal{L}}_h \Delta) \left( 4 - 3 \hat{\mathcal{L}}_h \Delta t + \left( \hat{\mathcal{L}}_h \Delta t \right)^2 \right) }{ \hat{\mathcal{L}}_h^3 (\Delta t)^2 } \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}). \\ &+ 4 \frac{ 2 + \hat{\mathcal{L}}_h \Delta t + \exp(\hat{\mathcal{L}}_h \Delta t) \left( -2 + \hat{\mathcal{L}}_h \Delta t \right) }{ \hat{\mathcal{L}}_h^3 (\Delta t)^2 } \odot \hat{\mathcal{N}}_h(\hat{u}_h^*) \\ &+ \frac{ -4 - 3 \hat{\mathcal{L}}_h \Delta t - \left( \hat{\mathcal{L}}_h \Delta t \right)^2 + \exp(\hat{\mathcal{L}}_h \Delta t) \left( 4 - \hat{\mathcal{L}}_h \Delta t \right) }{ \hat{\mathcal{L}}_h^3 (\Delta t)^2 } \odot \hat{\mathcal{N}}_h(\hat{u}_h^{**}) \end{aligned} \]

where \(\hat{\mathcal{N}}_h\) is the Fourier pseudo-spectral treatment of the nonlinear differential operator.

Arguments:

  • dt: The time step size.
  • linear_operator: The linear operator of the PDE. Must have a leading channel axis, followed by one, two or three spatial axes whereas the last axis must be of size (N//2)+1 where N is the number of dimensions in the former spatial axes.
  • nonlinear_fun: The Fourier pseudo-spectral treatment of the nonlinear differential operator.
  • num_circle_points: The number of points on the unit circle used to approximate the numerically challenging coefficients.
  • circle_radius: The radius of the circle used to approximate the numerically challenging coefficients.

Warning

The nonlinear function must take care of proper dealiasing. BaseNonlinearFun handles this automatically via its fft and ifft methods which apply pre- and post-dealiasing.

Note

The numerically stable evaluation of the coefficients follows Kassam and Trefethen (2005).

Source code in exponax/etdrk/_etdrk_3.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def __init__(
    self,
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0,
):
    r"""
    Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta
    with a **third order approximation**.

    Adapted from Eq. (23-25) of [Cox and Matthews
    (2002)](https://doi.org/10.1006/jcph.2002.6995):

    $$
    \begin{aligned}
        \hat{u}_h^*
        &=
        \exp(\hat{\mathcal{L}}_h \Delta t / 2)
        \odot
        \hat{u}_h^{[t]}
        +
        \frac{
            \exp(\hat{\mathcal{L}}_h \Delta t/2) - 1
        }{
            \hat{\mathcal{L}}_h
        }
        \odot
        \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}).
        \\
        \hat{u}_h^{**}
        &=
        \exp(\hat{\mathcal{L}}_h \Delta t / 2)
        \odot
        \hat{u}_h^{[t]}
        +
        \frac{
            \exp(\hat{\mathcal{L}}_h \Delta t) - 1
        }{
            \hat{\mathcal{L}}_h
        }
        \odot
        \left(
            2 \hat{\mathcal{N}}_h(\hat{u}_h^*)
            -
            \hat{\mathcal{N}}_h(\hat{u}_h^{[t]})
        \right).
        \\
        \hat{u}_h^{[t+1]}
        &=
        \exp(\hat{\mathcal{L}}_h \Delta t)
        \odot
        \hat{u}_h^{[t]}
        \\
        &+
        \frac{
            -4 - \exp(\hat{\mathcal{L}}_h \Delta t)
            +
            \exp(\hat{\mathcal{L}}_h \Delta)
            \left(
                4 - 3 \hat{\mathcal{L}}_h \Delta t
                +
                \left(
                    \hat{\mathcal{L}}_h \Delta t
                \right)^2
            \right)
        }{
            \hat{\mathcal{L}}_h^3 (\Delta t)^2
        }
        \odot
        \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}).
        \\
        &+
        4 \frac{
            2 + \hat{\mathcal{L}}_h \Delta t
            +
            \exp(\hat{\mathcal{L}}_h \Delta t)
            \left(
                -2 + \hat{\mathcal{L}}_h \Delta t
            \right)
        }{
            \hat{\mathcal{L}}_h^3 (\Delta t)^2
        }
        \odot
        \hat{\mathcal{N}}_h(\hat{u}_h^*)
        \\
        &+
        \frac{
            -4 - 3 \hat{\mathcal{L}}_h \Delta t
            -
            \left(
                \hat{\mathcal{L}}_h \Delta t
            \right)^2
            +
            \exp(\hat{\mathcal{L}}_h \Delta t)
            \left(
                4 - \hat{\mathcal{L}}_h \Delta t
            \right)
        }{
            \hat{\mathcal{L}}_h^3 (\Delta t)^2
        }
        \odot
        \hat{\mathcal{N}}_h(\hat{u}_h^{**})
    \end{aligned}
    $$

    where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of
    the nonlinear differential operator.

    **Arguments:**

    - `dt`: The time step size.
    - `linear_operator`: The linear operator of the PDE. Must have a leading
        channel axis, followed by one, two or three spatial axes whereas the
        last axis must be of size `(N//2)+1` where `N` is the number of
        dimensions in the former spatial axes.
    - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the
        nonlinear differential operator.
    - `num_circle_points`: The number of points on the unit circle used to
        approximate the numerically challenging coefficients.
    - `circle_radius`: The radius of the circle used to approximate the
        numerically challenging coefficients.

    !!! warning
        The nonlinear function must take care of proper dealiasing.
        `BaseNonlinearFun` handles this automatically via its `fft` and
        `ifft` methods which apply pre- and post-dealiasing.

    !!! note
        The numerically stable evaluation of the coefficients follows
        [Kassam and Trefethen
        (2005)](https://doi.org/10.1137/S1064827502410633).
    """
    super().__init__(dt, linear_operator)
    self._nonlinear_fun = nonlinear_fun
    self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)

    roots = roots_of_unity(num_circle_points)
    L_dt = linear_operator * dt

    def scan_body(accs, root):
        lr = circle_radius * root + L_dt
        exp_lr = jnp.exp(lr)
        exp_lr_half = jnp.exp(lr / 2)
        c1 = ((exp_lr_half - 1) / lr).real
        c2 = ((exp_lr - 1) / lr).real
        c3 = ((-4 - lr + exp_lr * (4 - 3 * lr + lr**2)) / lr**3).real
        c4 = ((4.0 * (2.0 + lr + exp_lr * (-2 + lr))) / lr**3).real
        c5 = ((-4 - 3 * lr - lr**2 + exp_lr * (4 - lr)) / lr**3).real
        return (
            accs[0] + c1,
            accs[1] + c2,
            accs[2] + c3,
            accs[3] + c4,
            accs[4] + c5,
        ), None

    zeros = jnp.zeros_like(L_dt.real)
    (s1, s2, s3, s4, s5), _ = jax.lax.scan(scan_body, (zeros,) * 5, roots)
    mean_c1 = s1 / num_circle_points
    mean_c2 = s2 / num_circle_points
    mean_c3 = s3 / num_circle_points
    mean_c4 = s4 / num_circle_points
    mean_c5 = s5 / num_circle_points
    self._coef_1 = dt * mean_c1
    self._coef_2 = dt * mean_c2
    self._coef_3 = dt * mean_c3
    self._coef_4 = dt * mean_c4
    self._coef_5 = dt * mean_c5
step_fourier ¤
step_fourier(
    u_hat: Complex[Array, "E ... (N//2)+1"],
) -> Complex[Array, "E ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_3.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def step_fourier(
    self,
    u_hat: Complex[Array, "E ... (N//2)+1"],
) -> Complex[Array, "E ... (N//2)+1"]:
    u_nonlin_hat = self._nonlinear_fun(u_hat)
    u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat

    u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
    u_stage_2_hat = self._exp_term * u_hat + self._coef_2 * (
        2 * u_stage_1_nonlin_hat - u_nonlin_hat
    )

    u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat)

    u_next_hat = (
        self._exp_term * u_hat
        + self._coef_3 * u_nonlin_hat
        + self._coef_4 * u_stage_1_nonlin_hat
        + self._coef_5 * u_stage_2_nonlin_hat
    )

    return u_next_hat

exponax.etdrk.ETDRK4 ¤

Bases: BaseETDRK

Source code in exponax/etdrk/_etdrk_4.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
class ETDRK4(BaseETDRK):
    _nonlinear_fun: BaseNonlinearFun
    _half_exp_term: Complex[Array, "E ... (N//2)+1"]
    _coef_1: Complex[Array, "E ... (N//2)+1"]
    _coef_2: Complex[Array, "E ... (N//2)+1"]
    _coef_3: Complex[Array, "E ... (N//2)+1"]
    _coef_4: Complex[Array, "E ... (N//2)+1"]
    _coef_5: Complex[Array, "E ... (N//2)+1"]
    _coef_6: Complex[Array, "E ... (N//2)+1"]

    def __init__(
        self,
        dt: float,
        linear_operator: Complex[Array, "E ... (N//2)+1"],
        nonlinear_fun: BaseNonlinearFun,
        *,
        num_circle_points: int = 16,
        circle_radius: float = 1.0,
    ):
        r"""
        Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta
        with a **fourth order approximation**.

        Adapted from Eq. (26-29) of [Cox and Matthews
        (2002)](https://doi.org/10.1006/jcph.2002.6995):

        $$
        \begin{aligned}
            \hat{u}_h^* &=
            \exp(\hat{\mathcal{L}}_h \Delta t / 2)
            \odot
            \hat{u}_h^{[t]}
            +
            \frac{
                \exp(\hat{\mathcal{L}}_h \Delta t/2) - 1
            }{
                \hat{\mathcal{L}}_h
            }
            \odot
            \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}).
            \\
            \hat{u}_h^{**}
            &=
            \exp(\hat{\mathcal{L}}_h \Delta t / 2)
            \odot
            \hat{u}_h^{[t]}
            +
            \frac{
                \exp(\hat{\mathcal{L}}_h \Delta t / 2) - 1
            }{
                \hat{\mathcal{L}}_h
            } \odot \hat{\mathcal{N}}_h(\hat{u}_h^*).
            \\
            \hat{u}_h^{***}
            &=
            \exp(\hat{\mathcal{L}}_h \Delta t)
            \odot
            \hat{u}_h^{*}
            +
            \frac{
                \exp(\hat{\mathcal{L}}_h \Delta t/2) - 1
            }{
                \hat{\mathcal{L}}_h
            }
            \odot
            \left(
                2 \hat{\mathcal{N}}_h(\hat{u}_h^{**})
                -
                \hat{\mathcal{N}}_h(\hat{u}_h^{[t]})
            \right).
            \\
            \hat{u}_h^{[t+1]}
            &=
            \exp(\hat{\mathcal{L}}_h \Delta t)
            \odot
            \hat{u}_h^{[t]}
            \\
            &+
            \frac{
                -4 - \hat{\mathcal{L}}_h \Delta t
                +
                \exp(\hat{\mathcal{L}}_h \Delta t)
                \left(
                    4 - 3 \hat{\mathcal{L}}_h \Delta t
                    +
                    \left(
                        \hat{\mathcal{L}}_h \Delta t
                    \right)^2
                \right)
            }{
                \hat{\mathcal{L}}_h^3 (\Delta t)^2
            }
            \odot
            \hat{\mathcal{N}}_h(\hat{u}_h^{[t]})
            \\
            &+
            2 \frac{
                2 + \hat{\mathcal{L}}_h \Delta t
                +
                \exp(\hat{\mathcal{L}}_h \Delta t)
                \left(
                    -2 + \hat{\mathcal{L}}_h \Delta t
                \right)
            }{
                \hat{\mathcal{L}}_h^3 (\Delta t)^2
            }
            \odot
            \left(
                \hat{\mathcal{N}}_h(\hat{u}_h^*)
                +
                \hat{\mathcal{N}}_h(\hat{u}_h^{**})
            \right)
            \\
            &+
            \frac{
                -4 - 3 \hat{\mathcal{L}}_h \Delta t
                - \left(
                    \hat{\mathcal{L}}_h \Delta t
                \right)^2
                + \exp(\hat{\mathcal{L}}_h \Delta t)
                \left(
                    4 - \hat{\mathcal{L}}_h \Delta t
                \right)
            }{
                \hat{\mathcal{L}}_h^3 (\Delta t)^2
            }
            \odot
            \hat{\mathcal{N}}_h(\hat{u}_h^{***})
        \end{aligned}
        $$

        where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of
        the nonlinear differential operator.

        **Arguments:**

        - `dt`: The time step size.
        - `linear_operator`: The linear operator of the PDE. Must have a leading
            channel axis, followed by one, two or three spatial axes whereas the
            last axis must be of size `(N//2)+1` where `N` is the number of
            dimensions in the former spatial axes.
        - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the
            nonlinear differential operator.
        - `num_circle_points`: The number of points on the unit circle used to
            approximate the numerically challenging coefficients.
        - `circle_radius`: The radius of the circle used to approximate the
            numerically challenging coefficients.

        !!! warning
            The nonlinear function must take care of proper dealiasing.
            `BaseNonlinearFun` handles this automatically via its `fft` and
            `ifft` methods which apply pre- and post-dealiasing.

        !!! note
            The numerically stable evaluation of the coefficients follows
            [Kassam and Trefethen
            (2005)](https://doi.org/10.1137/S1064827502410633).
        """
        super().__init__(dt, linear_operator)
        self._nonlinear_fun = nonlinear_fun
        self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)

        roots = roots_of_unity(num_circle_points)
        L_dt = linear_operator * dt

        def scan_body(accs, root):
            lr = circle_radius * root + L_dt
            exp_lr = jnp.exp(lr)
            exp_lr_half = jnp.exp(lr / 2)
            c1 = ((exp_lr_half - 1) / lr).real
            c4 = ((-4 - lr + exp_lr * (4 - 3 * lr + lr**2)) / lr**3).real
            c5 = ((2 + lr + exp_lr * (-2 + lr)) / lr**3).real
            c6 = ((-4 - 3 * lr - lr**2 + exp_lr * (4 - lr)) / lr**3).real
            return (accs[0] + c1, accs[1] + c4, accs[2] + c5, accs[3] + c6), None

        zeros = jnp.zeros_like(L_dt.real)
        (s1, s4, s5, s6), _ = jax.lax.scan(scan_body, (zeros,) * 4, roots)
        mean_c1 = s1 / num_circle_points
        mean_c4 = s4 / num_circle_points
        mean_c5 = s5 / num_circle_points
        mean_c6 = s6 / num_circle_points
        self._coef_1 = dt * mean_c1
        self._coef_2 = self._coef_1
        self._coef_3 = self._coef_1
        self._coef_4 = dt * mean_c4
        self._coef_5 = dt * mean_c5
        self._coef_6 = dt * mean_c6

    def step_fourier(
        self,
        u_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        u_nonlin_hat = self._nonlinear_fun(u_hat)
        u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat

        u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
        u_stage_2_hat = (
            self._half_exp_term * u_hat + self._coef_2 * u_stage_1_nonlin_hat
        )

        u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat)
        u_stage_3_hat = self._half_exp_term * u_stage_1_hat + self._coef_3 * (
            2 * u_stage_2_nonlin_hat - u_nonlin_hat
        )

        u_stage_3_nonlin_hat = self._nonlinear_fun(u_stage_3_hat)

        u_next_hat = (
            self._exp_term * u_hat
            + self._coef_4 * u_nonlin_hat
            + self._coef_5 * 2 * (u_stage_1_nonlin_hat + u_stage_2_nonlin_hat)
            + self._coef_6 * u_stage_3_nonlin_hat
        )

        return u_next_hat
__init__ ¤
__init__(
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0
)

Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta with a fourth order approximation.

Adapted from Eq. (26-29) of Cox and Matthews (2002):

\[ \begin{aligned} \hat{u}_h^* &= \exp(\hat{\mathcal{L}}_h \Delta t / 2) \odot \hat{u}_h^{[t]} + \frac{ \exp(\hat{\mathcal{L}}_h \Delta t/2) - 1 }{ \hat{\mathcal{L}}_h } \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}). \\ \hat{u}_h^{**} &= \exp(\hat{\mathcal{L}}_h \Delta t / 2) \odot \hat{u}_h^{[t]} + \frac{ \exp(\hat{\mathcal{L}}_h \Delta t / 2) - 1 }{ \hat{\mathcal{L}}_h } \odot \hat{\mathcal{N}}_h(\hat{u}_h^*). \\ \hat{u}_h^{***} &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot \hat{u}_h^{*} + \frac{ \exp(\hat{\mathcal{L}}_h \Delta t/2) - 1 }{ \hat{\mathcal{L}}_h } \odot \left( 2 \hat{\mathcal{N}}_h(\hat{u}_h^{**}) - \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \right). \\ \hat{u}_h^{[t+1]} &= \exp(\hat{\mathcal{L}}_h \Delta t) \odot \hat{u}_h^{[t]} \\ &+ \frac{ -4 - \hat{\mathcal{L}}_h \Delta t + \exp(\hat{\mathcal{L}}_h \Delta t) \left( 4 - 3 \hat{\mathcal{L}}_h \Delta t + \left( \hat{\mathcal{L}}_h \Delta t \right)^2 \right) }{ \hat{\mathcal{L}}_h^3 (\Delta t)^2 } \odot \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}) \\ &+ 2 \frac{ 2 + \hat{\mathcal{L}}_h \Delta t + \exp(\hat{\mathcal{L}}_h \Delta t) \left( -2 + \hat{\mathcal{L}}_h \Delta t \right) }{ \hat{\mathcal{L}}_h^3 (\Delta t)^2 } \odot \left( \hat{\mathcal{N}}_h(\hat{u}_h^*) + \hat{\mathcal{N}}_h(\hat{u}_h^{**}) \right) \\ &+ \frac{ -4 - 3 \hat{\mathcal{L}}_h \Delta t - \left( \hat{\mathcal{L}}_h \Delta t \right)^2 + \exp(\hat{\mathcal{L}}_h \Delta t) \left( 4 - \hat{\mathcal{L}}_h \Delta t \right) }{ \hat{\mathcal{L}}_h^3 (\Delta t)^2 } \odot \hat{\mathcal{N}}_h(\hat{u}_h^{***}) \end{aligned} \]

where \(\hat{\mathcal{N}}_h\) is the Fourier pseudo-spectral treatment of the nonlinear differential operator.

Arguments:

  • dt: The time step size.
  • linear_operator: The linear operator of the PDE. Must have a leading channel axis, followed by one, two or three spatial axes whereas the last axis must be of size (N//2)+1 where N is the number of dimensions in the former spatial axes.
  • nonlinear_fun: The Fourier pseudo-spectral treatment of the nonlinear differential operator.
  • num_circle_points: The number of points on the unit circle used to approximate the numerically challenging coefficients.
  • circle_radius: The radius of the circle used to approximate the numerically challenging coefficients.

Warning

The nonlinear function must take care of proper dealiasing. BaseNonlinearFun handles this automatically via its fft and ifft methods which apply pre- and post-dealiasing.

Note

The numerically stable evaluation of the coefficients follows Kassam and Trefethen (2005).

Source code in exponax/etdrk/_etdrk_4.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def __init__(
    self,
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
    nonlinear_fun: BaseNonlinearFun,
    *,
    num_circle_points: int = 16,
    circle_radius: float = 1.0,
):
    r"""
    Solve a semi-linear PDE using Exponential Time Differencing Runge-Kutta
    with a **fourth order approximation**.

    Adapted from Eq. (26-29) of [Cox and Matthews
    (2002)](https://doi.org/10.1006/jcph.2002.6995):

    $$
    \begin{aligned}
        \hat{u}_h^* &=
        \exp(\hat{\mathcal{L}}_h \Delta t / 2)
        \odot
        \hat{u}_h^{[t]}
        +
        \frac{
            \exp(\hat{\mathcal{L}}_h \Delta t/2) - 1
        }{
            \hat{\mathcal{L}}_h
        }
        \odot
        \hat{\mathcal{N}}_h(\hat{u}_h^{[t]}).
        \\
        \hat{u}_h^{**}
        &=
        \exp(\hat{\mathcal{L}}_h \Delta t / 2)
        \odot
        \hat{u}_h^{[t]}
        +
        \frac{
            \exp(\hat{\mathcal{L}}_h \Delta t / 2) - 1
        }{
            \hat{\mathcal{L}}_h
        } \odot \hat{\mathcal{N}}_h(\hat{u}_h^*).
        \\
        \hat{u}_h^{***}
        &=
        \exp(\hat{\mathcal{L}}_h \Delta t)
        \odot
        \hat{u}_h^{*}
        +
        \frac{
            \exp(\hat{\mathcal{L}}_h \Delta t/2) - 1
        }{
            \hat{\mathcal{L}}_h
        }
        \odot
        \left(
            2 \hat{\mathcal{N}}_h(\hat{u}_h^{**})
            -
            \hat{\mathcal{N}}_h(\hat{u}_h^{[t]})
        \right).
        \\
        \hat{u}_h^{[t+1]}
        &=
        \exp(\hat{\mathcal{L}}_h \Delta t)
        \odot
        \hat{u}_h^{[t]}
        \\
        &+
        \frac{
            -4 - \hat{\mathcal{L}}_h \Delta t
            +
            \exp(\hat{\mathcal{L}}_h \Delta t)
            \left(
                4 - 3 \hat{\mathcal{L}}_h \Delta t
                +
                \left(
                    \hat{\mathcal{L}}_h \Delta t
                \right)^2
            \right)
        }{
            \hat{\mathcal{L}}_h^3 (\Delta t)^2
        }
        \odot
        \hat{\mathcal{N}}_h(\hat{u}_h^{[t]})
        \\
        &+
        2 \frac{
            2 + \hat{\mathcal{L}}_h \Delta t
            +
            \exp(\hat{\mathcal{L}}_h \Delta t)
            \left(
                -2 + \hat{\mathcal{L}}_h \Delta t
            \right)
        }{
            \hat{\mathcal{L}}_h^3 (\Delta t)^2
        }
        \odot
        \left(
            \hat{\mathcal{N}}_h(\hat{u}_h^*)
            +
            \hat{\mathcal{N}}_h(\hat{u}_h^{**})
        \right)
        \\
        &+
        \frac{
            -4 - 3 \hat{\mathcal{L}}_h \Delta t
            - \left(
                \hat{\mathcal{L}}_h \Delta t
            \right)^2
            + \exp(\hat{\mathcal{L}}_h \Delta t)
            \left(
                4 - \hat{\mathcal{L}}_h \Delta t
            \right)
        }{
            \hat{\mathcal{L}}_h^3 (\Delta t)^2
        }
        \odot
        \hat{\mathcal{N}}_h(\hat{u}_h^{***})
    \end{aligned}
    $$

    where $\hat{\mathcal{N}}_h$ is the Fourier pseudo-spectral treatment of
    the nonlinear differential operator.

    **Arguments:**

    - `dt`: The time step size.
    - `linear_operator`: The linear operator of the PDE. Must have a leading
        channel axis, followed by one, two or three spatial axes whereas the
        last axis must be of size `(N//2)+1` where `N` is the number of
        dimensions in the former spatial axes.
    - `nonlinear_fun`: The Fourier pseudo-spectral treatment of the
        nonlinear differential operator.
    - `num_circle_points`: The number of points on the unit circle used to
        approximate the numerically challenging coefficients.
    - `circle_radius`: The radius of the circle used to approximate the
        numerically challenging coefficients.

    !!! warning
        The nonlinear function must take care of proper dealiasing.
        `BaseNonlinearFun` handles this automatically via its `fft` and
        `ifft` methods which apply pre- and post-dealiasing.

    !!! note
        The numerically stable evaluation of the coefficients follows
        [Kassam and Trefethen
        (2005)](https://doi.org/10.1137/S1064827502410633).
    """
    super().__init__(dt, linear_operator)
    self._nonlinear_fun = nonlinear_fun
    self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)

    roots = roots_of_unity(num_circle_points)
    L_dt = linear_operator * dt

    def scan_body(accs, root):
        lr = circle_radius * root + L_dt
        exp_lr = jnp.exp(lr)
        exp_lr_half = jnp.exp(lr / 2)
        c1 = ((exp_lr_half - 1) / lr).real
        c4 = ((-4 - lr + exp_lr * (4 - 3 * lr + lr**2)) / lr**3).real
        c5 = ((2 + lr + exp_lr * (-2 + lr)) / lr**3).real
        c6 = ((-4 - 3 * lr - lr**2 + exp_lr * (4 - lr)) / lr**3).real
        return (accs[0] + c1, accs[1] + c4, accs[2] + c5, accs[3] + c6), None

    zeros = jnp.zeros_like(L_dt.real)
    (s1, s4, s5, s6), _ = jax.lax.scan(scan_body, (zeros,) * 4, roots)
    mean_c1 = s1 / num_circle_points
    mean_c4 = s4 / num_circle_points
    mean_c5 = s5 / num_circle_points
    mean_c6 = s6 / num_circle_points
    self._coef_1 = dt * mean_c1
    self._coef_2 = self._coef_1
    self._coef_3 = self._coef_1
    self._coef_4 = dt * mean_c4
    self._coef_5 = dt * mean_c5
    self._coef_6 = dt * mean_c6
step_fourier ¤
step_fourier(
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_4.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def step_fourier(
    self,
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
    u_nonlin_hat = self._nonlinear_fun(u_hat)
    u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat

    u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
    u_stage_2_hat = (
        self._half_exp_term * u_hat + self._coef_2 * u_stage_1_nonlin_hat
    )

    u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat)
    u_stage_3_hat = self._half_exp_term * u_stage_1_hat + self._coef_3 * (
        2 * u_stage_2_nonlin_hat - u_nonlin_hat
    )

    u_stage_3_nonlin_hat = self._nonlinear_fun(u_stage_3_hat)

    u_next_hat = (
        self._exp_term * u_hat
        + self._coef_4 * u_nonlin_hat
        + self._coef_5 * 2 * (u_stage_1_nonlin_hat + u_stage_2_nonlin_hat)
        + self._coef_6 * u_stage_3_nonlin_hat
    )

    return u_next_hat

exponax.etdrk.BaseETDRK ¤

Bases: Module, ABC

Source code in exponax/etdrk/_base_etdrk.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class BaseETDRK(eqx.Module, ABC):
    dt: float
    _exp_term: Complex[Array, "E ... (N//2)+1"]

    def __init__(
        self,
        dt: float,
        linear_operator: Complex[Array, "E ... (N//2)+1"],
    ):
        """
        Base class for exponential time differencing Runge-Kutta methods.

        **Arguments:**

        - `dt`: The time step size.
        - `linear_operator`: The linear operator of the PDE. Must have a leading
            channel axis, followed by one, two or three spatial axes whereas the
            last axis must be of size `(N//2)+1` where `N` is the number of
            dimensions in the former spatial axes.

        !!! Example
            Below is an example how to get the linear operator for
            the heat equation.

            ```python
            import jax.numpy as jnp
            import exponax as ex

            # Define the linear operator
            N = 256
            L = 5.0  # The domain size
            D = 1  # Being in 1D

            derivative_operator = 1j * ex.spectral.build_derivative_operator(
                D,
                L,
                N,
            )

            print(derivative_operator.shape)  # (1, (N//2)+1)

            nu = 0.01 # The diffusion coefficient

            linear_operator = nu * derivative_operator**2
            ```
        """
        self.dt = dt
        self._exp_term = jnp.exp(self.dt * linear_operator)

    @abstractmethod
    def step_fourier(
        self,
        u_hat: Complex[Array, "C ... (N//2)+1"],
    ) -> Complex[Array, "C ... (N//2)+1"]:
        """
        Advance the state in Fourier space.

        **Arguments:**

        - `u_hat`: The previous state in Fourier space.

        **Returns:**

        - The next state in Fourier space, i.e., `self.dt` time units later.
        """
        pass
__init__ ¤
__init__(
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
)

Base class for exponential time differencing Runge-Kutta methods.

Arguments:

  • dt: The time step size.
  • linear_operator: The linear operator of the PDE. Must have a leading channel axis, followed by one, two or three spatial axes whereas the last axis must be of size (N//2)+1 where N is the number of dimensions in the former spatial axes.

Example

Below is an example how to get the linear operator for the heat equation.

import jax.numpy as jnp
import exponax as ex

# Define the linear operator
N = 256
L = 5.0  # The domain size
D = 1  # Being in 1D

derivative_operator = 1j * ex.spectral.build_derivative_operator(
    D,
    L,
    N,
)

print(derivative_operator.shape)  # (1, (N//2)+1)

nu = 0.01 # The diffusion coefficient

linear_operator = nu * derivative_operator**2
Source code in exponax/etdrk/_base_etdrk.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    dt: float,
    linear_operator: Complex[Array, "E ... (N//2)+1"],
):
    """
    Base class for exponential time differencing Runge-Kutta methods.

    **Arguments:**

    - `dt`: The time step size.
    - `linear_operator`: The linear operator of the PDE. Must have a leading
        channel axis, followed by one, two or three spatial axes whereas the
        last axis must be of size `(N//2)+1` where `N` is the number of
        dimensions in the former spatial axes.

    !!! Example
        Below is an example how to get the linear operator for
        the heat equation.

        ```python
        import jax.numpy as jnp
        import exponax as ex

        # Define the linear operator
        N = 256
        L = 5.0  # The domain size
        D = 1  # Being in 1D

        derivative_operator = 1j * ex.spectral.build_derivative_operator(
            D,
            L,
            N,
        )

        print(derivative_operator.shape)  # (1, (N//2)+1)

        nu = 0.01 # The diffusion coefficient

        linear_operator = nu * derivative_operator**2
        ```
    """
    self.dt = dt
    self._exp_term = jnp.exp(self.dt * linear_operator)
step_fourier abstractmethod ¤
step_fourier(
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]

Advance the state in Fourier space.

Arguments:

  • u_hat: The previous state in Fourier space.

Returns:

  • The next state in Fourier space, i.e., self.dt time units later.
Source code in exponax/etdrk/_base_etdrk.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
@abstractmethod
def step_fourier(
    self,
    u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
    """
    Advance the state in Fourier space.

    **Arguments:**

    - `u_hat`: The previous state in Fourier space.

    **Returns:**

    - The next state in Fourier space, i.e., `self.dt` time units later.
    """
    pass

exponax.etdrk.roots_of_unity ¤

roots_of_unity(M: int) -> Complex[Array, M]

Return (complex-valued) array with M roots of unity. Useful to perform contour integrals in the complex plane.

Arguments:

  • M: The number of roots of unity.

Returns:

  • roots: The M roots of unity in an array of shape (M,).
Source code in exponax/etdrk/_utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def roots_of_unity(M: int) -> Complex[Array, "M"]:
    """
    Return (complex-valued) array with M roots of unity. Useful to perform
    contour integrals in the complex plane.

    **Arguments:**

    - `M`: The number of roots of unity.

    **Returns:**

    - `roots`: The M roots of unity in an array of shape `(M,)`.
    """
    # return jnp.exp(1j * jnp.pi * (jnp.arange(1, M+1) - 0.5) / M)
    return jnp.exp(2j * jnp.pi * (jnp.arange(1, M + 1) - 0.5) / M)