Skip to content

Helper¤

exponax.build_ic_set ¤

build_ic_set(
    ic_generator: Callable[
        [int, PRNGKeyArray], Float[Array, "C ... N"]
    ],
    *,
    num_points: int,
    num_samples: int,
    key: PRNGKeyArray
) -> Float[Array, "S C ... N"]

Generate a set of initial conditions by sampling from a given initial condition distribution and evaluating the function on the given grid.

Arguments:

  • ic_generator: A function that takes a number of points and a PRNGKey and returns an array representing the discrete state of an initial condition. The shape of the returned array is (C, ..., N).
  • num_samples: The number of initial conditions to sample.
  • key: The PRNGKey to use for sampling.

Returns:

  • ic_set: The set of initial conditions. Shape: (S, C, ..., N). S = num_samples.
Source code in exponax/_utils.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def build_ic_set(
    ic_generator: Callable[[int, PRNGKeyArray], Float[Array, "C ... N"]],
    *,
    num_points: int,
    num_samples: int,
    key: PRNGKeyArray,
) -> Float[Array, "S C ... N"]:
    """
    Generate a set of initial conditions by sampling from a given initial
    condition distribution and evaluating the function on the given grid.

    **Arguments:**

    - `ic_generator`: A function that takes a number of points and a PRNGKey
        and returns an array representing the discrete state of an initial
        condition. The shape of the returned array is `(C, ..., N)`.
    - `num_samples`: The number of initial conditions to sample.
    - `key`: The PRNGKey to use for sampling.

    **Returns:**

    - `ic_set`: The set of initial conditions. Shape: `(S, C, ..., N)`.
        `S = num_samples`.
    """

    def scan_fn(k, _):
        k, sub_k = jr.split(k)
        ic = ic_generator(num_points, key=sub_k)
        return k, ic

    _, ic_set = jax.lax.scan(scan_fn, key, None, length=num_samples)

    return ic_set

exponax.ic.MultiChannelIC ¤

Bases: Module

Source code in exponax/ic/_multi_channel.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class MultiChannelIC(eqx.Module):
    initial_conditions: tuple[BaseIC, ...]

    def __init__(self, initial_conditions: tuple[BaseIC, ...]):
        """
        A multi-channel initial condition.

        **Arguments**:

        - `initial_conditions`: A tuple of initial conditions.
        """
        self.initial_conditions = initial_conditions

    def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "C ... N"]:
        """
        Evaluate the initial condition.

        **Arguments**:

        - `x`: The grid points.

        **Returns**:

        - `u`: The initial condition evaluated at the grid points.
        """
        return jnp.concatenate([ic(x) for ic in self.initial_conditions], axis=0)
__init__ ¤
__init__(initial_conditions: tuple[BaseIC, ...])

A multi-channel initial condition.

Arguments:

  • initial_conditions: A tuple of initial conditions.
Source code in exponax/ic/_multi_channel.py
12
13
14
15
16
17
18
19
20
def __init__(self, initial_conditions: tuple[BaseIC, ...]):
    """
    A multi-channel initial condition.

    **Arguments**:

    - `initial_conditions`: A tuple of initial conditions.
    """
    self.initial_conditions = initial_conditions
__call__ ¤
__call__(
    x: Float[Array, "D ... N"]
) -> Float[Array, "C ... N"]

Evaluate the initial condition.

Arguments:

  • x: The grid points.

Returns:

  • u: The initial condition evaluated at the grid points.
Source code in exponax/ic/_multi_channel.py
22
23
24
25
26
27
28
29
30
31
32
33
34
def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "C ... N"]:
    """
    Evaluate the initial condition.

    **Arguments**:

    - `x`: The grid points.

    **Returns**:

    - `u`: The initial condition evaluated at the grid points.
    """
    return jnp.concatenate([ic(x) for ic in self.initial_conditions], axis=0)

exponax.ic.RandomMultiChannelICGenerator ¤

Bases: Module

Source code in exponax/ic/_multi_channel.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class RandomMultiChannelICGenerator(eqx.Module):
    ic_generators: tuple[BaseRandomICGenerator, ...]

    def __init__(self, ic_generators: tuple[BaseRandomICGenerator, ...]):
        """
        A multi-channel random initial condition generator. Use this for
        problems with multiple channels, like Burgers in higher dimensions or
        the Gray-Scott dynamics.

        **Arguments**:

        - `ic_generators`: A tuple of initial condition generators.

        !!! example
            Below is an example for generating a random multi-channel initial
            condition for the three-dimensional Burgers equation which has three
            channels. For simplicity, we will use the same IC generator for each
            channel.

            ```python
            import jax
            import exponax as ex

            single_channel_ic_gen = ex.ic.RandomTruncatedFourierSeries(
                3,
                max_one=True,
            )
            multi_channel_ic_gen = ex.ic.RandomMultiChannelICGenerator(
                [single_channel_ic_gen,] * 3
            )

            ic = multi_channel_ic_gen(100, key=jax.random.PRNGKey(0))
            ```
        """
        self.ic_generators = ic_generators

    def gen_ic_fun(self, *, key: PRNGKeyArray) -> MultiChannelIC:
        ic_funs = [
            ic_gen.gen_ic_fun(key=k)
            for (ic_gen, k) in zip(
                self.ic_generators,
                jax.random.split(key, len(self.ic_generators)),
            )
        ]
        return MultiChannelIC(ic_funs)

    def __call__(
        self, num_points: int, *, key: PRNGKeyArray
    ) -> Float[Array, "C ... N"]:
        u_list = [
            ic_gen(num_points, key=k)
            for (ic_gen, k) in zip(
                self.ic_generators,
                jax.random.split(key, len(self.ic_generators)),
            )
        ]
        return jnp.concatenate(u_list, axis=0)
__init__ ¤
__init__(ic_generators: tuple[BaseRandomICGenerator, ...])

A multi-channel random initial condition generator. Use this for problems with multiple channels, like Burgers in higher dimensions or the Gray-Scott dynamics.

Arguments:

  • ic_generators: A tuple of initial condition generators.

Example

Below is an example for generating a random multi-channel initial condition for the three-dimensional Burgers equation which has three channels. For simplicity, we will use the same IC generator for each channel.

import jax
import exponax as ex

single_channel_ic_gen = ex.ic.RandomTruncatedFourierSeries(
    3,
    max_one=True,
)
multi_channel_ic_gen = ex.ic.RandomMultiChannelICGenerator(
    [single_channel_ic_gen,] * 3
)

ic = multi_channel_ic_gen(100, key=jax.random.PRNGKey(0))
Source code in exponax/ic/_multi_channel.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def __init__(self, ic_generators: tuple[BaseRandomICGenerator, ...]):
    """
    A multi-channel random initial condition generator. Use this for
    problems with multiple channels, like Burgers in higher dimensions or
    the Gray-Scott dynamics.

    **Arguments**:

    - `ic_generators`: A tuple of initial condition generators.

    !!! example
        Below is an example for generating a random multi-channel initial
        condition for the three-dimensional Burgers equation which has three
        channels. For simplicity, we will use the same IC generator for each
        channel.

        ```python
        import jax
        import exponax as ex

        single_channel_ic_gen = ex.ic.RandomTruncatedFourierSeries(
            3,
            max_one=True,
        )
        multi_channel_ic_gen = ex.ic.RandomMultiChannelICGenerator(
            [single_channel_ic_gen,] * 3
        )

        ic = multi_channel_ic_gen(100, key=jax.random.PRNGKey(0))
        ```
    """
    self.ic_generators = ic_generators
__call__ ¤
__call__(
    num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "C ... N"]
Source code in exponax/ic/_multi_channel.py
83
84
85
86
87
88
89
90
91
92
93
def __call__(
    self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "C ... N"]:
    u_list = [
        ic_gen(num_points, key=k)
        for (ic_gen, k) in zip(
            self.ic_generators,
            jax.random.split(key, len(self.ic_generators)),
        )
    ]
    return jnp.concatenate(u_list, axis=0)

exponax.ic.ClampingICGenerator ¤

Bases: BaseRandomICGenerator

Source code in exponax/ic/_clamping.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class ClampingICGenerator(BaseRandomICGenerator):
    ic_gen: BaseRandomICGenerator
    limits: tuple[float, float]

    def __init__(
        self, ic_gen: BaseRandomICGenerator, limits: tuple[float, float] = (0, 1)
    ):
        """
        A generator based on another generator that clamps the output to a given
        range.

        Some dynamics (like the Fisher-KPP equation) require such initial
        conditions.

        **Arguments**:

        - `ic_gen`: The initial condition generator to clamp.
        - `limits`: The lower and upper limits of the clamping range.
        """
        self.ic_gen = ic_gen
        self.limits = limits
        self.num_spatial_dims = ic_gen.num_spatial_dims

    def __call__(
        self, num_points: int, *, key: PRNGKeyArray
    ) -> Float[Array, "1 ... N"]:
        ic = self.ic_gen(num_points=num_points, key=key)
        ic_above_zero = ic - jnp.min(ic)
        ic_clamped_to_unit_limits = ic_above_zero / jnp.max(ic_above_zero)
        range = self.limits[1] - self.limits[0]
        return ic_clamped_to_unit_limits * range + self.limits[0]
__init__ ¤
__init__(
    ic_gen: BaseRandomICGenerator,
    limits: tuple[float, float] = (0, 1),
)

A generator based on another generator that clamps the output to a given range.

Some dynamics (like the Fisher-KPP equation) require such initial conditions.

Arguments:

  • ic_gen: The initial condition generator to clamp.
  • limits: The lower and upper limits of the clamping range.
Source code in exponax/ic/_clamping.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(
    self, ic_gen: BaseRandomICGenerator, limits: tuple[float, float] = (0, 1)
):
    """
    A generator based on another generator that clamps the output to a given
    range.

    Some dynamics (like the Fisher-KPP equation) require such initial
    conditions.

    **Arguments**:

    - `ic_gen`: The initial condition generator to clamp.
    - `limits`: The lower and upper limits of the clamping range.
    """
    self.ic_gen = ic_gen
    self.limits = limits
    self.num_spatial_dims = ic_gen.num_spatial_dims
__call__ ¤
__call__(
    num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]
Source code in exponax/ic/_clamping.py
30
31
32
33
34
35
36
37
def __call__(
    self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]:
    ic = self.ic_gen(num_points=num_points, key=key)
    ic_above_zero = ic - jnp.min(ic)
    ic_clamped_to_unit_limits = ic_above_zero / jnp.max(ic_above_zero)
    range = self.limits[1] - self.limits[0]
    return ic_clamped_to_unit_limits * range + self.limits[0]

exponax.ic.ScaledICGenerator ¤

Bases: BaseRandomICGenerator

Source code in exponax/ic/_scaled.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class ScaledICGenerator(BaseRandomICGenerator):
    ic_gen: BaseRandomICGenerator
    scale: float

    def __init__(self, ic_gen: BaseRandomICGenerator, scale: float):
        """
        A scaled initial condition generator.

        Works best in combination with initial conditions that have
        `max_one=True` or `std_one=True`.

        **Arguments**:

        - `ic_gen`: The initial condition generator.
        - `scale`: The scaling factor.
        """
        self.ic_gen = ic_gen
        self.scale = scale
        self.num_spatial_dims = ic_gen.num_spatial_dims

    def gen_ic_fun(self, *, key: PRNGKeyArray) -> BaseIC:
        return ScaledIC(self.ic_gen.gen_ic_fun(key=key), scale=self.scale)

    def __call__(
        self, num_points: int, *, key: PRNGKeyArray
    ) -> Float[Array, "1 ... N"]:
        ic = self.ic_gen(num_points=num_points, key=key)
        return ic * self.scale
__init__ ¤
__init__(ic_gen: BaseRandomICGenerator, scale: float)

A scaled initial condition generator.

Works best in combination with initial conditions that have max_one=True or std_one=True.

Arguments:

  • ic_gen: The initial condition generator.
  • scale: The scaling factor.
Source code in exponax/ic/_scaled.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(self, ic_gen: BaseRandomICGenerator, scale: float):
    """
    A scaled initial condition generator.

    Works best in combination with initial conditions that have
    `max_one=True` or `std_one=True`.

    **Arguments**:

    - `ic_gen`: The initial condition generator.
    - `scale`: The scaling factor.
    """
    self.ic_gen = ic_gen
    self.scale = scale
    self.num_spatial_dims = ic_gen.num_spatial_dims
__call__ ¤
__call__(
    num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]
Source code in exponax/ic/_scaled.py
37
38
39
40
41
def __call__(
    self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]:
    ic = self.ic_gen(num_points=num_points, key=key)
    return ic * self.scale

exponax.ic.ScaledIC ¤

Bases: BaseIC

Source code in exponax/ic/_scaled.py
 6
 7
 8
 9
10
11
class ScaledIC(BaseIC):
    ic: BaseIC
    scale: float

    def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]:
        return self.ic(x) * self.scale
__call__ ¤
__call__(
    x: Float[Array, "D ... N"]
) -> Float[Array, "1 ... N"]
Source code in exponax/ic/_scaled.py
10
11
def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]:
    return self.ic(x) * self.scale