Skip to content

Base Configuration¤

trainax.configuration.BaseConfiguration ¤

Bases: Module, ABC

Source code in trainax/configuration/_base_configuration.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class BaseConfiguration(eqx.Module, ABC):
    @abstractmethod
    def __call__(
        self,
        stepper: eqx.Module,
        data: PyTree,
        *,
        ref_stepper: eqx.Module = None,
        residuum_fn: eqx.Module = None,
    ) -> float:
        """
        Evaluate the configuration on the given data.

        **Arguments:**

        - `stepper`: The stepper to use for the configuration. Must
            have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`.
        - `data`: The data to evaluate the configuration on. This
            depends on the concrete configuration. In the most reduced case, it
            just contains the set of initial states.
        - `ref_stepper`: The reference stepper to use for some
            configurations. (keyword-only argument)
        - `residuum_fn`: The residuum function to use for some
            configurations. (keyword-only argument)

        **Returns:**

        - The loss value computed by this configuration.
        """
        pass
__call__ abstractmethod ¤
__call__(
    stepper: eqx.Module,
    data: PyTree,
    *,
    ref_stepper: eqx.Module = None,
    residuum_fn: eqx.Module = None
) -> float

Evaluate the configuration on the given data.

Arguments:

  • stepper: The stepper to use for the configuration. Must have the signature stepper(u_prev: PyTree) -> u_next: PyTree.
  • data: The data to evaluate the configuration on. This depends on the concrete configuration. In the most reduced case, it just contains the set of initial states.
  • ref_stepper: The reference stepper to use for some configurations. (keyword-only argument)
  • residuum_fn: The residuum function to use for some configurations. (keyword-only argument)

Returns:

  • The loss value computed by this configuration.
Source code in trainax/configuration/_base_configuration.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@abstractmethod
def __call__(
    self,
    stepper: eqx.Module,
    data: PyTree,
    *,
    ref_stepper: eqx.Module = None,
    residuum_fn: eqx.Module = None,
) -> float:
    """
    Evaluate the configuration on the given data.

    **Arguments:**

    - `stepper`: The stepper to use for the configuration. Must
        have the signature `stepper(u_prev: PyTree) -> u_next: PyTree`.
    - `data`: The data to evaluate the configuration on. This
        depends on the concrete configuration. In the most reduced case, it
        just contains the set of initial states.
    - `ref_stepper`: The reference stepper to use for some
        configurations. (keyword-only argument)
    - `residuum_fn`: The residuum function to use for some
        configurations. (keyword-only argument)

    **Returns:**

    - The loss value computed by this configuration.
    """
    pass