Skip to content

Blocks¤

pdequinox.blocks.ClassicDoubleConvBlock ¤

Bases: Block

Source code in pdequinox/blocks/_classic_double_conv_block.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class ClassicDoubleConvBlock(Block):
    conv_1: PhysicsConv
    norm_1: eqx.nn.GroupNorm
    conv_2: PhysicsConv
    norm_2: eqx.nn.GroupNorm
    activation: Callable

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        activation: Callable = jax.nn.relu,
        kernel_size: int = 3,
        use_norm: bool = True,
        num_groups: int = 1,  # for GroupNorm
        use_bias: bool = True,
        zero_bias_init: bool = False,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"],
        key: PRNGKeyArray,
    ):
        """
        Block that performs two sequential convolutions with activation and
        optional group normalization in between.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `boundary_mode`: The boundary mode to use for the convolution.
            (Keyword only argument)
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `activation`: The activation function to use after each convolution.
            Default is `jax.nn.relu`.
        - `kernel_size`: The size of the convolutional kernel. Default is `3`.
        - `use_norm`: Whether to use group normalization. Default is `True`.
        - `num_groups`: The number of groups to use for group normalization.
            Default is `1`.
        - `use_bias`: Whether to use bias in the convolutional layers. Default is
            `True`.
        - `zero_bias_init`: Whether to initialise the bias to zero. Default is
            `False`.
        """

        def conv_constructor(i, o, b, k):
            return PhysicsConv(
                num_spatial_dims=num_spatial_dims,
                in_channels=i,
                out_channels=o,
                kernel_size=kernel_size,
                stride=1,
                dilation=1,
                boundary_mode=boundary_mode,
                use_bias=b,
                zero_bias_init=zero_bias_init,
                key=k,
            )

        k_1, k_2 = jax.random.split(key)
        self.conv_1 = conv_constructor(in_channels, out_channels, use_bias, k_1)
        if use_norm:
            self.norm_1 = eqx.nn.GroupNorm(groups=num_groups, channels=out_channels)
        else:
            self.norm_1 = eqx.nn.Identity()
        self.conv_2 = conv_constructor(out_channels, out_channels, use_bias, k_2)
        if use_norm:
            self.norm_2 = eqx.nn.GroupNorm(groups=num_groups, channels=out_channels)
        else:
            self.norm_2 = eqx.nn.Identity()

        self.activation = activation

    def __call__(self, x):
        x = self.activation(self.norm_1(self.conv_1(x)))
        x = self.activation(self.norm_2(self.conv_2(x)))
        return x

    @property
    def receptive_field(self) -> tuple[tuple[float, float], ...]:
        conv_1_receptive_field = self.conv_1.receptive_field
        conv_2_receptive_field = self.conv_2.receptive_field
        return tuple(
            (
                c_1_i_backward + c_2_i_backward,
                c_1_i_forward + c_2_i_forward,
            )
            for (c_1_i_backward, c_1_i_forward), (c_2_i_backward, c_2_i_forward) in zip(
                conv_1_receptive_field, conv_2_receptive_field
            )
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    activation: Callable = jax.nn.relu,
    kernel_size: int = 3,
    use_norm: bool = True,
    num_groups: int = 1,
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ],
    key: PRNGKeyArray
)

Block that performs two sequential convolutions with activation and optional group normalization in between.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • boundary_mode: The boundary mode to use for the convolution. (Keyword only argument)
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • activation: The activation function to use after each convolution. Default is jax.nn.relu.
  • kernel_size: The size of the convolutional kernel. Default is 3.
  • use_norm: Whether to use group normalization. Default is True.
  • num_groups: The number of groups to use for group normalization. Default is 1.
  • use_bias: Whether to use bias in the convolutional layers. Default is True.
  • zero_bias_init: Whether to initialise the bias to zero. Default is False.
Source code in pdequinox/blocks/_classic_double_conv_block.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    activation: Callable = jax.nn.relu,
    kernel_size: int = 3,
    use_norm: bool = True,
    num_groups: int = 1,  # for GroupNorm
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"],
    key: PRNGKeyArray,
):
    """
    Block that performs two sequential convolutions with activation and
    optional group normalization in between.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `boundary_mode`: The boundary mode to use for the convolution.
        (Keyword only argument)
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `activation`: The activation function to use after each convolution.
        Default is `jax.nn.relu`.
    - `kernel_size`: The size of the convolutional kernel. Default is `3`.
    - `use_norm`: Whether to use group normalization. Default is `True`.
    - `num_groups`: The number of groups to use for group normalization.
        Default is `1`.
    - `use_bias`: Whether to use bias in the convolutional layers. Default is
        `True`.
    - `zero_bias_init`: Whether to initialise the bias to zero. Default is
        `False`.
    """

    def conv_constructor(i, o, b, k):
        return PhysicsConv(
            num_spatial_dims=num_spatial_dims,
            in_channels=i,
            out_channels=o,
            kernel_size=kernel_size,
            stride=1,
            dilation=1,
            boundary_mode=boundary_mode,
            use_bias=b,
            zero_bias_init=zero_bias_init,
            key=k,
        )

    k_1, k_2 = jax.random.split(key)
    self.conv_1 = conv_constructor(in_channels, out_channels, use_bias, k_1)
    if use_norm:
        self.norm_1 = eqx.nn.GroupNorm(groups=num_groups, channels=out_channels)
    else:
        self.norm_1 = eqx.nn.Identity()
    self.conv_2 = conv_constructor(out_channels, out_channels, use_bias, k_2)
    if use_norm:
        self.norm_2 = eqx.nn.GroupNorm(groups=num_groups, channels=out_channels)
    else:
        self.norm_2 = eqx.nn.Identity()

    self.activation = activation
__call__ ¤
__call__(x)
Source code in pdequinox/blocks/_classic_double_conv_block.py
87
88
89
90
def __call__(self, x):
    x = self.activation(self.norm_1(self.conv_1(x)))
    x = self.activation(self.norm_2(self.conv_2(x)))
    return x

pdequinox.blocks.ClassicResBlock ¤

Bases: Module

Source code in pdequinox/blocks/_classic_res_block.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class ClassicResBlock(eqx.Module):
    conv_1: eqx.Module
    norm_1: eqx.Module
    conv_2: eqx.Module
    norm_2: eqx.Module
    bypass_conv: eqx.Module
    bypass_norm: eqx.Module
    activation: Callable

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        activation: Callable = jax.nn.relu,
        kernel_size: int = 3,
        use_norm: bool = False,
        num_groups: int = 1,  # for group norm
        use_bias: bool = True,
        zero_bias_init: bool = False,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"],
        key: PRNGKeyArray,
    ):
        """
        Classical Block of a ResNet with postactivation and optional group
        normalization in between (Default: off)

        If in_channels != out_channels, a bypass convolution (1x1 conv) and
        group normalization (if `use_norm=True`) is added to the residual
        connection.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `boundary_mode`: The boundary mode to use for the convolution.
            (Keyword only argument)
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `activation`: The activation function to use after each convolution.
            Default is `jax.nn.relu`.
        - `kernel_size`: The size of the convolutional kernel. Default is `3`.
        - `use_norm`: Whether to use group normalization. Default is `False`.
        - `num_groups`: The number of groups to use for group normalization.
            Default is `1`.
        - `use_bias`: Whether to use bias in the convolutional layers. Default
            is `True`.
        - `zero_bias_init`: Whether to initialise the bias to zero. Default is
            `False`.
        """

        def conv_constructor(i, o, b, k):
            return PhysicsConv(
                num_spatial_dims=num_spatial_dims,
                in_channels=i,
                out_channels=o,
                kernel_size=kernel_size,
                stride=1,
                dilation=1,
                boundary_mode=boundary_mode,
                use_bias=b,
                zero_bias_init=zero_bias_init,
                key=k,
            )

        k_1, k_2, key = jax.random.split(key, 3)
        self.conv_1 = conv_constructor(in_channels, out_channels, use_bias, k_1)
        if use_norm:
            self.norm_1 = eqx.nn.GroupNorm(groups=num_groups, channels=out_channels)
        else:
            self.norm_1 = eqx.nn.Identity()

        self.conv_2 = conv_constructor(out_channels, out_channels, use_bias, k_2)
        if use_norm:
            self.norm_2 = eqx.nn.GroupNorm(groups=num_groups, channels=out_channels)
        else:
            self.norm_2 = eqx.nn.Identity()

        if out_channels != in_channels:
            bypass_conv_key, _ = jax.random.split(key)
            self.bypass_conv = PointwiseLinearConv(
                num_spatial_dims=num_spatial_dims,
                in_channels=in_channels,
                out_channels=out_channels,
                use_bias=False,  # Following PDEArena
                key=bypass_conv_key,
            )
            if use_norm:
                self.bypass_norm = eqx.nn.GroupNorm(
                    groups=num_groups, channels=out_channels
                )
            else:
                self.bypass_norm = eqx.nn.Identity()
        else:
            self.bypass_conv = eqx.nn.Identity()
            self.bypass_norm = eqx.nn.Identity()

        self.activation = activation

    def __call__(self, x):
        x_skip = x
        x = self.conv_1(x)
        x = self.norm_1(x)
        x = self.activation(x)
        x = self.conv_2(x)
        x = self.norm_2(x)
        x = x + self.bypass_norm(self.bypass_conv(x_skip))
        x = self.activation(x)
        return x

    @property
    def receptive_field(self) -> tuple[tuple[float, float], ...]:
        conv_1_receptive_field = self.conv_1.receptive_field
        conv_2_receptive_field = self.conv_2.receptive_field
        return tuple(
            (
                c_1_i_backward + c_2_i_backward,
                c_1_i_forward + c_2_i_forward,
            )
            for (c_1_i_backward, c_1_i_forward), (c_2_i_backward, c_2_i_forward) in zip(
                conv_1_receptive_field, conv_2_receptive_field
            )
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    activation: Callable = jax.nn.relu,
    kernel_size: int = 3,
    use_norm: bool = False,
    num_groups: int = 1,
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ],
    key: PRNGKeyArray
)

Classical Block of a ResNet with postactivation and optional group normalization in between (Default: off)

If in_channels != out_channels, a bypass convolution (1x1 conv) and group normalization (if use_norm=True) is added to the residual connection.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • boundary_mode: The boundary mode to use for the convolution. (Keyword only argument)
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • activation: The activation function to use after each convolution. Default is jax.nn.relu.
  • kernel_size: The size of the convolutional kernel. Default is 3.
  • use_norm: Whether to use group normalization. Default is False.
  • num_groups: The number of groups to use for group normalization. Default is 1.
  • use_bias: Whether to use bias in the convolutional layers. Default is True.
  • zero_bias_init: Whether to initialise the bias to zero. Default is False.
Source code in pdequinox/blocks/_classic_res_block.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    activation: Callable = jax.nn.relu,
    kernel_size: int = 3,
    use_norm: bool = False,
    num_groups: int = 1,  # for group norm
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"],
    key: PRNGKeyArray,
):
    """
    Classical Block of a ResNet with postactivation and optional group
    normalization in between (Default: off)

    If in_channels != out_channels, a bypass convolution (1x1 conv) and
    group normalization (if `use_norm=True`) is added to the residual
    connection.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `boundary_mode`: The boundary mode to use for the convolution.
        (Keyword only argument)
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `activation`: The activation function to use after each convolution.
        Default is `jax.nn.relu`.
    - `kernel_size`: The size of the convolutional kernel. Default is `3`.
    - `use_norm`: Whether to use group normalization. Default is `False`.
    - `num_groups`: The number of groups to use for group normalization.
        Default is `1`.
    - `use_bias`: Whether to use bias in the convolutional layers. Default
        is `True`.
    - `zero_bias_init`: Whether to initialise the bias to zero. Default is
        `False`.
    """

    def conv_constructor(i, o, b, k):
        return PhysicsConv(
            num_spatial_dims=num_spatial_dims,
            in_channels=i,
            out_channels=o,
            kernel_size=kernel_size,
            stride=1,
            dilation=1,
            boundary_mode=boundary_mode,
            use_bias=b,
            zero_bias_init=zero_bias_init,
            key=k,
        )

    k_1, k_2, key = jax.random.split(key, 3)
    self.conv_1 = conv_constructor(in_channels, out_channels, use_bias, k_1)
    if use_norm:
        self.norm_1 = eqx.nn.GroupNorm(groups=num_groups, channels=out_channels)
    else:
        self.norm_1 = eqx.nn.Identity()

    self.conv_2 = conv_constructor(out_channels, out_channels, use_bias, k_2)
    if use_norm:
        self.norm_2 = eqx.nn.GroupNorm(groups=num_groups, channels=out_channels)
    else:
        self.norm_2 = eqx.nn.Identity()

    if out_channels != in_channels:
        bypass_conv_key, _ = jax.random.split(key)
        self.bypass_conv = PointwiseLinearConv(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            use_bias=False,  # Following PDEArena
            key=bypass_conv_key,
        )
        if use_norm:
            self.bypass_norm = eqx.nn.GroupNorm(
                groups=num_groups, channels=out_channels
            )
        else:
            self.bypass_norm = eqx.nn.Identity()
    else:
        self.bypass_conv = eqx.nn.Identity()
        self.bypass_norm = eqx.nn.Identity()

    self.activation = activation
__call__ ¤
__call__(x)
Source code in pdequinox/blocks/_classic_res_block.py
112
113
114
115
116
117
118
119
120
121
def __call__(self, x):
    x_skip = x
    x = self.conv_1(x)
    x = self.norm_1(x)
    x = self.activation(x)
    x = self.conv_2(x)
    x = self.norm_2(x)
    x = x + self.bypass_norm(self.bypass_conv(x_skip))
    x = self.activation(x)
    return x

pdequinox.blocks.ModernResBlock ¤

Bases: Module

Source code in pdequinox/blocks/_modern_res_block.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class ModernResBlock(eqx.Module):
    conv_1: eqx.Module
    norm_1: eqx.Module
    conv_2: eqx.Module
    norm_2: eqx.Module
    bypass_conv: eqx.Module
    bypass_norm: eqx.Module
    activation: Callable

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        activation: Callable = jax.nn.relu,
        kernel_size: int = 3,
        use_norm: bool = True,
        num_groups: int = 1,  # for GroupNorm
        use_bias: bool = True,
        zero_bias_init: bool = False,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"],
        key: PRNGKeyArray,
    ):
        """
        Block that performs two sequential convolutions with activation and
        optional group normalization in between. The order of operations is
        based on "pre-activation" to allow for a clean bypass/residual
        connection.

        If the number of input channels is different from the number of output
        channels, a pointwise convolution (without bias) is used to match the
        number of channels.

        If `use_norm` is `True`, group normalization is used after each
        convolution. If there is a convolution that matches the number of
        channels, the bypass will also have group normalization.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `boundary_mode`: The boundary mode to use for the convolution.
            (Keyword only argument)
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `activation`: The activation function to use after each convolution.
            Default is `jax.nn.relu`.
        - `kernel_size`: The size of the convolutional kernel. Default is `3`.
        - `use_norm`: Whether to use group normalization. Default is `True`.
        - `num_groups`: The number of groups to use for group normalization.
            Default is `1`.
        - `use_bias`: Whether to use bias in the convolutional layers. Default
            is `True`.
        - `zero_bias_init`: Whether to initialise the bias to zero. Default is
            `False`.
        """

        def conv_constructor(i, o, b, k):
            return PhysicsConv(
                num_spatial_dims=num_spatial_dims,
                in_channels=i,
                out_channels=o,
                kernel_size=kernel_size,
                stride=1,
                dilation=1,
                boundary_mode=boundary_mode,
                use_bias=b,
                zero_bias_init=zero_bias_init,
                key=k,
            )

        conv_1_key, conv_2_key, key = jax.random.split(key, 3)

        if use_norm:
            self.norm_1 = eqx.nn.GroupNorm(groups=num_groups, channels=in_channels)
        else:
            self.norm_1 = eqx.nn.Identity()
        self.conv_1 = conv_constructor(in_channels, out_channels, use_bias, conv_1_key)

        # In the PDEArena, for some reason, there is always a second group norm
        # even if use_norm is False
        if use_norm:
            self.norm_2 = eqx.nn.GroupNorm(groups=num_groups, channels=out_channels)
        else:
            self.norm_2 = eqx.nn.Identity()
        self.conv_2 = conv_constructor(out_channels, out_channels, use_bias, conv_2_key)

        self.activation = activation

        if out_channels != in_channels:
            bypass_conv_key, _ = jax.random.split(key)

            if use_norm:
                self.bypass_norm = eqx.nn.GroupNorm(
                    groups=num_groups, channels=in_channels
                )
            else:
                self.bypass_norm = eqx.nn.Identity()

            self.bypass_conv = PointwiseLinearConv(
                num_spatial_dims=num_spatial_dims,
                in_channels=in_channels,
                out_channels=out_channels,
                use_bias=False,  # Following PDEArena
                key=bypass_conv_key,
            )
        else:
            self.bypass_norm = eqx.nn.Identity()
            self.bypass_conv = eqx.nn.Identity()

    def __call__(self, x):
        x_skip = x
        # Using pre-activation instead of post-activation
        x = self.conv_1(self.activation(self.norm_1(x)))
        x = self.conv_2(self.activation(self.norm_2(x)))

        x = x + self.bypass_conv(self.bypass_norm(x_skip))
        return x

    @property
    def receptive_field(self) -> tuple[tuple[float, float], ...]:
        conv_1_receptive_field = self.conv_1.receptive_field
        conv_2_receptive_field = self.conv_2.receptive_field
        return tuple(
            (
                c_1_i_backward + c_2_i_backward,
                c_1_i_forward + c_2_i_forward,
            )
            for (c_1_i_backward, c_1_i_forward), (c_2_i_backward, c_2_i_forward) in zip(
                conv_1_receptive_field, conv_2_receptive_field
            )
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    activation: Callable = jax.nn.relu,
    kernel_size: int = 3,
    use_norm: bool = True,
    num_groups: int = 1,
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ],
    key: PRNGKeyArray
)

Block that performs two sequential convolutions with activation and optional group normalization in between. The order of operations is based on "pre-activation" to allow for a clean bypass/residual connection.

If the number of input channels is different from the number of output channels, a pointwise convolution (without bias) is used to match the number of channels.

If use_norm is True, group normalization is used after each convolution. If there is a convolution that matches the number of channels, the bypass will also have group normalization.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • boundary_mode: The boundary mode to use for the convolution. (Keyword only argument)
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • activation: The activation function to use after each convolution. Default is jax.nn.relu.
  • kernel_size: The size of the convolutional kernel. Default is 3.
  • use_norm: Whether to use group normalization. Default is True.
  • num_groups: The number of groups to use for group normalization. Default is 1.
  • use_bias: Whether to use bias in the convolutional layers. Default is True.
  • zero_bias_init: Whether to initialise the bias to zero. Default is False.
Source code in pdequinox/blocks/_modern_res_block.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    activation: Callable = jax.nn.relu,
    kernel_size: int = 3,
    use_norm: bool = True,
    num_groups: int = 1,  # for GroupNorm
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"],
    key: PRNGKeyArray,
):
    """
    Block that performs two sequential convolutions with activation and
    optional group normalization in between. The order of operations is
    based on "pre-activation" to allow for a clean bypass/residual
    connection.

    If the number of input channels is different from the number of output
    channels, a pointwise convolution (without bias) is used to match the
    number of channels.

    If `use_norm` is `True`, group normalization is used after each
    convolution. If there is a convolution that matches the number of
    channels, the bypass will also have group normalization.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `boundary_mode`: The boundary mode to use for the convolution.
        (Keyword only argument)
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `activation`: The activation function to use after each convolution.
        Default is `jax.nn.relu`.
    - `kernel_size`: The size of the convolutional kernel. Default is `3`.
    - `use_norm`: Whether to use group normalization. Default is `True`.
    - `num_groups`: The number of groups to use for group normalization.
        Default is `1`.
    - `use_bias`: Whether to use bias in the convolutional layers. Default
        is `True`.
    - `zero_bias_init`: Whether to initialise the bias to zero. Default is
        `False`.
    """

    def conv_constructor(i, o, b, k):
        return PhysicsConv(
            num_spatial_dims=num_spatial_dims,
            in_channels=i,
            out_channels=o,
            kernel_size=kernel_size,
            stride=1,
            dilation=1,
            boundary_mode=boundary_mode,
            use_bias=b,
            zero_bias_init=zero_bias_init,
            key=k,
        )

    conv_1_key, conv_2_key, key = jax.random.split(key, 3)

    if use_norm:
        self.norm_1 = eqx.nn.GroupNorm(groups=num_groups, channels=in_channels)
    else:
        self.norm_1 = eqx.nn.Identity()
    self.conv_1 = conv_constructor(in_channels, out_channels, use_bias, conv_1_key)

    # In the PDEArena, for some reason, there is always a second group norm
    # even if use_norm is False
    if use_norm:
        self.norm_2 = eqx.nn.GroupNorm(groups=num_groups, channels=out_channels)
    else:
        self.norm_2 = eqx.nn.Identity()
    self.conv_2 = conv_constructor(out_channels, out_channels, use_bias, conv_2_key)

    self.activation = activation

    if out_channels != in_channels:
        bypass_conv_key, _ = jax.random.split(key)

        if use_norm:
            self.bypass_norm = eqx.nn.GroupNorm(
                groups=num_groups, channels=in_channels
            )
        else:
            self.bypass_norm = eqx.nn.Identity()

        self.bypass_conv = PointwiseLinearConv(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            use_bias=False,  # Following PDEArena
            key=bypass_conv_key,
        )
    else:
        self.bypass_norm = eqx.nn.Identity()
        self.bypass_conv = eqx.nn.Identity()
__call__ ¤
__call__(x)
Source code in pdequinox/blocks/_modern_res_block.py
133
134
135
136
137
138
139
140
def __call__(self, x):
    x_skip = x
    # Using pre-activation instead of post-activation
    x = self.conv_1(self.activation(self.norm_1(x)))
    x = self.conv_2(self.activation(self.norm_2(x)))

    x = x + self.bypass_conv(self.bypass_norm(x_skip))
    return x

pdequinox.blocks.ClassicSpectralBlock ¤

Bases: Block

Source code in pdequinox/blocks/_classic_spectral_block.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class ClassicSpectralBlock(Block):
    spectral_conv: SpectralConv
    by_pass_conv: PointwiseLinearConv
    activation: Callable

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        # Uses gelu because it likely recovers more modes
        activation: Callable = jax.nn.gelu,
        num_modes: int = 8,
        use_bias: bool = True,
        zero_bias_init: bool = False,
        key: PRNGKeyArray,
    ):
        """
        Residual-style block as used in vanilla FNOs; combines a spectral
        convolution with a bypass.

        Does not have argument `boundary_mode` because it would not respect it.
        In the original FNO paper it is argued that the bypass helps recover the
        boundary condition.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `activation`: The activation function to use after each convolution.
            Default is `jax.nn.relu`.
        - `num_modes`: How many modes to consider in Fourier space. At max this
            can be N//2+1, with N being the number of spatial points. Think of
            it as the analogy of the kernel size.
        - `use_bias`: Whether to use a bias in the bypass convolution. Default
            `True`.
        - `zero_bias_init`: Whether to initialise the bias to zero. Default is
            `False`.
        """
        k_1, k_2 = jax.random.split(key)
        self.spectral_conv = SpectralConv(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            num_modes=num_modes,
            key=k_1,
        )
        self.by_pass_conv = PointwiseLinearConv(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            use_bias=use_bias,
            zero_bias_init=zero_bias_init,
            key=k_2,
        )
        self.activation = activation

    def __call__(self, x):
        x = self.spectral_conv(x) + self.by_pass_conv(x)
        x = self.activation(x)
        return x

    @property
    def receptive_field(self) -> tuple[tuple[float, float], ...]:
        return self.spectral_conv.receptive_field
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    activation: Callable = jax.nn.gelu,
    num_modes: int = 8,
    use_bias: bool = True,
    zero_bias_init: bool = False,
    key: PRNGKeyArray
)

Residual-style block as used in vanilla FNOs; combines a spectral convolution with a bypass.

Does not have argument boundary_mode because it would not respect it. In the original FNO paper it is argued that the bypass helps recover the boundary condition.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • activation: The activation function to use after each convolution. Default is jax.nn.relu.
  • num_modes: How many modes to consider in Fourier space. At max this can be N//2+1, with N being the number of spatial points. Think of it as the analogy of the kernel size.
  • use_bias: Whether to use a bias in the bypass convolution. Default True.
  • zero_bias_init: Whether to initialise the bias to zero. Default is False.
Source code in pdequinox/blocks/_classic_spectral_block.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    # Uses gelu because it likely recovers more modes
    activation: Callable = jax.nn.gelu,
    num_modes: int = 8,
    use_bias: bool = True,
    zero_bias_init: bool = False,
    key: PRNGKeyArray,
):
    """
    Residual-style block as used in vanilla FNOs; combines a spectral
    convolution with a bypass.

    Does not have argument `boundary_mode` because it would not respect it.
    In the original FNO paper it is argued that the bypass helps recover the
    boundary condition.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `activation`: The activation function to use after each convolution.
        Default is `jax.nn.relu`.
    - `num_modes`: How many modes to consider in Fourier space. At max this
        can be N//2+1, with N being the number of spatial points. Think of
        it as the analogy of the kernel size.
    - `use_bias`: Whether to use a bias in the bypass convolution. Default
        `True`.
    - `zero_bias_init`: Whether to initialise the bias to zero. Default is
        `False`.
    """
    k_1, k_2 = jax.random.split(key)
    self.spectral_conv = SpectralConv(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        num_modes=num_modes,
        key=k_1,
    )
    self.by_pass_conv = PointwiseLinearConv(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        use_bias=use_bias,
        zero_bias_init=zero_bias_init,
        key=k_2,
    )
    self.activation = activation
__call__ ¤
__call__(x)
Source code in pdequinox/blocks/_classic_spectral_block.py
72
73
74
75
def __call__(self, x):
    x = self.spectral_conv(x) + self.by_pass_conv(x)
    x = self.activation(x)
    return x

pdequinox.blocks.DilatedResBlock ¤

Bases: Module

Source code in pdequinox/blocks/_dilated_res_block.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class DilatedResBlock(eqx.Module):
    norm_layers: tuple[eqx.nn.GroupNorm]
    conv_layers: tuple[PhysicsConv]
    activation: Callable
    bypass_conv: eqx.Module
    bypass_norm: eqx.Module

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        activation: Callable = jax.nn.relu,
        kernel_size: int = 3,
        dilation_rates: tuple[int] = (1, 2, 4, 8, 4, 2, 1),
        use_norm: bool = True,
        num_groups: int = 1,  # for GroupNorm
        use_bias: bool = True,
        zero_bias_init: bool = False,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"],
        key: PRNGKeyArray,
    ):
        """
        Block that performs a sequence of convolutions with varying dilation
        rates. Dilation refers to how many (virtual) zeros are inserted between
        kernel elements, effectively resulting into a larger receptive field. A
        bypass is added turning this block into a residual element.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `boundary_mode`: The boundary mode to use for the convolution.
            (Keyword only argument)
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `activation`: The activation function to use after each convolution.
            Default is `jax.nn.relu`.
        - `kernel_size`: The size of the convolutional kernel. Default is `3`.
        - `dilation_rates`: A sequence of integers. Their length identifies the
            number of sequential convolutions performed. Each integer is the
            dilation performed at that convolution. Typically, this list follows
            the pattern of first increasing in dilation rate, and then
            decreasing again. Default is `(1, 2, 4, 8, 4, 2, 1)`.
        - `use_norm`: Whether to use group normalization. Default is `True`.
        - `num_groups`: The number of groups to use for group normalization.
            Default is `1`.
        - `use_bias`: Whether to use bias in the convolutional layers. Default is
            `True`.
        - `zero_bias_init`: Whether to initialise the bias to zero. Default is
            `False`.
        """

        def conv_constructor(i, o, d, b, k):
            return PhysicsConv(
                num_spatial_dims=num_spatial_dims,
                in_channels=i,
                out_channels=o,
                kernel_size=kernel_size,
                stride=1,
                dilation=d,
                boundary_mode=boundary_mode,
                use_bias=b,
                zero_bias_init=zero_bias_init,
                key=k,
            )

        if use_norm:
            norm_layers = []
            norm_layers.append(
                eqx.nn.GroupNorm(groups=num_groups, channels=in_channels)
            )

            for _ in dilation_rates[1:]:
                norm_layers.append(
                    eqx.nn.GroupNorm(groups=num_groups, channels=out_channels)
                )

            self.norm_layers = tuple(norm_layers)
        else:
            self.norm_layers = tuple(eqx.nn.Identity() for _ in dilation_rates)

        key, *keys = jax.random.split(key, len(dilation_rates) + 1)

        conv_layers = []
        conv_layers.append(
            conv_constructor(
                in_channels, out_channels, dilation_rates[0], use_bias, keys[0]
            )
        )
        for d, k in zip(dilation_rates[1:], keys[1:]):
            conv_layers.append(
                conv_constructor(out_channels, out_channels, d, use_bias, k)
            )

        self.conv_layers = tuple(conv_layers)

        self.activation = activation

        if out_channels != in_channels:
            if use_norm:
                self.bypass_norm = eqx.nn.GroupNorm(
                    groups=num_groups, channels=in_channels
                )
            else:
                self.bypass_norm = eqx.nn.Identity()

            bypass_conv_key, _ = jax.random.split(key)
            self.bypass_conv = PointwiseLinearConv(
                num_spatial_dims=num_spatial_dims,
                in_channels=in_channels,
                out_channels=out_channels,
                use_bias=use_bias,  # Todo: should this be True or False by default?
                zero_bias_init=zero_bias_init,
                key=bypass_conv_key,
            )
        else:
            self.bypass_conv = eqx.nn.Identity()
            self.bypass_norm = eqx.nn.Identity()

    def __call__(self, x):
        x_skip = x
        for norm, conv in zip(self.norm_layers, self.conv_layers):
            x = norm(x)
            x = conv(x)
            x = self.activation(x)

        x_skip = self.bypass_conv(self.bypass_norm(x_skip))
        x = x + x_skip

        return x

    @property
    def receptive_field(self) -> tuple[tuple[float, float], ...]:
        individual_receptive_fields = tuple(
            (conv.receptive_field for conv in self.conv_layers)
        )
        return sum_receptive_fields(individual_receptive_fields)
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    activation: Callable = jax.nn.relu,
    kernel_size: int = 3,
    dilation_rates: tuple[int] = (1, 2, 4, 8, 4, 2, 1),
    use_norm: bool = True,
    num_groups: int = 1,
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ],
    key: PRNGKeyArray
)

Block that performs a sequence of convolutions with varying dilation rates. Dilation refers to how many (virtual) zeros are inserted between kernel elements, effectively resulting into a larger receptive field. A bypass is added turning this block into a residual element.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • boundary_mode: The boundary mode to use for the convolution. (Keyword only argument)
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • activation: The activation function to use after each convolution. Default is jax.nn.relu.
  • kernel_size: The size of the convolutional kernel. Default is 3.
  • dilation_rates: A sequence of integers. Their length identifies the number of sequential convolutions performed. Each integer is the dilation performed at that convolution. Typically, this list follows the pattern of first increasing in dilation rate, and then decreasing again. Default is (1, 2, 4, 8, 4, 2, 1).
  • use_norm: Whether to use group normalization. Default is True.
  • num_groups: The number of groups to use for group normalization. Default is 1.
  • use_bias: Whether to use bias in the convolutional layers. Default is True.
  • zero_bias_init: Whether to initialise the bias to zero. Default is False.
Source code in pdequinox/blocks/_dilated_res_block.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    activation: Callable = jax.nn.relu,
    kernel_size: int = 3,
    dilation_rates: tuple[int] = (1, 2, 4, 8, 4, 2, 1),
    use_norm: bool = True,
    num_groups: int = 1,  # for GroupNorm
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"],
    key: PRNGKeyArray,
):
    """
    Block that performs a sequence of convolutions with varying dilation
    rates. Dilation refers to how many (virtual) zeros are inserted between
    kernel elements, effectively resulting into a larger receptive field. A
    bypass is added turning this block into a residual element.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `boundary_mode`: The boundary mode to use for the convolution.
        (Keyword only argument)
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `activation`: The activation function to use after each convolution.
        Default is `jax.nn.relu`.
    - `kernel_size`: The size of the convolutional kernel. Default is `3`.
    - `dilation_rates`: A sequence of integers. Their length identifies the
        number of sequential convolutions performed. Each integer is the
        dilation performed at that convolution. Typically, this list follows
        the pattern of first increasing in dilation rate, and then
        decreasing again. Default is `(1, 2, 4, 8, 4, 2, 1)`.
    - `use_norm`: Whether to use group normalization. Default is `True`.
    - `num_groups`: The number of groups to use for group normalization.
        Default is `1`.
    - `use_bias`: Whether to use bias in the convolutional layers. Default is
        `True`.
    - `zero_bias_init`: Whether to initialise the bias to zero. Default is
        `False`.
    """

    def conv_constructor(i, o, d, b, k):
        return PhysicsConv(
            num_spatial_dims=num_spatial_dims,
            in_channels=i,
            out_channels=o,
            kernel_size=kernel_size,
            stride=1,
            dilation=d,
            boundary_mode=boundary_mode,
            use_bias=b,
            zero_bias_init=zero_bias_init,
            key=k,
        )

    if use_norm:
        norm_layers = []
        norm_layers.append(
            eqx.nn.GroupNorm(groups=num_groups, channels=in_channels)
        )

        for _ in dilation_rates[1:]:
            norm_layers.append(
                eqx.nn.GroupNorm(groups=num_groups, channels=out_channels)
            )

        self.norm_layers = tuple(norm_layers)
    else:
        self.norm_layers = tuple(eqx.nn.Identity() for _ in dilation_rates)

    key, *keys = jax.random.split(key, len(dilation_rates) + 1)

    conv_layers = []
    conv_layers.append(
        conv_constructor(
            in_channels, out_channels, dilation_rates[0], use_bias, keys[0]
        )
    )
    for d, k in zip(dilation_rates[1:], keys[1:]):
        conv_layers.append(
            conv_constructor(out_channels, out_channels, d, use_bias, k)
        )

    self.conv_layers = tuple(conv_layers)

    self.activation = activation

    if out_channels != in_channels:
        if use_norm:
            self.bypass_norm = eqx.nn.GroupNorm(
                groups=num_groups, channels=in_channels
            )
        else:
            self.bypass_norm = eqx.nn.Identity()

        bypass_conv_key, _ = jax.random.split(key)
        self.bypass_conv = PointwiseLinearConv(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            use_bias=use_bias,  # Todo: should this be True or False by default?
            zero_bias_init=zero_bias_init,
            key=bypass_conv_key,
        )
    else:
        self.bypass_conv = eqx.nn.Identity()
        self.bypass_norm = eqx.nn.Identity()
__call__ ¤
__call__(x)
Source code in pdequinox/blocks/_dilated_res_block.py
141
142
143
144
145
146
147
148
149
150
151
def __call__(self, x):
    x_skip = x
    for norm, conv in zip(self.norm_layers, self.conv_layers):
        x = norm(x)
        x = conv(x)
        x = self.activation(x)

    x_skip = self.bypass_conv(self.bypass_norm(x_skip))
    x = x + x_skip

    return x

pdequinox.blocks.LinearConvBlock ¤

Bases: PhysicsConv

Source code in pdequinox/blocks/_linear_conv_block.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class LinearConvBlock(PhysicsConv):
    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        kernel_size: int = 3,
        use_bias: bool = True,
        zero_bias_init: bool = False,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"],
        key: PRNGKeyArray,
    ):
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=1,
            dilation=1,
            boundary_mode=boundary_mode,
            use_bias=use_bias,
            zero_bias_init=zero_bias_init,
            key=key,
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    kernel_size: int = 3,
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ],
    key: PRNGKeyArray
)
Source code in pdequinox/blocks/_linear_conv_block.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    kernel_size: int = 3,
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"],
    key: PRNGKeyArray,
):
    super().__init__(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=1,
        dilation=1,
        boundary_mode=boundary_mode,
        use_bias=use_bias,
        zero_bias_init=zero_bias_init,
        key=key,
    )
__call__ ¤
__call__(
    x: Array, *, key: Optional[PRNGKeyArray] = None
) -> Array

Arguments:

  • x: The input. Should be a JAX array of shape (in_channels, dim_1, ..., dim_N), where N = num_spatial_dims.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (out_channels, new_dim_1, ..., new_dim_N).

Source code in pdequinox/conv/_conv.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
@jax.named_scope("eqx.nn.Conv")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
    """**Arguments:**

    - `x`: The input. Should be a JAX array of shape
        `(in_channels, dim_1, ..., dim_N)`, where `N = num_spatial_dims`.
    - `key`: Ignored; provided for compatibility with the rest of the Equinox API.
        (Keyword only argument.)

    **Returns:**

    A JAX array of shape `(out_channels, new_dim_1, ..., new_dim_N)`.
    """

    unbatched_rank = self.num_spatial_dims + 1
    if x.ndim != unbatched_rank:
        raise ValueError(
            f"Input to `Conv` needs to have rank {unbatched_rank},",
            f" but input has shape {x.shape}.",
        )
    if self.padding_mode == "circular":
        x = jnp.pad(x, ((0, 0),) + self.padding, mode="wrap")
        padding_lax = ((0, 0),) * self.num_spatial_dims
    elif self.padding_mode == "zeros":
        x = x
        padding_lax = self.padding
    elif self.padding_mode == "reflect":
        x = jnp.pad(x, ((0, 0),) + self.padding, mode="reflect")
        padding_lax = ((0, 0),) * self.num_spatial_dims
    elif self.padding_mode == "replicate":
        x = jnp.pad(x, ((0, 0),) + self.padding, mode="edge")
        padding_lax = ((0, 0),) * self.num_spatial_dims
    else:
        raise ValueError(
            f"`padding_mode` must be one of ['zeros', 'reflect', 'replicate', 'circular'],"
            f" but got {self.padding_mode}."
        )

    x = jnp.expand_dims(x, axis=0)
    x = lax.conv_general_dilated(
        lhs=x,
        rhs=self.weight,
        window_strides=self.stride,
        padding=padding_lax,
        rhs_dilation=self.dilation,
        feature_group_count=self.groups,
    )
    x = jnp.squeeze(x, axis=0)
    if self.use_bias:
        x = x + self.bias
    return x

pdequinox.blocks.LinearConvDownBlock ¤

Bases: PhysicsConv

Source code in pdequinox/blocks/_linear_conv_down_block.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class LinearConvDownBlock(PhysicsConv):
    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        kernel_size: int = 3,
        factor: int = 2,
        use_bias: bool = True,
        zero_bias_init: bool = False,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"],
        key: PRNGKeyArray,
    ):
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=factor,
            dilation=1,
            boundary_mode=boundary_mode,
            use_bias=use_bias,
            zero_bias_init=zero_bias_init,
            key=key,
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    kernel_size: int = 3,
    factor: int = 2,
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ],
    key: PRNGKeyArray
)
Source code in pdequinox/blocks/_linear_conv_down_block.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    kernel_size: int = 3,
    factor: int = 2,
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"],
    key: PRNGKeyArray,
):
    super().__init__(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=factor,
        dilation=1,
        boundary_mode=boundary_mode,
        use_bias=use_bias,
        zero_bias_init=zero_bias_init,
        key=key,
    )
__call__ ¤
__call__(
    x: Array, *, key: Optional[PRNGKeyArray] = None
) -> Array

Arguments:

  • x: The input. Should be a JAX array of shape (in_channels, dim_1, ..., dim_N), where N = num_spatial_dims.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (out_channels, new_dim_1, ..., new_dim_N).

Source code in pdequinox/conv/_conv.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
@jax.named_scope("eqx.nn.Conv")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
    """**Arguments:**

    - `x`: The input. Should be a JAX array of shape
        `(in_channels, dim_1, ..., dim_N)`, where `N = num_spatial_dims`.
    - `key`: Ignored; provided for compatibility with the rest of the Equinox API.
        (Keyword only argument.)

    **Returns:**

    A JAX array of shape `(out_channels, new_dim_1, ..., new_dim_N)`.
    """

    unbatched_rank = self.num_spatial_dims + 1
    if x.ndim != unbatched_rank:
        raise ValueError(
            f"Input to `Conv` needs to have rank {unbatched_rank},",
            f" but input has shape {x.shape}.",
        )
    if self.padding_mode == "circular":
        x = jnp.pad(x, ((0, 0),) + self.padding, mode="wrap")
        padding_lax = ((0, 0),) * self.num_spatial_dims
    elif self.padding_mode == "zeros":
        x = x
        padding_lax = self.padding
    elif self.padding_mode == "reflect":
        x = jnp.pad(x, ((0, 0),) + self.padding, mode="reflect")
        padding_lax = ((0, 0),) * self.num_spatial_dims
    elif self.padding_mode == "replicate":
        x = jnp.pad(x, ((0, 0),) + self.padding, mode="edge")
        padding_lax = ((0, 0),) * self.num_spatial_dims
    else:
        raise ValueError(
            f"`padding_mode` must be one of ['zeros', 'reflect', 'replicate', 'circular'],"
            f" but got {self.padding_mode}."
        )

    x = jnp.expand_dims(x, axis=0)
    x = lax.conv_general_dilated(
        lhs=x,
        rhs=self.weight,
        window_strides=self.stride,
        padding=padding_lax,
        rhs_dilation=self.dilation,
        feature_group_count=self.groups,
    )
    x = jnp.squeeze(x, axis=0)
    if self.use_bias:
        x = x + self.bias
    return x

pdequinox.blocks.LinearConvUpBlock ¤

Bases: PhysicsConvTranspose

Source code in pdequinox/blocks/_linear_conv_up_block.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class LinearConvUpBlock(PhysicsConvTranspose):
    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        kernel_size: int = 3,
        factor: int = 2,
        output_padding: int = 1,
        use_bias: bool = True,
        zero_bias_init: bool = False,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"],
        key: PRNGKeyArray,
    ):
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=factor,
            output_padding=output_padding,
            dilation=1,
            boundary_mode=boundary_mode,
            use_bias=use_bias,
            zero_bias_init=zero_bias_init,
            key=key,
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    kernel_size: int = 3,
    factor: int = 2,
    output_padding: int = 1,
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ],
    key: PRNGKeyArray
)
Source code in pdequinox/blocks/_linear_conv_up_block.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    kernel_size: int = 3,
    factor: int = 2,
    output_padding: int = 1,
    use_bias: bool = True,
    zero_bias_init: bool = False,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"],
    key: PRNGKeyArray,
):
    super().__init__(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=factor,
        output_padding=output_padding,
        dilation=1,
        boundary_mode=boundary_mode,
        use_bias=use_bias,
        zero_bias_init=zero_bias_init,
        key=key,
    )
__call__ ¤
__call__(
    x: Array,
    *,
    output_padding: Optional[
        Union[int, Sequence[int]]
    ] = None,
    key: Optional[PRNGKeyArray] = None
) -> Array

Arguments:

  • x: The input. Should be a JAX array of shape (in_channels, dim_1, ..., dim_N), where N = num_spatial_dims.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
  • output_padding: Additional padding for the output shape. If not provided, the output_padding used in the initialisation is used.

Returns:

A JAX array of shape (out_channels, new_dim_1, ..., new_dim_N).

Source code in pdequinox/conv/_conv.py
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
@jax.named_scope("eqx.nn.ConvTranspose")
def __call__(
    self,
    x: Array,
    *,
    output_padding: Optional[Union[int, Sequence[int]]] = None,
    key: Optional[PRNGKeyArray] = None,
) -> Array:
    """**Arguments:**

    - `x`: The input. Should be a JAX array of shape
        `(in_channels, dim_1, ..., dim_N)`, where `N = num_spatial_dims`.
    - `key`: Ignored; provided for compatibility with the rest of the Equinox API.
        (Keyword only argument.)
    - `output_padding`: Additional padding for the output shape. If not provided,
        the `output_padding` used in the initialisation is used.

    **Returns:**

    A JAX array of shape `(out_channels, new_dim_1, ..., new_dim_N)`.
    """

    unbatched_rank = self.num_spatial_dims + 1
    if x.ndim != unbatched_rank:
        raise ValueError(
            f"Input to `ConvTranspose` needs to have rank {unbatched_rank},",
            f" but input has shape {x.shape}.",
        )

    if output_padding is None:
        output_padding = self.output_padding

    # Given by Relationship 14 of https://arxiv.org/abs/1603.07285
    transpose_padding = tuple(
        (d * (k - 1) - p0, d * (k - 1) - p1 + o)
        for k, (p0, p1), o, d in zip(
            self.kernel_size, self.padding, output_padding, self.dilation
        )
    )
    # Decide how much has to pre-paded (for everything non "zeros" padding
    # mode)
    if self.padding_mode != "zeros":
        pre_dilation_padding = tuple(
            (
                (p_l + (s - 1)) // s,
                (p_r + (s - 1)) // s,
            )
            for (p_l, p_r), s in zip(transpose_padding, self.stride)
        )
        # Can also be negative
        post_dilation_padding = tuple(
            (
                p_l - pd_l * s,
                p_r - pd_r * s + o,
            )
            for (p_l, p_r), (pd_l, pd_r), s, o in zip(
                self.padding, pre_dilation_padding, self.stride, output_padding
            )
        )
        if self.padding_mode == "circular":
            x = jnp.pad(x, ((0, 0),) + pre_dilation_padding, mode="wrap")
        elif self.padding_mode == "reflect":
            x = jnp.pad(x, ((0, 0),) + pre_dilation_padding, mode="reflect")
        elif self.padding_mode == "replicate":
            x = jnp.pad(x, ((0, 0),) + pre_dilation_padding, mode="edge")
        else:
            raise ValueError(
                f"`padding_mode` must be one of ['zeros', 'reflect', 'replicate', 'circular'],"
                f" but got {self.padding_mode}."
            )
    else:
        post_dilation_padding = tuple(
            (
                p_l,
                p_r + o,
            )
            for (p_l, p_r), o in zip(self.padding, output_padding)
        )
        x = x

    x = jnp.expand_dims(x, axis=0)
    x = lax.conv_general_dilated(
        lhs=x,
        rhs=self.weight,
        window_strides=(1,) * self.num_spatial_dims,
        padding=post_dilation_padding,
        lhs_dilation=self.stride,
        rhs_dilation=self.dilation,
        feature_group_count=self.groups,
    )
    x = jnp.squeeze(x, axis=0)
    if self.use_bias:
        x = x + self.bias
    return x

pdequinox.blocks.Block ¤

Bases: Module, ABC

Source code in pdequinox/blocks/_base_block.py
8
9
class Block(eqx.Module, ABC):
    pass