(Convolutional) ResNet¤
    
              Bases: Sequential
Source code in pdequinox/arch/_classic_res_net.py
                | 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |  | 
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 32,
    num_blocks: int = 6,
    use_norm: bool = False,
    activation: Callable = jax.nn.relu,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ] = "periodic",
    key: PRNGKeyArray
)
Vanilla ResNet architecture very close the original He et al. (2016) paper.
Performs a sequence of blocks consisting of two convolutions and a
bypass. The structure of the blocks are "post-activation" (original
ResNet paper). For the modern "pre-activation" ResNet, see
ModernResNet. By default, no group normalization is used. The original
paper used batch normalization.
The total number of convolutions is 2 * num_blocks (3x3 convolutions)
and two 1x1 convolutions for the lifting and projection.
Arguments:
- num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to- 2.
- in_channels: The number of input channels.
- out_channels: The number of output channels.
- hidden_channels: The number of channels in the hidden layers. Default is- 32.
- num_blocks: The number of blocks to use. Must be an integer greater or equal than- 1. Default is- 6.
- use_norm: Whether to use group normalization. Default is- False.
- activation: The activation function to use in the blocks. Default is- jax.nn.relu. Lifting and projection are not activated.
- boundary_mode: The boundary mode to use for the convolution. Default is- "periodic".
- key: A- jax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)
Source code in pdequinox/arch/_classic_res_net.py
              | 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |  | 
__call__(x)
Source code in pdequinox/_sequential.py
              | 111 112 113 114 115 116 |  | 
    
              Bases: Sequential
Source code in pdequinox/arch/_modern_res_net.py
                | 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |  | 
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 32,
    num_blocks: int = 6,
    use_norm: bool = True,
    activation: Callable = jax.nn.relu,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ] = "periodic",
    key: PRNGKeyArray
)
Modern ResNet using pre-activation residual blocks. Based on the implementation of PDEArena.
Arguments:
- num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to- 2.
- in_channels: The number of input channels.
- out_channels: The number of output channels.
- hidden_channels: The number of channels in the hidden layers. Default is- 32.
- num_blocks: The number of blocks to use. Default is- 6.
- use_norm: If- True, uses group norm.
- activation: The activation function to use in the blocks. Default is- jax.nn.relu.
- key: A- jax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)
- boundary_mode: The boundary mode to use. Default is- periodic.
Source code in pdequinox/arch/_modern_res_net.py
              | 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |  | 
__call__(x)
Source code in pdequinox/_sequential.py
              | 111 112 113 114 115 116 |  |