Skip to content

Helper¤

exponax.build_ic_set ¤

build_ic_set(
    ic_generator,
    *,
    num_points: int,
    num_samples: int,
    key: PRNGKeyArray
) -> Float[Array, "S 1 ... N"]

Generate a set of initial conditions by sampling from a given initial condition distribution and evaluating the function on the given grid.

Arguments: - ic_generator: A function that takes a PRNGKey and returns a function that takes a grid and returns a sample from the initial condition distribution. - num_samples: The number of initial conditions to sample. - key: The PRNGKey to use for sampling.

Returns: - ic_set: The set of initial conditions. Shape: (S, 1, ..., N). S = num_samples.

Source code in exponax/_utils.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def build_ic_set(
    ic_generator,
    *,
    num_points: int,
    num_samples: int,
    key: PRNGKeyArray,
) -> Float[Array, "S 1 ... N"]:
    """
    Generate a set of initial conditions by sampling from a given initial
    condition distribution and evaluating the function on the given grid.

    **Arguments:**
        - `ic_generator`: A function that takes a PRNGKey and returns a
            function that takes a grid and returns a sample from the initial
            condition distribution.
        - `num_samples`: The number of initial conditions to sample.
        - `key`: The PRNGKey to use for sampling.

    **Returns:**
        - `ic_set`: The set of initial conditions. Shape: `(S, 1, ..., N)`.
            `S = num_samples`.
    """

    def scan_fn(k, _):
        k, sub_k = jr.split(k)
        ic = ic_generator(num_points, key=sub_k)
        return k, ic

    _, ic_set = jax.lax.scan(scan_fn, key, None, length=num_samples)

    return ic_set

exponax.ic.MultiChannelIC ¤

Bases: Module

Source code in exponax/ic/_multi_channel.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class MultiChannelIC(eqx.Module):
    initial_conditions: tuple[BaseIC, ...]

    def __init__(self, initial_conditions: tuple[BaseIC, ...]):
        """
        A multi-channel initial condition.

        **Arguments**:
            - `initial_conditions`: A tuple of initial conditions.
        """
        self.initial_conditions = initial_conditions

    def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "C ... N"]:
        """
        Evaluate the initial condition.

        **Arguments**:
            - `x`: The grid points.

        **Returns**:
            - `u`: The initial condition evaluated at the grid points.
        """
        return jnp.concatenate([ic(x) for ic in self.initial_conditions], axis=0)
__init__ ¤
__init__(initial_conditions: tuple[BaseIC, ...])

A multi-channel initial condition.

Arguments: - initial_conditions: A tuple of initial conditions.

Source code in exponax/ic/_multi_channel.py
12
13
14
15
16
17
18
19
def __init__(self, initial_conditions: tuple[BaseIC, ...]):
    """
    A multi-channel initial condition.

    **Arguments**:
        - `initial_conditions`: A tuple of initial conditions.
    """
    self.initial_conditions = initial_conditions
__call__ ¤
__call__(
    x: Float[Array, "D ... N"]
) -> Float[Array, "C ... N"]

Evaluate the initial condition.

Arguments: - x: The grid points.

Returns: - u: The initial condition evaluated at the grid points.

Source code in exponax/ic/_multi_channel.py
21
22
23
24
25
26
27
28
29
30
31
def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "C ... N"]:
    """
    Evaluate the initial condition.

    **Arguments**:
        - `x`: The grid points.

    **Returns**:
        - `u`: The initial condition evaluated at the grid points.
    """
    return jnp.concatenate([ic(x) for ic in self.initial_conditions], axis=0)

exponax.ic.RandomMultiChannelICGenerator ¤

Bases: Module

Source code in exponax/ic/_multi_channel.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class RandomMultiChannelICGenerator(eqx.Module):
    ic_generators: tuple[BaseRandomICGenerator, ...]

    def __init__(self, ic_generators: tuple[BaseRandomICGenerator, ...]):
        """
        A multi-channel random initial condition generator.

        **Arguments**:
            - `ic_generators`: A tuple of initial condition generators.
        """
        self.ic_generators = ic_generators

    def gen_ic_fun(self, *, key: PRNGKeyArray) -> MultiChannelIC:
        ic_funs = [
            ic_gen.gen_ic_fun(key=k)
            for (ic_gen, k) in zip(
                self.ic_generators,
                jax.random.split(key, len(self.ic_generators)),
            )
        ]
        return MultiChannelIC(ic_funs)

    def __call__(
        self, num_points: int, *, key: PRNGKeyArray
    ) -> Float[Array, "C ... N"]:
        u_list = [
            ic_gen(num_points, key=k)
            for (ic_gen, k) in zip(
                self.ic_generators,
                jax.random.split(key, len(self.ic_generators)),
            )
        ]
        return jnp.concatenate(u_list, axis=0)
__init__ ¤
__init__(ic_generators: tuple[BaseRandomICGenerator, ...])

A multi-channel random initial condition generator.

Arguments: - ic_generators: A tuple of initial condition generators.

Source code in exponax/ic/_multi_channel.py
37
38
39
40
41
42
43
44
def __init__(self, ic_generators: tuple[BaseRandomICGenerator, ...]):
    """
    A multi-channel random initial condition generator.

    **Arguments**:
        - `ic_generators`: A tuple of initial condition generators.
    """
    self.ic_generators = ic_generators
__call__ ¤
__call__(
    num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "C ... N"]
Source code in exponax/ic/_multi_channel.py
56
57
58
59
60
61
62
63
64
65
66
def __call__(
    self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "C ... N"]:
    u_list = [
        ic_gen(num_points, key=k)
        for (ic_gen, k) in zip(
            self.ic_generators,
            jax.random.split(key, len(self.ic_generators)),
        )
    ]
    return jnp.concatenate(u_list, axis=0)

exponax.ic.ClampingICGenerator ¤

Bases: BaseRandomICGenerator

Source code in exponax/ic/_clamping.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class ClampingICGenerator(BaseRandomICGenerator):
    ic_gen: BaseRandomICGenerator
    limits: tuple[float, float]

    def __init__(
        self, ic_gen: BaseRandomICGenerator, limits: tuple[float, float] = (0, 1)
    ):
        """
        A generator based on another generator that clamps the output to a given
        range.

        **Arguments**:
            - `ic_gen`: The initial condition generator to clamp.
            - `limits`: The lower and upper limits of the clamping range.
        """
        self.ic_gen = ic_gen
        self.limits = limits
        self.num_spatial_dims = ic_gen.num_spatial_dims

    def __call__(
        self, num_points: int, *, key: PRNGKeyArray
    ) -> Float[Array, "1 ... N"]:
        ic = self.ic_gen(num_points=num_points, key=key)
        ic_above_zero = ic - jnp.min(ic)
        ic_clamped_to_unit_limits = ic_above_zero / jnp.max(ic_above_zero)
        range = self.limits[1] - self.limits[0]
        return ic_clamped_to_unit_limits * range + self.limits[0]
__init__ ¤
__init__(
    ic_gen: BaseRandomICGenerator,
    limits: tuple[float, float] = (0, 1),
)

A generator based on another generator that clamps the output to a given range.

Arguments: - ic_gen: The initial condition generator to clamp. - limits: The lower and upper limits of the clamping range.

Source code in exponax/ic/_clamping.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def __init__(
    self, ic_gen: BaseRandomICGenerator, limits: tuple[float, float] = (0, 1)
):
    """
    A generator based on another generator that clamps the output to a given
    range.

    **Arguments**:
        - `ic_gen`: The initial condition generator to clamp.
        - `limits`: The lower and upper limits of the clamping range.
    """
    self.ic_gen = ic_gen
    self.limits = limits
    self.num_spatial_dims = ic_gen.num_spatial_dims
__call__ ¤
__call__(
    num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]
Source code in exponax/ic/_clamping.py
26
27
28
29
30
31
32
33
def __call__(
    self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]:
    ic = self.ic_gen(num_points=num_points, key=key)
    ic_above_zero = ic - jnp.min(ic)
    ic_clamped_to_unit_limits = ic_above_zero / jnp.max(ic_above_zero)
    range = self.limits[1] - self.limits[0]
    return ic_clamped_to_unit_limits * range + self.limits[0]

exponax.ic.ScaledICGenerator ¤

Bases: BaseRandomICGenerator

Source code in exponax/ic/_scaled.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class ScaledICGenerator(BaseRandomICGenerator):
    ic_gen: BaseRandomICGenerator
    scale: float

    def __init__(self, ic_gen: BaseRandomICGenerator, scale: float):
        """
        A scaled initial condition generator.

        Works best in combination with initial conditions that have
        `max_one=True` or `std_one=True`.

        **Arguments**:
            - `ic_gen`: The initial condition generator.
            - `scale`: The scaling factor.
        """
        self.ic_gen = ic_gen
        self.scale = scale
        self.num_spatial_dims = ic_gen.num_spatial_dims

    def gen_ic_fun(self, *, key: PRNGKeyArray) -> BaseIC:
        return ScaledIC(self.ic_gen.gen_ic_fun(key=key), scale=self.scale)

    def __call__(
        self, num_points: int, *, key: PRNGKeyArray
    ) -> Float[Array, "1 ... N"]:
        ic = self.ic_gen(num_points=num_points, key=key)
        return ic * self.scale
__init__ ¤
__init__(ic_gen: BaseRandomICGenerator, scale: float)

A scaled initial condition generator.

Works best in combination with initial conditions that have max_one=True or std_one=True.

Arguments: - ic_gen: The initial condition generator. - scale: The scaling factor.

Source code in exponax/ic/_scaled.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def __init__(self, ic_gen: BaseRandomICGenerator, scale: float):
    """
    A scaled initial condition generator.

    Works best in combination with initial conditions that have
    `max_one=True` or `std_one=True`.

    **Arguments**:
        - `ic_gen`: The initial condition generator.
        - `scale`: The scaling factor.
    """
    self.ic_gen = ic_gen
    self.scale = scale
    self.num_spatial_dims = ic_gen.num_spatial_dims
__call__ ¤
__call__(
    num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]
Source code in exponax/ic/_scaled.py
36
37
38
39
40
def __call__(
    self, num_points: int, *, key: PRNGKeyArray
) -> Float[Array, "1 ... N"]:
    ic = self.ic_gen(num_points=num_points, key=key)
    return ic * self.scale

exponax.ic.ScaledIC ¤

Bases: BaseIC

Source code in exponax/ic/_scaled.py
 6
 7
 8
 9
10
11
class ScaledIC(BaseIC):
    ic: BaseIC
    scale: float

    def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]:
        return self.ic(x) * self.scale
__call__ ¤
__call__(
    x: Float[Array, "D ... N"]
) -> Float[Array, "1 ... N"]
Source code in exponax/ic/_scaled.py
10
11
def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]:
    return self.ic(x) * self.scale