(Convolutional) ResNet¤
pdequinox.arch.ClassicResNet
¤
Bases: Sequential
Source code in pdequinox/arch/_classic_res_net.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|
__init__
¤
__init__(
num_spatial_dims: int,
in_channels: int,
out_channels: int,
*,
hidden_channels: int = 32,
num_blocks: int = 6,
use_norm: bool = False,
activation: Callable = jax.nn.relu,
boundary_mode: Literal[
"periodic", "dirichlet", "neumann"
] = "periodic",
key: PRNGKeyArray
)
Vanilla ResNet architecture very close the original He et al. (2016) paper.
Performs a sequence of blocks consisting of two convolutions and a
bypass. The structure of the blocks are "post-activation" (original
ResNet paper). For the modern "pre-activation" ResNet, see
ModernResNet
. By default, no group normalization is used. The original
paper used batch normalization.
The total number of convolutions is 2 * num_blocks
(3x3 convolutions)
and two 1x1 convolutions for the lifting and projection.
Arguments:
num_spatial_dims
: The number of spatial dimensions. For example traditional convolutions for image processing have this set to2
.in_channels
: The number of input channels.out_channels
: The number of output channels.hidden_channels
: The number of channels in the hidden layers. Default is32
.num_blocks
: The number of blocks to use. Must be an integer greater or equal than1
. Default is6
.use_norm
: Whether to use group normalization. Default isFalse
.activation
: The activation function to use in the blocks. Default isjax.nn.relu
. Lifting and projection are not activated.boundary_mode
: The boundary mode to use for the convolution. Default is"periodic"
.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
Source code in pdequinox/arch/_classic_res_net.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|
__call__
¤
__call__(x)
Source code in pdequinox/_sequential.py
111 112 113 114 115 116 |
|
pdequinox.arch.ModernResNet
¤
Bases: Sequential
Source code in pdequinox/arch/_modern_res_net.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
__init__
¤
__init__(
num_spatial_dims: int,
in_channels: int,
out_channels: int,
*,
hidden_channels: int = 32,
num_blocks: int = 6,
use_norm: bool = True,
activation: Callable = jax.nn.relu,
boundary_mode: Literal[
"periodic", "dirichlet", "neumann"
] = "periodic",
key: PRNGKeyArray
)
Modern ResNet using pre-activation residual blocks. Based on the implementation of PDEArena.
Arguments:
num_spatial_dims
: The number of spatial dimensions. For example traditional convolutions for image processing have this set to2
.in_channels
: The number of input channels.out_channels
: The number of output channels.hidden_channels
: The number of channels in the hidden layers. Default is32
.num_blocks
: The number of blocks to use. Default is6
.use_norm
: IfTrue
, uses group norm.activation
: The activation function to use in the blocks. Default isjax.nn.relu
.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)boundary_mode
: The boundary mode to use. Default isperiodic
.
Source code in pdequinox/arch/_modern_res_net.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
__call__
¤
__call__(x)
Source code in pdequinox/_sequential.py
111 112 113 114 115 116 |
|