Skip to content

Dilated ResNet¤

A sequential convolutional architecture employing dilated convolutions for long-range interactions.

pdequinox.arch.DilatedResNet ¤

Bases: Sequential

Source code in pdequinox/arch/_dilated_res_net.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class DilatedResNet(Sequential):
    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        hidden_channels: int = 32,
        num_blocks: int = 2,
        dilation_rates: tuple[int, ...] = (1, 2, 4, 8, 4, 2, 1),
        use_norm: bool = True,
        activation: Callable = jax.nn.relu,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"] = "periodic",
        key: PRNGKeyArray,
    ):
        """
        ResNet with varying dilation rates; very close to publication of
        Stachenfeld et al. (2021), based on the implementation of PDEArena.

        Each block consists of a sequence of convolutions with different
        dilation rates as defined by `dilation_rates`. There is an addiitonal
        skip connection.


        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `hidden_channels`: The number of channels in the hidden layers.
            Default is `32`.
        - `num_blocks`: The number of blocks to use. Default is `2`.
        - `dilation_rates`: The dilation rates to use. Default is `(1, 2, 4, 8,
            4, 2, 1)`.
        - `use_norm`: If `True`, uses group norm.
        - `activation`: The activation function to use in the blocks. Default is
            `jax.nn.relu`.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `boundary_mode`: The boundary mode to use. Default is `periodic`.
        """
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            hidden_channels=hidden_channels,
            num_blocks=num_blocks,
            activation=activation,
            key=key,
            boundary_mode=boundary_mode,
            lifting_factory=LinearChannelAdjustBlockFactory(
                use_bias=True,
                zero_bias_init=False,
            ),
            block_factory=DilatedResBlockFactory(
                dilation_rates=dilation_rates,
                use_norm=use_norm,
                use_bias=True,
                zero_bias_init=False,
            ),
            projection_factory=LinearChannelAdjustBlockFactory(
                use_bias=True,
                zero_bias_init=False,
            ),
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 32,
    num_blocks: int = 2,
    dilation_rates: tuple[int, ...] = (1, 2, 4, 8, 4, 2, 1),
    use_norm: bool = True,
    activation: Callable = jax.nn.relu,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ] = "periodic",
    key: PRNGKeyArray
)

ResNet with varying dilation rates; very close to publication of Stachenfeld et al. (2021), based on the implementation of PDEArena.

Each block consists of a sequence of convolutions with different dilation rates as defined by dilation_rates. There is an addiitonal skip connection.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • hidden_channels: The number of channels in the hidden layers. Default is 32.
  • num_blocks: The number of blocks to use. Default is 2.
  • dilation_rates: The dilation rates to use. Default is (1, 2, 4, 8, 4, 2, 1).
  • use_norm: If True, uses group norm.
  • activation: The activation function to use in the blocks. Default is jax.nn.relu.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • boundary_mode: The boundary mode to use. Default is periodic.
Source code in pdequinox/arch/_dilated_res_net.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 32,
    num_blocks: int = 2,
    dilation_rates: tuple[int, ...] = (1, 2, 4, 8, 4, 2, 1),
    use_norm: bool = True,
    activation: Callable = jax.nn.relu,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"] = "periodic",
    key: PRNGKeyArray,
):
    """
    ResNet with varying dilation rates; very close to publication of
    Stachenfeld et al. (2021), based on the implementation of PDEArena.

    Each block consists of a sequence of convolutions with different
    dilation rates as defined by `dilation_rates`. There is an addiitonal
    skip connection.


    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `hidden_channels`: The number of channels in the hidden layers.
        Default is `32`.
    - `num_blocks`: The number of blocks to use. Default is `2`.
    - `dilation_rates`: The dilation rates to use. Default is `(1, 2, 4, 8,
        4, 2, 1)`.
    - `use_norm`: If `True`, uses group norm.
    - `activation`: The activation function to use in the blocks. Default is
        `jax.nn.relu`.
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `boundary_mode`: The boundary mode to use. Default is `periodic`.
    """
    super().__init__(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        hidden_channels=hidden_channels,
        num_blocks=num_blocks,
        activation=activation,
        key=key,
        boundary_mode=boundary_mode,
        lifting_factory=LinearChannelAdjustBlockFactory(
            use_bias=True,
            zero_bias_init=False,
        ),
        block_factory=DilatedResBlockFactory(
            dilation_rates=dilation_rates,
            use_norm=use_norm,
            use_bias=True,
            zero_bias_init=False,
        ),
        projection_factory=LinearChannelAdjustBlockFactory(
            use_bias=True,
            zero_bias_init=False,
        ),
    )
__call__ ¤
__call__(x)
Source code in pdequinox/_sequential.py
111
112
113
114
115
116
def __call__(self, x):
    x = self.lifting(x)
    for block in self.blocks:
        x = block(x)
    x = self.projection(x)
    return x