Skip to content

Low-Level Convolution Routines¤

They wrap the equinox.conv module to provide an interface based on "SAME" convolutions with variable boundary modes or implement spectral convolutions.

pdequinox.conv.PhysicsConv ¤

Bases: MorePaddingConv

Source code in pdequinox/conv/_physics_conv.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class PhysicsConv(MorePaddingConv):
    boundary_mode: Literal["periodic", "dirichlet", "neumann"] = field(static=True)

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]] = 1,
        # no padding because it always chosen to retain spatial size
        dilation: Union[int, Sequence[int]] = 1,
        groups: int = 1,
        use_bias: bool = True,
        *,
        key: PRNGKeyArray,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"],
        zero_bias_init: bool = False,
    ):
        """
        General n-dimensional convolution with "same" padding to operate on
        fields. Allows to choose a `boundary_mode` affecting the type of padding
        used. No option to set the padding.

        This is a thin wrapper around `equinox.nn.Conv`.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional, convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `kernel_size`: The size of the convolutional kernel.
        - `stride`: The stride of the convolution.
        - `dilation`: The dilation of the convolution.
        - `groups`: The number of input channel groups. At `groups=1`,
            all input channels contribute to all output channels. Values higher
            than `1` are equivalent to running `groups` independent `Conv`
            operations side-by-side, each having access only to `in_channels` //
            `groups` input channels, and concatenating the results along the
            output channel dimension. `in_channels` must be divisible by
            `groups`.
        - `use_bias`: Whether to add on a bias after the convolution.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `boundary_mode`: The type of boundary padding to use. Use one of
            ["periodic", "dirichlet", "neumann"]. Currently only "periodic" is
            supported. (Keyword only argument.)
        - `zero_bias_init`: Whether to initialise the bias to zero. (Keyword
            only argument.)

        !!! info

            All of `kernel_size`, `stride`, `dilation` can be either an integer
            or a sequence of integers. If they are a sequence then the sequence
            should be of length equal to `num_spatial_dims`, and specify the
            value of each property down each spatial dimension in turn.

            If they are an integer then the same kernel size / stride / dilation
            will be used along every spatial dimension.
        """
        self.boundary_mode = boundary_mode.lower()

        if self.boundary_mode == "periodic":
            padding_mode = "circular"
        elif self.boundary_mode == "dirichlet":
            padding_mode = "zeros"
        elif self.boundary_mode == "neumann":
            padding_mode = "reflect"
        else:
            raise ValueError(
                f"Only 'periodic', 'dirichlet', 'neumann' boundary modes are supported, got {boundary_mode}"
            )

        super().__init__(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=compute_same_padding(num_spatial_dims, kernel_size, dilation),
            padding_mode=padding_mode,
            dilation=dilation,
            groups=groups,
            use_bias=use_bias,
            key=key,
        )

        if use_bias and zero_bias_init:
            self.bias = jnp.zeros_like(self.bias)
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    kernel_size: Union[int, Sequence[int]],
    stride: Union[int, Sequence[int]] = 1,
    dilation: Union[int, Sequence[int]] = 1,
    groups: int = 1,
    use_bias: bool = True,
    *,
    key: PRNGKeyArray,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ],
    zero_bias_init: bool = False
)

General n-dimensional convolution with "same" padding to operate on fields. Allows to choose a boundary_mode affecting the type of padding used. No option to set the padding.

This is a thin wrapper around equinox.nn.Conv.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional, convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kernel_size: The size of the convolutional kernel.
  • stride: The stride of the convolution.
  • dilation: The dilation of the convolution.
  • groups: The number of input channel groups. At groups=1, all input channels contribute to all output channels. Values higher than 1 are equivalent to running groups independent Conv operations side-by-side, each having access only to in_channels // groups input channels, and concatenating the results along the output channel dimension. in_channels must be divisible by groups.
  • use_bias: Whether to add on a bias after the convolution.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • boundary_mode: The type of boundary padding to use. Use one of ["periodic", "dirichlet", "neumann"]. Currently only "periodic" is supported. (Keyword only argument.)
  • zero_bias_init: Whether to initialise the bias to zero. (Keyword only argument.)

Info

All of kernel_size, stride, dilation can be either an integer or a sequence of integers. If they are a sequence then the sequence should be of length equal to num_spatial_dims, and specify the value of each property down each spatial dimension in turn.

If they are an integer then the same kernel size / stride / dilation will be used along every spatial dimension.

Source code in pdequinox/conv/_physics_conv.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    kernel_size: Union[int, Sequence[int]],
    stride: Union[int, Sequence[int]] = 1,
    # no padding because it always chosen to retain spatial size
    dilation: Union[int, Sequence[int]] = 1,
    groups: int = 1,
    use_bias: bool = True,
    *,
    key: PRNGKeyArray,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"],
    zero_bias_init: bool = False,
):
    """
    General n-dimensional convolution with "same" padding to operate on
    fields. Allows to choose a `boundary_mode` affecting the type of padding
    used. No option to set the padding.

    This is a thin wrapper around `equinox.nn.Conv`.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional, convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `kernel_size`: The size of the convolutional kernel.
    - `stride`: The stride of the convolution.
    - `dilation`: The dilation of the convolution.
    - `groups`: The number of input channel groups. At `groups=1`,
        all input channels contribute to all output channels. Values higher
        than `1` are equivalent to running `groups` independent `Conv`
        operations side-by-side, each having access only to `in_channels` //
        `groups` input channels, and concatenating the results along the
        output channel dimension. `in_channels` must be divisible by
        `groups`.
    - `use_bias`: Whether to add on a bias after the convolution.
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `boundary_mode`: The type of boundary padding to use. Use one of
        ["periodic", "dirichlet", "neumann"]. Currently only "periodic" is
        supported. (Keyword only argument.)
    - `zero_bias_init`: Whether to initialise the bias to zero. (Keyword
        only argument.)

    !!! info

        All of `kernel_size`, `stride`, `dilation` can be either an integer
        or a sequence of integers. If they are a sequence then the sequence
        should be of length equal to `num_spatial_dims`, and specify the
        value of each property down each spatial dimension in turn.

        If they are an integer then the same kernel size / stride / dilation
        will be used along every spatial dimension.
    """
    self.boundary_mode = boundary_mode.lower()

    if self.boundary_mode == "periodic":
        padding_mode = "circular"
    elif self.boundary_mode == "dirichlet":
        padding_mode = "zeros"
    elif self.boundary_mode == "neumann":
        padding_mode = "reflect"
    else:
        raise ValueError(
            f"Only 'periodic', 'dirichlet', 'neumann' boundary modes are supported, got {boundary_mode}"
        )

    super().__init__(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=compute_same_padding(num_spatial_dims, kernel_size, dilation),
        padding_mode=padding_mode,
        dilation=dilation,
        groups=groups,
        use_bias=use_bias,
        key=key,
    )

    if use_bias and zero_bias_init:
        self.bias = jnp.zeros_like(self.bias)
__call__ ¤
__call__(
    x: Array, *, key: Optional[PRNGKeyArray] = None
) -> Array

Arguments:

  • x: The input. Should be a JAX array of shape (in_channels, dim_1, ..., dim_N), where N = num_spatial_dims.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

Returns:

A JAX array of shape (out_channels, new_dim_1, ..., new_dim_N).

Source code in pdequinox/conv/_conv.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
@jax.named_scope("eqx.nn.Conv")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
    """**Arguments:**

    - `x`: The input. Should be a JAX array of shape
        `(in_channels, dim_1, ..., dim_N)`, where `N = num_spatial_dims`.
    - `key`: Ignored; provided for compatibility with the rest of the Equinox API.
        (Keyword only argument.)

    **Returns:**

    A JAX array of shape `(out_channels, new_dim_1, ..., new_dim_N)`.
    """

    unbatched_rank = self.num_spatial_dims + 1
    if x.ndim != unbatched_rank:
        raise ValueError(
            f"Input to `Conv` needs to have rank {unbatched_rank},",
            f" but input has shape {x.shape}.",
        )
    if self.padding_mode == "circular":
        x = jnp.pad(x, ((0, 0),) + self.padding, mode="wrap")
        padding_lax = ((0, 0),) * self.num_spatial_dims
    elif self.padding_mode == "zeros":
        x = x
        padding_lax = self.padding
    elif self.padding_mode == "reflect":
        x = jnp.pad(x, ((0, 0),) + self.padding, mode="reflect")
        padding_lax = ((0, 0),) * self.num_spatial_dims
    elif self.padding_mode == "replicate":
        x = jnp.pad(x, ((0, 0),) + self.padding, mode="edge")
        padding_lax = ((0, 0),) * self.num_spatial_dims
    else:
        raise ValueError(
            f"`padding_mode` must be one of ['zeros', 'reflect', 'replicate', 'circular'],"
            f" but got {self.padding_mode}."
        )

    x = jnp.expand_dims(x, axis=0)
    x = lax.conv_general_dilated(
        lhs=x,
        rhs=self.weight,
        window_strides=self.stride,
        padding=padding_lax,
        rhs_dilation=self.dilation,
        feature_group_count=self.groups,
    )
    x = jnp.squeeze(x, axis=0)
    if self.use_bias:
        x = x + self.bias
    return x

pdequinox.conv.PhysicsConvTranspose ¤

Bases: MorePaddingConvTranspose

Source code in pdequinox/conv/_physics_conv.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
class PhysicsConvTranspose(MorePaddingConvTranspose):
    boundary_mode: Literal["periodic", "dirichlet", "neumann"] = field(static=True)

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]] = 1,
        # no padding because it always chosen to retain spatial size
        output_padding: Union[int, Sequence[int]] = 0,
        dilation: Union[int, Sequence[int]] = 1,
        groups: int = 1,
        use_bias: bool = True,
        *,
        key: PRNGKeyArray,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"],
        zero_bias_init: bool = False,
    ):
        """
        General n-dimensional transposed convolution with "same" padding to
        operate on fields. Allows to choose a `boundary_mode` affecting the type
        of padding used. No option to set the padding.

        This is a thin wrapper around `equinox.nn.ConvTranspose`.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional, convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `kernel_size`: The size of the convolutional kernel.
        - `stride`: The stride of the convolution.
        - `output_padding`: Additional padding for the output shape.
        - `dilation`: The dilation of the convolution.
        - `groups`: The number of input channel groups. At `groups=1`,
            all input channels contribute to all output channels. Values higher
            than `1` are equivalent to running `groups` independent `Conv`
            operations side-by-side, each having access only to `in_channels` //
            `groups` input channels, and concatenating the results along the
            output channel dimension. `in_channels` must be divisible by
            `groups`.
        - `use_bias`: Whether to add on a bias after the convolution.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `boundary_mode`: The type of boundary padding to use. Use one of
            ["periodic", "dirichlet", "neumann"]. Currently only "periodic" is
            supported. (Keyword only argument.)
        - `zero_bias_init`: Whether to initialise the bias to zero. (Keyword
            only argument.)

        !!! info

            All of `kernel_size`, `stride`, `dilation` can be either an integer
            or a sequence of integers. If they are a sequence then the sequence
            should be of length equal to `num_spatial_dims`, and specify the
            value of each property down each spatial dimension in turn.

            If they are an integer then the same kernel size / stride / dilation
            will be used along every spatial dimension.

        !!! tip

            Transposed convolutions are often used to go in the "opposite
            direction" to a normal convolution. That is, from something with the
            shape of the output of a convolution to something with the shape of
            the input to a convolution. Moreover, to do so with the same
            "connectivity", i.e. which inputs can affect which outputs.

            Relative to an [`pdequinox.conv.PhysicsConv`][] layer, this can be
            accomplished by switching the values of `in_channels` and
            `out_channels`, whilst keeping `kernel_size`, `stride`, `dilation`,
            and `groups` the same.

            When `stride > 1` then [`pdequinox.conv.PhysicsConv`][] maps multiple input shapes
            to the same output shape. `output_padding` is provided to resolve
            this ambiguity, by adding a little extra padding to just the
            bottom/right edges of the input.

            See [these
            animations](https://github.com/vdumoulin/conv_arithmetic/blob/af6f818b0bb396c26da79899554682a8a499101d/README.md#transposed-convolution-animations)
            and [this report](https://arxiv.org/abs/1603.07285) for a nice
            reference.
        """
        self.boundary_mode = boundary_mode.lower()

        if self.boundary_mode == "periodic":
            padding_mode = "circular"
        elif self.boundary_mode == "dirichlet":
            padding_mode = "zeros"
        elif self.boundary_mode == "neumann":
            padding_mode = "reflect"
        else:
            raise ValueError(
                f"Only 'periodic', 'dirichlet', 'neumann' boundary modes are supported, got {boundary_mode}"
            )

        super().__init__(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=compute_same_padding(num_spatial_dims, kernel_size, dilation),
            padding_mode=padding_mode,
            output_padding=output_padding,
            dilation=dilation,
            groups=groups,
            use_bias=use_bias,
            key=key,
        )

        if use_bias and zero_bias_init:
            self.bias = jnp.zeros_like(self.bias)
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    kernel_size: Union[int, Sequence[int]],
    stride: Union[int, Sequence[int]] = 1,
    output_padding: Union[int, Sequence[int]] = 0,
    dilation: Union[int, Sequence[int]] = 1,
    groups: int = 1,
    use_bias: bool = True,
    *,
    key: PRNGKeyArray,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ],
    zero_bias_init: bool = False
)

General n-dimensional transposed convolution with "same" padding to operate on fields. Allows to choose a boundary_mode affecting the type of padding used. No option to set the padding.

This is a thin wrapper around equinox.nn.ConvTranspose.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional, convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • kernel_size: The size of the convolutional kernel.
  • stride: The stride of the convolution.
  • output_padding: Additional padding for the output shape.
  • dilation: The dilation of the convolution.
  • groups: The number of input channel groups. At groups=1, all input channels contribute to all output channels. Values higher than 1 are equivalent to running groups independent Conv operations side-by-side, each having access only to in_channels // groups input channels, and concatenating the results along the output channel dimension. in_channels must be divisible by groups.
  • use_bias: Whether to add on a bias after the convolution.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • boundary_mode: The type of boundary padding to use. Use one of ["periodic", "dirichlet", "neumann"]. Currently only "periodic" is supported. (Keyword only argument.)
  • zero_bias_init: Whether to initialise the bias to zero. (Keyword only argument.)

Info

All of kernel_size, stride, dilation can be either an integer or a sequence of integers. If they are a sequence then the sequence should be of length equal to num_spatial_dims, and specify the value of each property down each spatial dimension in turn.

If they are an integer then the same kernel size / stride / dilation will be used along every spatial dimension.

Tip

Transposed convolutions are often used to go in the "opposite direction" to a normal convolution. That is, from something with the shape of the output of a convolution to something with the shape of the input to a convolution. Moreover, to do so with the same "connectivity", i.e. which inputs can affect which outputs.

Relative to an pdequinox.conv.PhysicsConv layer, this can be accomplished by switching the values of in_channels and out_channels, whilst keeping kernel_size, stride, dilation, and groups the same.

When stride > 1 then pdequinox.conv.PhysicsConv maps multiple input shapes to the same output shape. output_padding is provided to resolve this ambiguity, by adding a little extra padding to just the bottom/right edges of the input.

See these animations and this report for a nice reference.

Source code in pdequinox/conv/_physics_conv.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    kernel_size: Union[int, Sequence[int]],
    stride: Union[int, Sequence[int]] = 1,
    # no padding because it always chosen to retain spatial size
    output_padding: Union[int, Sequence[int]] = 0,
    dilation: Union[int, Sequence[int]] = 1,
    groups: int = 1,
    use_bias: bool = True,
    *,
    key: PRNGKeyArray,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"],
    zero_bias_init: bool = False,
):
    """
    General n-dimensional transposed convolution with "same" padding to
    operate on fields. Allows to choose a `boundary_mode` affecting the type
    of padding used. No option to set the padding.

    This is a thin wrapper around `equinox.nn.ConvTranspose`.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional, convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `kernel_size`: The size of the convolutional kernel.
    - `stride`: The stride of the convolution.
    - `output_padding`: Additional padding for the output shape.
    - `dilation`: The dilation of the convolution.
    - `groups`: The number of input channel groups. At `groups=1`,
        all input channels contribute to all output channels. Values higher
        than `1` are equivalent to running `groups` independent `Conv`
        operations side-by-side, each having access only to `in_channels` //
        `groups` input channels, and concatenating the results along the
        output channel dimension. `in_channels` must be divisible by
        `groups`.
    - `use_bias`: Whether to add on a bias after the convolution.
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `boundary_mode`: The type of boundary padding to use. Use one of
        ["periodic", "dirichlet", "neumann"]. Currently only "periodic" is
        supported. (Keyword only argument.)
    - `zero_bias_init`: Whether to initialise the bias to zero. (Keyword
        only argument.)

    !!! info

        All of `kernel_size`, `stride`, `dilation` can be either an integer
        or a sequence of integers. If they are a sequence then the sequence
        should be of length equal to `num_spatial_dims`, and specify the
        value of each property down each spatial dimension in turn.

        If they are an integer then the same kernel size / stride / dilation
        will be used along every spatial dimension.

    !!! tip

        Transposed convolutions are often used to go in the "opposite
        direction" to a normal convolution. That is, from something with the
        shape of the output of a convolution to something with the shape of
        the input to a convolution. Moreover, to do so with the same
        "connectivity", i.e. which inputs can affect which outputs.

        Relative to an [`pdequinox.conv.PhysicsConv`][] layer, this can be
        accomplished by switching the values of `in_channels` and
        `out_channels`, whilst keeping `kernel_size`, `stride`, `dilation`,
        and `groups` the same.

        When `stride > 1` then [`pdequinox.conv.PhysicsConv`][] maps multiple input shapes
        to the same output shape. `output_padding` is provided to resolve
        this ambiguity, by adding a little extra padding to just the
        bottom/right edges of the input.

        See [these
        animations](https://github.com/vdumoulin/conv_arithmetic/blob/af6f818b0bb396c26da79899554682a8a499101d/README.md#transposed-convolution-animations)
        and [this report](https://arxiv.org/abs/1603.07285) for a nice
        reference.
    """
    self.boundary_mode = boundary_mode.lower()

    if self.boundary_mode == "periodic":
        padding_mode = "circular"
    elif self.boundary_mode == "dirichlet":
        padding_mode = "zeros"
    elif self.boundary_mode == "neumann":
        padding_mode = "reflect"
    else:
        raise ValueError(
            f"Only 'periodic', 'dirichlet', 'neumann' boundary modes are supported, got {boundary_mode}"
        )

    super().__init__(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=compute_same_padding(num_spatial_dims, kernel_size, dilation),
        padding_mode=padding_mode,
        output_padding=output_padding,
        dilation=dilation,
        groups=groups,
        use_bias=use_bias,
        key=key,
    )

    if use_bias and zero_bias_init:
        self.bias = jnp.zeros_like(self.bias)
__call__ ¤
__call__(
    x: Array,
    *,
    output_padding: Optional[
        Union[int, Sequence[int]]
    ] = None,
    key: Optional[PRNGKeyArray] = None
) -> Array

Arguments:

  • x: The input. Should be a JAX array of shape (in_channels, dim_1, ..., dim_N), where N = num_spatial_dims.
  • key: Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
  • output_padding: Additional padding for the output shape. If not provided, the output_padding used in the initialisation is used.

Returns:

A JAX array of shape (out_channels, new_dim_1, ..., new_dim_N).

Source code in pdequinox/conv/_conv.py
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
@jax.named_scope("eqx.nn.ConvTranspose")
def __call__(
    self,
    x: Array,
    *,
    output_padding: Optional[Union[int, Sequence[int]]] = None,
    key: Optional[PRNGKeyArray] = None,
) -> Array:
    """**Arguments:**

    - `x`: The input. Should be a JAX array of shape
        `(in_channels, dim_1, ..., dim_N)`, where `N = num_spatial_dims`.
    - `key`: Ignored; provided for compatibility with the rest of the Equinox API.
        (Keyword only argument.)
    - `output_padding`: Additional padding for the output shape. If not provided,
        the `output_padding` used in the initialisation is used.

    **Returns:**

    A JAX array of shape `(out_channels, new_dim_1, ..., new_dim_N)`.
    """

    unbatched_rank = self.num_spatial_dims + 1
    if x.ndim != unbatched_rank:
        raise ValueError(
            f"Input to `ConvTranspose` needs to have rank {unbatched_rank},",
            f" but input has shape {x.shape}.",
        )

    if output_padding is None:
        output_padding = self.output_padding

    # Given by Relationship 14 of https://arxiv.org/abs/1603.07285
    transpose_padding = tuple(
        (d * (k - 1) - p0, d * (k - 1) - p1 + o)
        for k, (p0, p1), o, d in zip(
            self.kernel_size, self.padding, output_padding, self.dilation
        )
    )
    # Decide how much has to pre-paded (for everything non "zeros" padding
    # mode)
    if self.padding_mode != "zeros":
        pre_dilation_padding = tuple(
            (
                (p_l + (s - 1)) // s,
                (p_r + (s - 1)) // s,
            )
            for (p_l, p_r), s in zip(transpose_padding, self.stride)
        )
        # Can also be negative
        post_dilation_padding = tuple(
            (
                p_l - pd_l * s,
                p_r - pd_r * s + o,
            )
            for (p_l, p_r), (pd_l, pd_r), s, o in zip(
                self.padding, pre_dilation_padding, self.stride, output_padding
            )
        )
        if self.padding_mode == "circular":
            x = jnp.pad(x, ((0, 0),) + pre_dilation_padding, mode="wrap")
        elif self.padding_mode == "reflect":
            x = jnp.pad(x, ((0, 0),) + pre_dilation_padding, mode="reflect")
        elif self.padding_mode == "replicate":
            x = jnp.pad(x, ((0, 0),) + pre_dilation_padding, mode="edge")
        else:
            raise ValueError(
                f"`padding_mode` must be one of ['zeros', 'reflect', 'replicate', 'circular'],"
                f" but got {self.padding_mode}."
            )
    else:
        post_dilation_padding = tuple(
            (
                p_l,
                p_r + o,
            )
            for (p_l, p_r), o in zip(self.padding, output_padding)
        )
        x = x

    x = jnp.expand_dims(x, axis=0)
    x = lax.conv_general_dilated(
        lhs=x,
        rhs=self.weight,
        window_strides=(1,) * self.num_spatial_dims,
        padding=post_dilation_padding,
        lhs_dilation=self.stride,
        rhs_dilation=self.dilation,
        feature_group_count=self.groups,
    )
    x = jnp.squeeze(x, axis=0)
    if self.use_bias:
        x = x + self.bias
    return x

pdequinox.conv.SpectralConv ¤

Bases: Module

Huge credit to the Serket library for this implementation: https://github.com/ASEM000/serket

Source code in pdequinox/conv/_spectral_conv.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class SpectralConv(eqx.Module):
    """
    Huge credit to the Serket library for this implementation:
    https://github.com/ASEM000/serket
    """

    num_spatial_dims: int
    num_modes: tuple[int]
    weights_real: Float[Array, "G C_o C_i ..."]
    weights_imag: Float[Array, "G C_o C_i ..."]

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        num_modes: Union[tuple[int, ...], int],
        *,
        key: PRNGKeyArray,
    ):
        """
        General n-dimensional spectral convolution on **real** fields.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional, convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `num_modes`: The number of modes to use in the fourier representation
            of the input. If an integer is passed, the same number of modes will
            be used for each spatial dimension.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        """
        if isinstance(num_modes, int):
            num_modes = (num_modes,) * num_spatial_dims

        if len(num_modes) != num_spatial_dims:
            raise ValueError("num_modes must have the same length as num_spatial_dims")

        self.num_spatial_dims = num_spatial_dims
        self.num_modes = num_modes

        weight_shape = (
            2 ** (num_spatial_dims - 1),
            in_channels,
            out_channels,
        ) + num_modes

        real_key, imag_key = jr.split(key)
        scale = 1 / (in_channels * out_channels)
        self.weights_real = scale * jr.normal(real_key, weight_shape)
        self.weights_imag = scale * jr.normal(imag_key, weight_shape)

    def __call__(self, x: Float[Array, "C_i ..."]) -> Float[Array, "C_o ..."]:
        return spectral_conv_nd(x, self.weights_real, self.weights_imag, self.num_modes)

    @property
    def receptive_field(self) -> tuple[tuple[float, float], ...]:
        return tuple(((jnp.inf, jnp.inf),) * self.num_spatial_dims)
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    num_modes: Union[tuple[int, ...], int],
    *,
    key: PRNGKeyArray
)

General n-dimensional spectral convolution on real fields.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional, convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • num_modes: The number of modes to use in the fourier representation of the input. If an integer is passed, the same number of modes will be used for each spatial dimension.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
Source code in pdequinox/conv/_spectral_conv.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    num_modes: Union[tuple[int, ...], int],
    *,
    key: PRNGKeyArray,
):
    """
    General n-dimensional spectral convolution on **real** fields.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional, convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `num_modes`: The number of modes to use in the fourier representation
        of the input. If an integer is passed, the same number of modes will
        be used for each spatial dimension.
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    """
    if isinstance(num_modes, int):
        num_modes = (num_modes,) * num_spatial_dims

    if len(num_modes) != num_spatial_dims:
        raise ValueError("num_modes must have the same length as num_spatial_dims")

    self.num_spatial_dims = num_spatial_dims
    self.num_modes = num_modes

    weight_shape = (
        2 ** (num_spatial_dims - 1),
        in_channels,
        out_channels,
    ) + num_modes

    real_key, imag_key = jr.split(key)
    scale = 1 / (in_channels * out_channels)
    self.weights_real = scale * jr.normal(real_key, weight_shape)
    self.weights_imag = scale * jr.normal(imag_key, weight_shape)
__call__ ¤
__call__(
    x: Float[Array, "C_i ..."]
) -> Float[Array, "C_o ..."]
Source code in pdequinox/conv/_spectral_conv.py
65
66
def __call__(self, x: Float[Array, "C_i ..."]) -> Float[Array, "C_o ..."]:
    return spectral_conv_nd(x, self.weights_real, self.weights_imag, self.num_modes)

pdequinox.conv.PointwiseLinearConv ¤

Bases: Conv

aka 1x1 Convolution; used primarily for channel adjustment

Source code in pdequinox/conv/_pointwise_linear_conv.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class PointwiseLinearConv(eqx.nn.Conv):
    """
    aka 1x1 Convolution; used primarily for channel adjustment
    """

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        use_bias: bool = True,
        *,
        zero_bias_init: bool = False,
        key: PRNGKeyArray,
    ):
        """
        General n-dimensional pointwise linear convolution (=1x1 convolution).
        This is primarily used for channel adjustment.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional, convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `use_bias`: Whether to use a bias term. (Default: `True`)
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `zero_bias_init`: Whether to initialise the bias to zero. (Default:
            `False`)
        """
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            dilation=1,
            padding=0,
            use_bias=use_bias,
            key=key,
        )
        if use_bias and zero_bias_init:
            self.bias = jnp.zeros_like(self.bias)

    @property
    def receptive_field(self) -> tuple[tuple[float, float], ...]:
        return tuple(((0.0, 0.0),) * self.num_spatial_dims)
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    use_bias: bool = True,
    *,
    zero_bias_init: bool = False,
    key: PRNGKeyArray
)

General n-dimensional pointwise linear convolution (=1x1 convolution). This is primarily used for channel adjustment.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional, convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • use_bias: Whether to use a bias term. (Default: True)
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • zero_bias_init: Whether to initialise the bias to zero. (Default: False)
Source code in pdequinox/conv/_pointwise_linear_conv.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    use_bias: bool = True,
    *,
    zero_bias_init: bool = False,
    key: PRNGKeyArray,
):
    """
    General n-dimensional pointwise linear convolution (=1x1 convolution).
    This is primarily used for channel adjustment.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional, convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `use_bias`: Whether to use a bias term. (Default: `True`)
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `zero_bias_init`: Whether to initialise the bias to zero. (Default:
        `False`)
    """
    super().__init__(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=1,
        stride=1,
        dilation=1,
        padding=0,
        use_bias=use_bias,
        key=key,
    )
    if use_bias and zero_bias_init:
        self.bias = jnp.zeros_like(self.bias)