Callbacks¤
trainax.callback.SaveNetwork
¤
Bases: BaseCallback
Source code in trainax/callback/_save_network.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
|
__init__
¤
__init__(
every: int,
path: str,
file_name: str,
name: str = "network_saved",
)
Callback to write the network state to a file every
update step.
Arguments:
every
: The frequency of the callback.path
: The path to save the network state.file_name
: The file name to save the network state.name
: The name of the callback
Source code in trainax/callback/_save_network.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
|
trainax.callback.GetNetwork
¤
Bases: BaseCallback
Source code in trainax/callback/_get_network.py
9 10 11 12 13 14 15 16 17 18 19 20 21 |
|
__init__
¤
__init__(every: int, name: str = 'network')
Callback to write out the network state every
update step.
Source code in trainax/callback/_get_network.py
10 11 12 |
|
trainax.callback.WeightNorm
¤
Bases: BaseCallback
Source code in trainax/callback/_weight_norm.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
|
__init__
¤
__init__(
every: int,
squared: bool = False,
name: str = "weight_norm",
)
Callback to save the weight norm every
update steps.
Arguments:
every
: The frequency of the callback.squared
: Whether to return the squared weight norm.name
: The name of the callback
Source code in trainax/callback/_weight_norm.py
12 13 14 15 16 17 18 19 20 21 22 23 |
|
trainax.callback.Loss
¤
Bases: BaseCallback
Source code in trainax/callback/_loss.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
|
__init__
¤
__init__(
every: int,
loss_configuration: BaseConfiguration,
*,
with_grad: bool = False,
ref_stepper: eqx.Module = None,
residuum_fn: eqx.Module = None,
name: str
)
Callback to save the loss associated with loss_configuration
every
update steps.
Use this to measure a stepper performance on a difference configuration than the training loss.
Arguments:
every
: The frequency of the callback.loss_configuration
: The loss configuration to compute the loss.with_grad
: Whether to also return the associated gradient. If only the gradient norm is desired, set this toFalse
and consider usingtrainax.callback.GradNorm
.ref_stepper
: A reference stepper that is used to compute the residuum. Supply this if the loss configuration requires a reference stepper.residuum_fn
: A residuum function that computes the discrete residuum between two consecutive states. Supply this if the loss configuration requires a residuum function.name
: The name of the callback.
Source code in trainax/callback/_loss.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
|
trainax.callback.GradNorm
¤
Bases: BaseCallback
Source code in trainax/callback/_grad_norm.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
|
__init__
¤
__init__(
every: int,
loss_configuration: BaseConfiguration,
*,
squared: bool = False,
ref_stepper: eqx.Module = None,
residuum_fn: eqx.Module = None,
name: str
)
Callback to save the gradient norm associated with loss_configuration
every
update steps.
Arguments:
every
: The frequency of the callback.loss_configuration
: The loss configuration to compute the gradient norm. If the gradient norm associated with the training loss is desired, the corresponding loss configuration has to be re-supplied.squared
: Whether to return the squared gradient norm.ref_stepper
: A reference stepper that is used to compute the residuum. Supply this if the loss configuration requires a reference stepper.residuum_fn
: A residuum function that computes the discrete residuum between two consecutive states. Supply this if the loss configuration requires a residuum function.name
: The name of the callback.
Source code in trainax/callback/_grad_norm.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
|
trainax.callback.CompositeCallback
¤
Bases: Module
Source code in trainax/callback/_composite.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
|
__init__
¤
__init__(callbacks: list[BaseCallback])
Callback to combine multiple callbacks.
Source code in trainax/callback/_composite.py
12 13 14 |
|
trainax.callback.BaseCallback
¤
Bases: Module
, ABC
Source code in trainax/callback/_base.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
|
__init__
¤
__init__(every: int, name: str)
Base class for callbacks.
Source code in trainax/callback/_base.py
12 13 14 15 |
|