Skip to content

(Convolutional) ResNet¤

pdequinox.arch.ClassicResNet ¤

Bases: Sequential

Source code in pdequinox/arch/_classic_res_net.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class ClassicResNet(Sequential):
    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        hidden_channels: int = 32,
        num_blocks: int = 6,
        use_norm: bool = False,
        activation: Callable = jax.nn.relu,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"] = "periodic",
        key: PRNGKeyArray,
    ):
        """
        Vanilla ResNet architecture very close the original He et al. (2016)
        paper.

        Performs a sequence of blocks consisting of two convolutions and a
        bypass. The structure of the blocks are "post-activation" (original
        ResNet paper). For the modern "pre-activation" ResNet, see
        `ModernResNet`. By default, no group normalization is used. The original
        paper used batch normalization.

        The total number of convolutions is `2 * num_blocks` (3x3 convolutions)
        and two 1x1 convolutions for the lifting and projection.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `hidden_channels`: The number of channels in the hidden layers.
            Default is `32`.
        - `num_blocks`: The number of blocks to use. Must be an integer greater
            or equal than `1`. Default is `6`.
        - `use_norm`: Whether to use group normalization. Default is `False`.
        - `activation`: The activation function to use in the blocks. Default is
            `jax.nn.relu`. Lifting and projection are **not** activated.
        - `boundary_mode`: The boundary mode to use for the convolution. Default
            is `"periodic"`.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        """
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            hidden_channels=hidden_channels,
            num_blocks=num_blocks,
            activation=activation,
            key=key,
            boundary_mode=boundary_mode,
            lifting_factory=LinearChannelAdjustBlockFactory(
                use_bias=True,
                zero_bias_init=False,
            ),
            block_factory=ClassicResBlockFactory(
                use_norm=use_norm,
                use_bias=True,
                zero_bias_init=False,
            ),
            projection_factory=LinearChannelAdjustBlockFactory(
                use_bias=True,
                zero_bias_init=False,
            ),
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 32,
    num_blocks: int = 6,
    use_norm: bool = False,
    activation: Callable = jax.nn.relu,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ] = "periodic",
    key: PRNGKeyArray
)

Vanilla ResNet architecture very close the original He et al. (2016) paper.

Performs a sequence of blocks consisting of two convolutions and a bypass. The structure of the blocks are "post-activation" (original ResNet paper). For the modern "pre-activation" ResNet, see ModernResNet. By default, no group normalization is used. The original paper used batch normalization.

The total number of convolutions is 2 * num_blocks (3x3 convolutions) and two 1x1 convolutions for the lifting and projection.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • hidden_channels: The number of channels in the hidden layers. Default is 32.
  • num_blocks: The number of blocks to use. Must be an integer greater or equal than 1. Default is 6.
  • use_norm: Whether to use group normalization. Default is False.
  • activation: The activation function to use in the blocks. Default is jax.nn.relu. Lifting and projection are not activated.
  • boundary_mode: The boundary mode to use for the convolution. Default is "periodic".
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
Source code in pdequinox/arch/_classic_res_net.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 32,
    num_blocks: int = 6,
    use_norm: bool = False,
    activation: Callable = jax.nn.relu,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"] = "periodic",
    key: PRNGKeyArray,
):
    """
    Vanilla ResNet architecture very close the original He et al. (2016)
    paper.

    Performs a sequence of blocks consisting of two convolutions and a
    bypass. The structure of the blocks are "post-activation" (original
    ResNet paper). For the modern "pre-activation" ResNet, see
    `ModernResNet`. By default, no group normalization is used. The original
    paper used batch normalization.

    The total number of convolutions is `2 * num_blocks` (3x3 convolutions)
    and two 1x1 convolutions for the lifting and projection.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `hidden_channels`: The number of channels in the hidden layers.
        Default is `32`.
    - `num_blocks`: The number of blocks to use. Must be an integer greater
        or equal than `1`. Default is `6`.
    - `use_norm`: Whether to use group normalization. Default is `False`.
    - `activation`: The activation function to use in the blocks. Default is
        `jax.nn.relu`. Lifting and projection are **not** activated.
    - `boundary_mode`: The boundary mode to use for the convolution. Default
        is `"periodic"`.
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    """
    super().__init__(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        hidden_channels=hidden_channels,
        num_blocks=num_blocks,
        activation=activation,
        key=key,
        boundary_mode=boundary_mode,
        lifting_factory=LinearChannelAdjustBlockFactory(
            use_bias=True,
            zero_bias_init=False,
        ),
        block_factory=ClassicResBlockFactory(
            use_norm=use_norm,
            use_bias=True,
            zero_bias_init=False,
        ),
        projection_factory=LinearChannelAdjustBlockFactory(
            use_bias=True,
            zero_bias_init=False,
        ),
    )
__call__ ¤
__call__(x)
Source code in pdequinox/_sequential.py
111
112
113
114
115
116
def __call__(self, x):
    x = self.lifting(x)
    for block in self.blocks:
        x = block(x)
    x = self.projection(x)
    return x

pdequinox.arch.ModernResNet ¤

Bases: Sequential

Source code in pdequinox/arch/_modern_res_net.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class ModernResNet(Sequential):
    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        hidden_channels: int = 32,
        num_blocks: int = 6,
        use_norm: bool = True,
        activation: Callable = jax.nn.relu,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"] = "periodic",
        key: PRNGKeyArray,
    ):
        """
        Modern ResNet using pre-activation residual blocks. Based on the
        implementation of PDEArena.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `hidden_channels`: The number of channels in the hidden layers.
            Default is `32`.
        - `num_blocks`: The number of blocks to use. Default is `6`.
        - `use_norm`: If `True`, uses group norm.
        - `activation`: The activation function to use in the blocks. Default is
            `jax.nn.relu`.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `boundary_mode`: The boundary mode to use. Default is `periodic`.
        """
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            hidden_channels=hidden_channels,
            num_blocks=num_blocks,
            activation=activation,
            key=key,
            boundary_mode=boundary_mode,
            lifting_factory=LinearChannelAdjustBlockFactory(
                use_bias=True,
                zero_bias_init=False,
            ),
            block_factory=ModernResBlockFactory(
                use_norm=use_norm,
                use_bias=True,
                zero_bias_init=False,
            ),
            projection_factory=LinearChannelAdjustBlockFactory(
                use_bias=True,
                zero_bias_init=False,
            ),
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 32,
    num_blocks: int = 6,
    use_norm: bool = True,
    activation: Callable = jax.nn.relu,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ] = "periodic",
    key: PRNGKeyArray
)

Modern ResNet using pre-activation residual blocks. Based on the implementation of PDEArena.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • hidden_channels: The number of channels in the hidden layers. Default is 32.
  • num_blocks: The number of blocks to use. Default is 6.
  • use_norm: If True, uses group norm.
  • activation: The activation function to use in the blocks. Default is jax.nn.relu.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • boundary_mode: The boundary mode to use. Default is periodic.
Source code in pdequinox/arch/_modern_res_net.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 32,
    num_blocks: int = 6,
    use_norm: bool = True,
    activation: Callable = jax.nn.relu,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"] = "periodic",
    key: PRNGKeyArray,
):
    """
    Modern ResNet using pre-activation residual blocks. Based on the
    implementation of PDEArena.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `hidden_channels`: The number of channels in the hidden layers.
        Default is `32`.
    - `num_blocks`: The number of blocks to use. Default is `6`.
    - `use_norm`: If `True`, uses group norm.
    - `activation`: The activation function to use in the blocks. Default is
        `jax.nn.relu`.
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `boundary_mode`: The boundary mode to use. Default is `periodic`.
    """
    super().__init__(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        hidden_channels=hidden_channels,
        num_blocks=num_blocks,
        activation=activation,
        key=key,
        boundary_mode=boundary_mode,
        lifting_factory=LinearChannelAdjustBlockFactory(
            use_bias=True,
            zero_bias_init=False,
        ),
        block_factory=ModernResBlockFactory(
            use_norm=use_norm,
            use_bias=True,
            zero_bias_init=False,
        ),
        projection_factory=LinearChannelAdjustBlockFactory(
            use_bias=True,
            zero_bias_init=False,
        ),
    )
__call__ ¤
__call__(x)
Source code in pdequinox/_sequential.py
111
112
113
114
115
116
def __call__(self, x):
    x = self.lifting(x)
    for block in self.blocks:
        x = block(x)
    x = self.projection(x)
    return x