Callbacks¤
trainax.callback.SaveNetwork
¤
Bases: BaseCallback
Source code in trainax/callback/_save_network.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | |
__init__
¤
__init__(
every: int,
path: str,
file_name: str,
name: str = "network_saved",
)
Callback to write the network state to a file every update step.
Arguments:
every: The frequency of the callback.path: The path to save the network state.file_name: The file name to save the network state.name: The name of the callback
Source code in trainax/callback/_save_network.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | |
trainax.callback.GetNetwork
¤
Bases: BaseCallback
Source code in trainax/callback/_get_network.py
9 10 11 12 13 14 15 16 17 18 19 20 21 | |
__init__
¤
__init__(every: int, name: str = 'network')
Callback to write out the network state every update step.
Source code in trainax/callback/_get_network.py
10 11 12 | |
trainax.callback.WeightNorm
¤
Bases: BaseCallback
Source code in trainax/callback/_weight_norm.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | |
__init__
¤
__init__(
every: int,
squared: bool = False,
name: str = "weight_norm",
)
Callback to save the weight norm every update steps.
Arguments:
every: The frequency of the callback.squared: Whether to return the squared weight norm.name: The name of the callback
Source code in trainax/callback/_weight_norm.py
12 13 14 15 16 17 18 19 20 21 22 23 | |
trainax.callback.Loss
¤
Bases: BaseCallback
Source code in trainax/callback/_loss.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | |
__init__
¤
__init__(
every: int,
loss_configuration: BaseConfiguration,
*,
with_grad: bool = False,
ref_stepper: eqx.Module = None,
residuum_fn: eqx.Module = None,
name: str
)
Callback to save the loss associated with loss_configuration every
update steps.
Use this to measure a stepper performance on a difference configuration than the training loss.
Arguments:
every: The frequency of the callback.loss_configuration: The loss configuration to compute the loss.with_grad: Whether to also return the associated gradient. If only the gradient norm is desired, set this toFalseand consider usingtrainax.callback.GradNorm.ref_stepper: A reference stepper that is used to compute the residuum. Supply this if the loss configuration requires a reference stepper.residuum_fn: A residuum function that computes the discrete residuum between two consecutive states. Supply this if the loss configuration requires a residuum function.name: The name of the callback.
Source code in trainax/callback/_loss.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | |
trainax.callback.GradNorm
¤
Bases: BaseCallback
Source code in trainax/callback/_grad_norm.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | |
__init__
¤
__init__(
every: int,
loss_configuration: BaseConfiguration,
*,
squared: bool = False,
ref_stepper: eqx.Module = None,
residuum_fn: eqx.Module = None,
name: str
)
Callback to save the gradient norm associated with loss_configuration
every update steps.
Arguments:
every: The frequency of the callback.loss_configuration: The loss configuration to compute the gradient norm. If the gradient norm associated with the training loss is desired, the corresponding loss configuration has to be re-supplied.squared: Whether to return the squared gradient norm.ref_stepper: A reference stepper that is used to compute the residuum. Supply this if the loss configuration requires a reference stepper.residuum_fn: A residuum function that computes the discrete residuum between two consecutive states. Supply this if the loss configuration requires a residuum function.name: The name of the callback.
Source code in trainax/callback/_grad_norm.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | |
trainax.callback.CompositeCallback
¤
Bases: Module
Source code in trainax/callback/_composite.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | |
__init__
¤
__init__(callbacks: list[BaseCallback])
Callback to combine multiple callbacks.
Source code in trainax/callback/_composite.py
12 13 14 | |
trainax.callback.BaseCallback
¤
Bases: Module, ABC
Source code in trainax/callback/_base.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | |
__init__
¤
__init__(every: int, name: str)
Base class for callbacks.
Source code in trainax/callback/_base.py
12 13 14 15 | |