Skip to content

Callbacks¤

trainax.callback.SaveNetwork ¤

Bases: BaseCallback

Source code in trainax/callback/_save_network.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class SaveNetwork(BaseCallback):
    path: str
    file_name: str

    def __init__(
        self,
        every: int,
        path: str,
        file_name: str,
        name: str = "network_saved",
    ):
        """
        Callback to write the network state to a file `every` update step.

        **Arguments:**

        - `every`: The frequency of the callback.
        - `path`: The path to save the network state.
        - `file_name`: The file name to save the network state.
        - `name`: The name of the callback
        """
        self.path = path
        self.file_name = file_name
        super().__init__(every, name)

    def callback(
        self,
        update_i: int,
        stepper: eqx.Module,
        data: PyTree,
    ) -> Any:
        concrete_file_name = f"{self.path}/{self.file_name}_{update_i}.eqx"
        eqx.tree_serialise_leaves(stepper, concrete_file_name)
        return True
__init__ ¤
__init__(
    every: int,
    path: str,
    file_name: str,
    name: str = "network_saved",
)

Callback to write the network state to a file every update step.

Arguments:

  • every: The frequency of the callback.
  • path: The path to save the network state.
  • file_name: The file name to save the network state.
  • name: The name of the callback
Source code in trainax/callback/_save_network.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(
    self,
    every: int,
    path: str,
    file_name: str,
    name: str = "network_saved",
):
    """
    Callback to write the network state to a file `every` update step.

    **Arguments:**

    - `every`: The frequency of the callback.
    - `path`: The path to save the network state.
    - `file_name`: The file name to save the network state.
    - `name`: The name of the callback
    """
    self.path = path
    self.file_name = file_name
    super().__init__(every, name)

trainax.callback.GetNetwork ¤

Bases: BaseCallback

Source code in trainax/callback/_get_network.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
class GetNetwork(BaseCallback):
    def __init__(self, every: int, name: str = "network"):
        """Callback to write out the network state `every` update step."""
        super().__init__(every, name)

    def callback(
        self,
        update_i: int,
        stepper: eqx.Module,
        data: PyTree,
    ) -> Any:
        """Write out the network state."""
        return stepper
__init__ ¤
__init__(every: int, name: str = 'network')

Callback to write out the network state every update step.

Source code in trainax/callback/_get_network.py
10
11
12
def __init__(self, every: int, name: str = "network"):
    """Callback to write out the network state `every` update step."""
    super().__init__(every, name)

trainax.callback.WeightNorm ¤

Bases: BaseCallback

Source code in trainax/callback/_weight_norm.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class WeightNorm(BaseCallback):
    squared: bool = False

    def __init__(self, every: int, squared: bool = False, name: str = "weight_norm"):
        """
        Callback to save the weight norm `every` update steps.

        **Arguments:**

        - `every`: The frequency of the callback.
        - `squared`: Whether to return the squared weight norm.
        - `name`: The name of the callback
        """
        self.squared = squared
        super().__init__(every, name)

    def callback(
        self,
        update_i: int,
        stepper: eqx.Module,
        data: PyTree,
    ) -> eqx.Module:
        weights = jtu.tree_leaves(eqx.filter(stepper, eqx.is_array))
        norms_squared = [jnp.sum(w**2) for w in weights]
        norm_squared = sum(norms_squared)

        if self.squared:
            return norm_squared
        else:
            return jnp.sqrt(norm_squared)
__init__ ¤
__init__(
    every: int,
    squared: bool = False,
    name: str = "weight_norm",
)

Callback to save the weight norm every update steps.

Arguments:

  • every: The frequency of the callback.
  • squared: Whether to return the squared weight norm.
  • name: The name of the callback
Source code in trainax/callback/_weight_norm.py
12
13
14
15
16
17
18
19
20
21
22
23
def __init__(self, every: int, squared: bool = False, name: str = "weight_norm"):
    """
    Callback to save the weight norm `every` update steps.

    **Arguments:**

    - `every`: The frequency of the callback.
    - `squared`: Whether to return the squared weight norm.
    - `name`: The name of the callback
    """
    self.squared = squared
    super().__init__(every, name)

trainax.callback.Loss ¤

Bases: BaseCallback

Source code in trainax/callback/_loss.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class Loss(BaseCallback):
    loss_configuration: BaseConfiguration
    with_grad: bool

    ref_stepper: eqx.Module
    residuum_fn: eqx.Module

    def __init__(
        self,
        every: int,
        loss_configuration: BaseConfiguration,
        *,
        with_grad: bool = False,
        ref_stepper: eqx.Module = None,
        residuum_fn: eqx.Module = None,
        name: str,
    ):
        """
        Callback to save the loss associated with `loss_configuration` `every`
        update steps.

        Use this to measure a stepper performance on a difference configuration
        than the training loss.

        **Arguments:**

        - `every`: The frequency of the callback.
        - `loss_configuration`: The loss configuration to compute the loss.
        - `with_grad`: Whether to also return the associated gradient. If only
            the gradient norm is desired, set this to `False` and consider using
            [`trainax.callback.GradNorm`]().
        - `ref_stepper`: A reference stepper that is used to compute the residuum.
            Supply this if the loss configuration requires a reference stepper.
        - `residuum_fn`: A residuum function that computes the discrete residuum
            between two consecutive states. Supply this if the loss configuration
            requires a residuum function.
        - `name`: The name of the callback.
        """
        self.loss_configuration = loss_configuration
        self.with_grad = with_grad
        self.ref_stepper = ref_stepper
        self.residuum_fn = residuum_fn
        super().__init__(every, name)

    def callback(
        self,
        update_i: int,
        stepper: eqx.Module,
        data: PyTree,
    ) -> Union[eqx.Module, tuple[eqx.Module, eqx.Module]]:
        """
        Compute the loss and optionally the associated gradient.
        """
        if self.with_grad:
            loss, grad = eqx.filter_value_and_grad(self.loss_configuration)(
                stepper,
                data,
                ref_stepper=self.ref_stepper,
                residuum_fn=self.residuum_fn,
            )
            return loss, grad
        else:
            loss = self.loss_configuration(
                stepper,
                data,
                ref_stepper=self.ref_stepper,
                residuum_fn=self.residuum_fn,
            )
            return loss
__init__ ¤
__init__(
    every: int,
    loss_configuration: BaseConfiguration,
    *,
    with_grad: bool = False,
    ref_stepper: eqx.Module = None,
    residuum_fn: eqx.Module = None,
    name: str
)

Callback to save the loss associated with loss_configuration every update steps.

Use this to measure a stepper performance on a difference configuration than the training loss.

Arguments:

  • every: The frequency of the callback.
  • loss_configuration: The loss configuration to compute the loss.
  • with_grad: Whether to also return the associated gradient. If only the gradient norm is desired, set this to False and consider using trainax.callback.GradNorm.
  • ref_stepper: A reference stepper that is used to compute the residuum. Supply this if the loss configuration requires a reference stepper.
  • residuum_fn: A residuum function that computes the discrete residuum between two consecutive states. Supply this if the loss configuration requires a residuum function.
  • name: The name of the callback.
Source code in trainax/callback/_loss.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    every: int,
    loss_configuration: BaseConfiguration,
    *,
    with_grad: bool = False,
    ref_stepper: eqx.Module = None,
    residuum_fn: eqx.Module = None,
    name: str,
):
    """
    Callback to save the loss associated with `loss_configuration` `every`
    update steps.

    Use this to measure a stepper performance on a difference configuration
    than the training loss.

    **Arguments:**

    - `every`: The frequency of the callback.
    - `loss_configuration`: The loss configuration to compute the loss.
    - `with_grad`: Whether to also return the associated gradient. If only
        the gradient norm is desired, set this to `False` and consider using
        [`trainax.callback.GradNorm`]().
    - `ref_stepper`: A reference stepper that is used to compute the residuum.
        Supply this if the loss configuration requires a reference stepper.
    - `residuum_fn`: A residuum function that computes the discrete residuum
        between two consecutive states. Supply this if the loss configuration
        requires a residuum function.
    - `name`: The name of the callback.
    """
    self.loss_configuration = loss_configuration
    self.with_grad = with_grad
    self.ref_stepper = ref_stepper
    self.residuum_fn = residuum_fn
    super().__init__(every, name)

trainax.callback.GradNorm ¤

Bases: BaseCallback

Source code in trainax/callback/_grad_norm.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class GradNorm(BaseCallback):
    loss_configuration: BaseConfiguration
    squared: bool

    ref_stepper: eqx.Module
    residuum_fn: eqx.Module

    def __init__(
        self,
        every: int,
        loss_configuration: BaseConfiguration,
        *,
        squared: bool = False,
        ref_stepper: eqx.Module = None,
        residuum_fn: eqx.Module = None,
        name: str,
    ):
        """
        Callback to save the gradient norm associated with `loss_configuration`
        `every` update steps.

        **Arguments:**

        - `every`: The frequency of the callback.
        - `loss_configuration`: The loss configuration to compute the gradient
            norm. If the gradient norm associated with the training loss is
            desired, the corresponding loss configuration has to be re-supplied.
        - `squared`: Whether to return the squared gradient norm.
        - `ref_stepper`: A reference stepper that is used to compute the residuum.
            Supply this if the loss configuration requires a reference stepper.
        - `residuum_fn`: A residuum function that computes the discrete residuum
            between two consecutive states. Supply this if the loss configuration
            requires a residuum function.
        - `name`: The name of the callback.
        """
        self.loss_configuration = loss_configuration
        self.squared = squared
        self.ref_stepper = ref_stepper
        self.residuum_fn = residuum_fn
        super().__init__(every, name)

    def callback(
        self,
        update_i: int,
        stepper: eqx.Module,
        data: PyTree,
    ) -> eqx.Module:
        """Compute the gradient norm."""
        grad = eqx.filter_grad(self.loss_configuration)(
            stepper,
            data,
            ref_stepper=self.ref_stepper,
            residuum_fn=self.residuum_fn,
        )
        grad_weights = jtu.tree_leaves(eqx.filter(grad, eqx.is_array))
        grad_norms_squared = [jnp.sum(g**2) for g in grad_weights]
        grad_norm_squared = sum(grad_norms_squared)
        if self.squared:
            return grad_norm_squared
        else:
            return jnp.sqrt(grad_norm_squared)
__init__ ¤
__init__(
    every: int,
    loss_configuration: BaseConfiguration,
    *,
    squared: bool = False,
    ref_stepper: eqx.Module = None,
    residuum_fn: eqx.Module = None,
    name: str
)

Callback to save the gradient norm associated with loss_configuration every update steps.

Arguments:

  • every: The frequency of the callback.
  • loss_configuration: The loss configuration to compute the gradient norm. If the gradient norm associated with the training loss is desired, the corresponding loss configuration has to be re-supplied.
  • squared: Whether to return the squared gradient norm.
  • ref_stepper: A reference stepper that is used to compute the residuum. Supply this if the loss configuration requires a reference stepper.
  • residuum_fn: A residuum function that computes the discrete residuum between two consecutive states. Supply this if the loss configuration requires a residuum function.
  • name: The name of the callback.
Source code in trainax/callback/_grad_norm.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(
    self,
    every: int,
    loss_configuration: BaseConfiguration,
    *,
    squared: bool = False,
    ref_stepper: eqx.Module = None,
    residuum_fn: eqx.Module = None,
    name: str,
):
    """
    Callback to save the gradient norm associated with `loss_configuration`
    `every` update steps.

    **Arguments:**

    - `every`: The frequency of the callback.
    - `loss_configuration`: The loss configuration to compute the gradient
        norm. If the gradient norm associated with the training loss is
        desired, the corresponding loss configuration has to be re-supplied.
    - `squared`: Whether to return the squared gradient norm.
    - `ref_stepper`: A reference stepper that is used to compute the residuum.
        Supply this if the loss configuration requires a reference stepper.
    - `residuum_fn`: A residuum function that computes the discrete residuum
        between two consecutive states. Supply this if the loss configuration
        requires a residuum function.
    - `name`: The name of the callback.
    """
    self.loss_configuration = loss_configuration
    self.squared = squared
    self.ref_stepper = ref_stepper
    self.residuum_fn = residuum_fn
    super().__init__(every, name)

trainax.callback.CompositeCallback ¤

Bases: Module

Source code in trainax/callback/_composite.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class CompositeCallback(eqx.Module):
    callbacks: list[BaseCallback]

    def __init__(self, callbacks: list[BaseCallback]):
        """Callback to combine multiple callbacks."""
        self.callbacks = callbacks

    def __call__(
        self,
        update_i: int,
        stepper: eqx.Module,
        data: PyTree,
    ) -> Any:
        res = {}
        for callback in self.callbacks:
            res.update(callback(update_i, stepper, data))
        return res
__init__ ¤
__init__(callbacks: list[BaseCallback])

Callback to combine multiple callbacks.

Source code in trainax/callback/_composite.py
12
13
14
def __init__(self, callbacks: list[BaseCallback]):
    """Callback to combine multiple callbacks."""
    self.callbacks = callbacks

trainax.callback.BaseCallback ¤

Bases: Module, ABC

Source code in trainax/callback/_base.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class BaseCallback(eqx.Module, ABC):
    every: int
    name: str

    def __init__(self, every: int, name: str):
        """Base class for callbacks."""
        self.every = every
        self.name = name

    @abstractmethod
    def callback(
        self,
        update_i: int,
        stepper: eqx.Module,
        data: PyTree,
    ):
        pass

    def __call__(
        self,
        update_i: int,
        stepper: eqx.Module,
        data: PyTree,
    ) -> Dict[str, Any]:
        """
        Evaluate the Callback.

        **Arguments:**

        - `update_i`: The current update step.
        - `stepper`: The equinox.Module to evaluate the callback on.
        - `data`: The data to evaluate the callback on.

        **Returns:**

        - The result of the callback wrapped into a dictionary.
        """
        if update_i % self.every == 0:
            res = self.callback(update_i, stepper, data)
            return {self.name: res}
        else:
            return {self.name: None}
__init__ ¤
__init__(every: int, name: str)

Base class for callbacks.

Source code in trainax/callback/_base.py
12
13
14
15
def __init__(self, every: int, name: str):
    """Base class for callbacks."""
    self.every = every
    self.name = name