ETDRK Backbone
Core clases that implement the Exponential Time Differencing Runge-Kutta (ETDRK)
method for solving semi-linear PDEs in form of timesteppers. Require supplying
the time step size \(\Delta t\), the linear operator in Fourier space \(\hat{\mathcal{L}}_h\), and the non-linear operator in Fourier space \(\hat{\mathcal{N}}_h\).
exponax.etdrk.ETDRK0
Bases: BaseETDRK
Exactly solve a linear PDE in Fourier space
Source code in exponax/etdrk/_etdrk_0.py
6
7
8
9
10
11
12
13
14
15 | class ETDRK0(BaseETDRK):
"""
Exactly solve a linear PDE in Fourier space
"""
def step_fourier(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
return self._exp_term * u_hat
|
__init__
__init__(
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
)
Source code in exponax/etdrk/_base_etdrk.py
| def __init__(
self,
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
):
self.dt = dt
self._exp_term = jnp.exp(self.dt * linear_operator)
|
step_fourier
step_fourier(
u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_0.py
| def step_fourier(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
return self._exp_term * u_hat
|
exponax.etdrk.ETDRK1
Bases: BaseETDRK
Source code in exponax/etdrk/_etdrk_1.py
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36 | class ETDRK1(BaseETDRK):
_nonlinear_fun: BaseNonlinearFun
_coef_1: Complex[Array, "E ... (N//2)+1"]
def __init__(
self,
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
nonlinear_fun: BaseNonlinearFun,
*,
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
super().__init__(dt, linear_operator)
self._nonlinear_fun = nonlinear_fun
LR = (
circle_radius * roots_of_unity(num_circle_points)
+ linear_operator[..., jnp.newaxis] * dt
)
self._coef_1 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real
def step_fourier(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
return self._exp_term * u_hat + self._coef_1 * self._nonlinear_fun(u_hat)
|
__init__
__init__(
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
nonlinear_fun: BaseNonlinearFun,
*,
num_circle_points: int = 16,
circle_radius: float = 1.0
)
Source code in exponax/etdrk/_etdrk_1.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 | def __init__(
self,
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
nonlinear_fun: BaseNonlinearFun,
*,
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
super().__init__(dt, linear_operator)
self._nonlinear_fun = nonlinear_fun
LR = (
circle_radius * roots_of_unity(num_circle_points)
+ linear_operator[..., jnp.newaxis] * dt
)
self._coef_1 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real
|
step_fourier
step_fourier(
u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_1.py
| def step_fourier(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
return self._exp_term * u_hat + self._coef_1 * self._nonlinear_fun(u_hat)
|
exponax.etdrk.ETDRK2
Bases: BaseETDRK
Source code in exponax/etdrk/_etdrk_2.py
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46 | class ETDRK2(BaseETDRK):
_nonlinear_fun: BaseNonlinearFun
_coef_1: Complex[Array, "E ... (N//2)+1"]
_coef_2: Complex[Array, "E ... (N//2)+1"]
def __init__(
self,
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
nonlinear_fun: BaseNonlinearFun,
*,
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
super().__init__(dt, linear_operator)
self._nonlinear_fun = nonlinear_fun
LR = (
circle_radius * roots_of_unity(num_circle_points)
+ linear_operator[..., jnp.newaxis] * dt
)
self._coef_1 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real
self._coef_2 = dt * jnp.mean((jnp.exp(LR) - 1 - LR) / LR**2, axis=-1).real
def step_fourier(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
u_nonlin_hat = self._nonlinear_fun(u_hat)
u_stage_1_hat = self._exp_term * u_hat + self._coef_1 * u_nonlin_hat
u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
u_next_hat = u_stage_1_hat + self._coef_2 * (
u_stage_1_nonlin_hat - u_nonlin_hat
)
return u_next_hat
|
__init__
__init__(
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
nonlinear_fun: BaseNonlinearFun,
*,
num_circle_points: int = 16,
circle_radius: float = 1.0
)
Source code in exponax/etdrk/_etdrk_2.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33 | def __init__(
self,
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
nonlinear_fun: BaseNonlinearFun,
*,
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
super().__init__(dt, linear_operator)
self._nonlinear_fun = nonlinear_fun
LR = (
circle_radius * roots_of_unity(num_circle_points)
+ linear_operator[..., jnp.newaxis] * dt
)
self._coef_1 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real
self._coef_2 = dt * jnp.mean((jnp.exp(LR) - 1 - LR) / LR**2, axis=-1).real
|
step_fourier
step_fourier(
u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_2.py
35
36
37
38
39
40
41
42
43
44
45
46 | def step_fourier(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
u_nonlin_hat = self._nonlinear_fun(u_hat)
u_stage_1_hat = self._exp_term * u_hat + self._coef_1 * u_nonlin_hat
u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
u_next_hat = u_stage_1_hat + self._coef_2 * (
u_stage_1_nonlin_hat - u_nonlin_hat
)
return u_next_hat
|
exponax.etdrk.ETDRK3
Bases: BaseETDRK
Source code in exponax/etdrk/_etdrk_3.py
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82 | class ETDRK3(BaseETDRK):
_nonlinear_fun: BaseNonlinearFun
_half_exp_term: Complex[Array, "E ... (N//2)+1"]
_coef_1: Complex[Array, "E ... (N//2)+1"]
_coef_2: Complex[Array, "E ... (N//2)+1"]
_coef_3: Complex[Array, "E ... (N//2)+1"]
_coef_4: Complex[Array, "E ... (N//2)+1"]
_coef_5: Complex[Array, "E ... (N//2)+1"]
def __init__(
self,
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
nonlinear_fun: BaseNonlinearFun,
*,
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
super().__init__(dt, linear_operator)
self._nonlinear_fun = nonlinear_fun
self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)
LR = (
circle_radius * roots_of_unity(num_circle_points)
+ linear_operator[..., jnp.newaxis] * dt
)
self._coef_1 = dt * jnp.mean((jnp.exp(LR / 2) - 1) / LR, axis=-1).real
self._coef_2 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real
self._coef_3 = (
dt
* jnp.mean(
(-4 - LR + jnp.exp(LR) * (4 - 3 * LR + LR**2)) / (LR**3), axis=-1
).real
)
self._coef_4 = (
dt
* jnp.mean(
(4.0 * (2.0 + LR + jnp.exp(LR) * (-2 + LR))) / (LR**3), axis=-1
).real
)
self._coef_5 = (
dt
* jnp.mean(
(-4 - 3 * LR - LR**2 + jnp.exp(LR) * (4 - LR)) / (LR**3), axis=-1
).real
)
def step_fourier(
self,
u_hat: Complex[Array, "E ... (N//2)+1"],
) -> Complex[Array, "E ... (N//2)+1"]:
u_nonlin_hat = self._nonlinear_fun(u_hat)
u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat
u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
u_stage_2_hat = self._exp_term * u_hat + self._coef_2 * (
2 * u_stage_1_nonlin_hat - u_nonlin_hat
)
u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat)
u_next_hat = (
self._exp_term * u_hat
+ self._coef_3 * u_nonlin_hat
+ self._coef_4 * u_stage_1_nonlin_hat
+ self._coef_5 * u_stage_2_nonlin_hat
)
return u_next_hat
|
__init__
__init__(
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
nonlinear_fun: BaseNonlinearFun,
*,
num_circle_points: int = 16,
circle_radius: float = 1.0
)
Source code in exponax/etdrk/_etdrk_3.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59 | def __init__(
self,
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
nonlinear_fun: BaseNonlinearFun,
*,
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
super().__init__(dt, linear_operator)
self._nonlinear_fun = nonlinear_fun
self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)
LR = (
circle_radius * roots_of_unity(num_circle_points)
+ linear_operator[..., jnp.newaxis] * dt
)
self._coef_1 = dt * jnp.mean((jnp.exp(LR / 2) - 1) / LR, axis=-1).real
self._coef_2 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real
self._coef_3 = (
dt
* jnp.mean(
(-4 - LR + jnp.exp(LR) * (4 - 3 * LR + LR**2)) / (LR**3), axis=-1
).real
)
self._coef_4 = (
dt
* jnp.mean(
(4.0 * (2.0 + LR + jnp.exp(LR) * (-2 + LR))) / (LR**3), axis=-1
).real
)
self._coef_5 = (
dt
* jnp.mean(
(-4 - 3 * LR - LR**2 + jnp.exp(LR) * (4 - LR)) / (LR**3), axis=-1
).real
)
|
step_fourier
step_fourier(
u_hat: Complex[Array, "E ... (N//2)+1"]
) -> Complex[Array, "E ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_3.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82 | def step_fourier(
self,
u_hat: Complex[Array, "E ... (N//2)+1"],
) -> Complex[Array, "E ... (N//2)+1"]:
u_nonlin_hat = self._nonlinear_fun(u_hat)
u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat
u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
u_stage_2_hat = self._exp_term * u_hat + self._coef_2 * (
2 * u_stage_1_nonlin_hat - u_nonlin_hat
)
u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat)
u_next_hat = (
self._exp_term * u_hat
+ self._coef_3 * u_nonlin_hat
+ self._coef_4 * u_stage_1_nonlin_hat
+ self._coef_5 * u_stage_2_nonlin_hat
)
return u_next_hat
|
exponax.etdrk.ETDRK4
Bases: BaseETDRK
Source code in exponax/etdrk/_etdrk_4.py
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86 | class ETDRK4(BaseETDRK):
_nonlinear_fun: BaseNonlinearFun
_half_exp_term: Complex[Array, "E ... (N//2)+1"]
_coef_1: Complex[Array, "E ... (N//2)+1"]
_coef_2: Complex[Array, "E ... (N//2)+1"]
_coef_3: Complex[Array, "E ... (N//2)+1"]
_coef_4: Complex[Array, "E ... (N//2)+1"]
_coef_5: Complex[Array, "E ... (N//2)+1"]
_coef_6: Complex[Array, "E ... (N//2)+1"]
def __init__(
self,
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
nonlinear_fun: BaseNonlinearFun,
*,
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
super().__init__(dt, linear_operator)
self._nonlinear_fun = nonlinear_fun
self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)
LR = (
circle_radius * roots_of_unity(num_circle_points)
+ linear_operator[..., jnp.newaxis] * dt
)
self._coef_1 = dt * jnp.mean((jnp.exp(LR / 2) - 1) / LR, axis=-1).real
self._coef_2 = self._coef_1
self._coef_3 = self._coef_1
self._coef_4 = (
dt
* jnp.mean(
(-4 - LR + jnp.exp(LR) * (4 - 3 * LR + LR**2)) / (LR**3), axis=-1
).real
)
self._coef_5 = (
dt * jnp.mean((2 + LR + jnp.exp(LR) * (-2 + LR)) / (LR**3), axis=-1).real
)
self._coef_6 = (
dt
* jnp.mean(
(-4 - 3 * LR - LR**2 + jnp.exp(LR) * (4 - LR)) / (LR**3), axis=-1
).real
)
def step_fourier(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
u_nonlin_hat = self._nonlinear_fun(u_hat)
u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat
u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
u_stage_2_hat = (
self._half_exp_term * u_hat + self._coef_2 * u_stage_1_nonlin_hat
)
u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat)
u_stage_3_hat = self._half_exp_term * u_stage_1_hat + self._coef_3 * (
2 * u_stage_2_nonlin_hat - u_nonlin_hat
)
u_stage_3_nonlin_hat = self._nonlinear_fun(u_stage_3_hat)
u_next_hat = (
self._exp_term * u_hat
+ self._coef_4 * u_nonlin_hat
+ self._coef_5 * 2 * (u_stage_1_nonlin_hat + u_stage_2_nonlin_hat)
+ self._coef_6 * u_stage_3_nonlin_hat
)
return u_next_hat
|
__init__
__init__(
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
nonlinear_fun: BaseNonlinearFun,
*,
num_circle_points: int = 16,
circle_radius: float = 1.0
)
Source code in exponax/etdrk/_etdrk_4.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58 | def __init__(
self,
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
nonlinear_fun: BaseNonlinearFun,
*,
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
super().__init__(dt, linear_operator)
self._nonlinear_fun = nonlinear_fun
self._half_exp_term = jnp.exp(0.5 * dt * linear_operator)
LR = (
circle_radius * roots_of_unity(num_circle_points)
+ linear_operator[..., jnp.newaxis] * dt
)
self._coef_1 = dt * jnp.mean((jnp.exp(LR / 2) - 1) / LR, axis=-1).real
self._coef_2 = self._coef_1
self._coef_3 = self._coef_1
self._coef_4 = (
dt
* jnp.mean(
(-4 - LR + jnp.exp(LR) * (4 - 3 * LR + LR**2)) / (LR**3), axis=-1
).real
)
self._coef_5 = (
dt * jnp.mean((2 + LR + jnp.exp(LR) * (-2 + LR)) / (LR**3), axis=-1).real
)
self._coef_6 = (
dt
* jnp.mean(
(-4 - 3 * LR - LR**2 + jnp.exp(LR) * (4 - LR)) / (LR**3), axis=-1
).real
)
|
step_fourier
step_fourier(
u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]
Source code in exponax/etdrk/_etdrk_4.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86 | def step_fourier(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
u_nonlin_hat = self._nonlinear_fun(u_hat)
u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat
u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat)
u_stage_2_hat = (
self._half_exp_term * u_hat + self._coef_2 * u_stage_1_nonlin_hat
)
u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat)
u_stage_3_hat = self._half_exp_term * u_stage_1_hat + self._coef_3 * (
2 * u_stage_2_nonlin_hat - u_nonlin_hat
)
u_stage_3_nonlin_hat = self._nonlinear_fun(u_stage_3_hat)
u_next_hat = (
self._exp_term * u_hat
+ self._coef_4 * u_nonlin_hat
+ self._coef_5 * 2 * (u_stage_1_nonlin_hat + u_stage_2_nonlin_hat)
+ self._coef_6 * u_stage_3_nonlin_hat
)
return u_next_hat
|
exponax.etdrk.BaseETDRK
Bases: Module
, ABC
Source code in exponax/etdrk/_base_etdrk.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35 | class BaseETDRK(eqx.Module, ABC):
dt: float
_exp_term: Complex[Array, "E ... (N//2)+1"]
def __init__(
self,
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
):
self.dt = dt
self._exp_term = jnp.exp(self.dt * linear_operator)
@abstractmethod
def step_fourier(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
"""
Advance the state in Fourier space.
"""
pass
|
__init__
__init__(
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
)
Source code in exponax/etdrk/_base_etdrk.py
| def __init__(
self,
dt: float,
linear_operator: Complex[Array, "E ... (N//2)+1"],
):
self.dt = dt
self._exp_term = jnp.exp(self.dt * linear_operator)
|
step_fourier
abstractmethod
step_fourier(
u_hat: Complex[Array, "C ... (N//2)+1"]
) -> Complex[Array, "C ... (N//2)+1"]
Advance the state in Fourier space.
Source code in exponax/etdrk/_base_etdrk.py
27
28
29
30
31
32
33
34
35 | @abstractmethod
def step_fourier(
self,
u_hat: Complex[Array, "C ... (N//2)+1"],
) -> Complex[Array, "C ... (N//2)+1"]:
"""
Advance the state in Fourier space.
"""
pass
|
exponax.etdrk.roots_of_unity
roots_of_unity(M: int) -> Complex[Array, M]
Return (complex-valued) array with M roots of unity.
Source code in exponax/etdrk/_utils.py
| def roots_of_unity(M: int) -> Complex[Array, "M"]:
"""
Return (complex-valued) array with M roots of unity.
"""
# return jnp.exp(1j * jnp.pi * (jnp.arange(1, M+1) - 0.5) / M)
return jnp.exp(2j * jnp.pi * (jnp.arange(1, M + 1) - 0.5) / M)
|