Skip to content

Architectural Constructors¤

There are two primary architectural constructors for Sequential and Hierarchical Networks that allow for composability with the PDEquinox blocks.

Sequential Constructor¤

sequential_net

The pdequinox.Sequential network constructor is defined by:

  • a lifting block \(\mathcal{L}\)
  • \(N\) blocks \(\left \{ \mathcal{B}_i \right\}_{i=1}^N\)
  • a projection block \(\mathcal{P}\)
  • the hidden channels within the sequential processing
  • the number of blocks \(N\) (one can also supply a list of hidden channels if they shall be different between blocks)

Hierarchical Constructor¤

hierarchical_net

The pdequinox.Hierarchical network constructor is defined by:

  • a lifting block \(\mathcal{L}\)
  • The number of levels \(D\) (i.e., the number of additional hierarchies). Setting \(D = 0\) recovers the sequential processing.
  • a list of \(D\) blocks \(\left \{ \mathcal{D}_i \right\}_{i=1}^D\) for downsampling, i.e. mapping downwards to the lower hierarchy (oftentimes this is that they halve the spatial axes while keeping the number of channels)
  • a list of \(D\) blocks \(\left \{ \mathcal{B}_i^l \right\}_{i=1}^D\) for processing in the left arc (oftentimes this changes the number of channels, e.g. doubles it such that the combination of downsampling and left processing halves the spatial resolution and doubles the feature count)
  • a list of \(D\) blocks \(\left \{ \mathcal{U}_i \right\}_{i=1}^D\) for upsamping, i.e., mapping upwards to the higher hierarchy (oftentimes this doubles the spatial resolution; at the same time it halves the feature count such that we can concatenate a skip connection)
  • a list of \(D\) blocks \(\left \{ \mathcal{B}_i^r \right\}_{i=1}^D\) for processing in the right arc (oftentimes this changes the number of channels, e.g. halves it such that the combination of upsampling and right processing doubles the spatial resolution and halves the feature count)
  • a projection block \(\mathcal{P}\)
  • the hidden channels within the hierarchical processing (if just an integer is provided; this is assumed to be the number of hidden channels in the highest hierarchy.)

Beyond Architectural Constructors¤

For completion, pdequinox.arch also provides a pdequinox.arch.ConvNet which is a simple feed-forward convolutional network. It also provides pdequinox.arch.MLP which is a dense networks which also requires pre-defining the number of resolution points. -->

API¤

pdequinox.Sequential ¤

Bases: Module

Source code in pdequinox/_sequential.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class Sequential(eqx.Module):
    lifting: Block
    blocks: List[Block]
    projection: Block

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        hidden_channels: Union[Sequence[int], int],
        num_blocks: int,
        activation: Callable,
        key: PRNGKeyArray,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"],
        lifting_factory: BlockFactory = LinearChannelAdjustBlockFactory(),
        block_factory: BlockFactory = ClassicResBlockFactory(),
        projection_factory: BlockFactory = LinearChannelAdjustBlockFactory(),
    ):
        """
        Generic constructor for sequential block-based architectures like
        ResNets.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `hidden_channels`: The number of channels in the hidden layers. Either
            an integer to have the same number of hidden channels in the layers
            between all blocks, or a list of `num_blocks + 1` integers.
        - `num_blocks`: The number of blocks to use. Must be an integer greater
            equal than `1`.
        - `activation`: The activation function to use in the blocks.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `boundary_mode`: The boundary mode to use for the convolution.
            (Keyword only argument)
        - `lifting_factory`: The factory to use for the lifting block.
            Default is `LinearChannelAdjustBlockFactory` which is simply a
            linear 1x1 convolution for channel adjustment.
        - `block_factory`: The factory to use for the blocks. Default is
            `ClassicResBlockFactory` which is a classic ResNet block (with
            postactivation)
        - `projection_factory`: The factory to use for the projection block.
            Default is `LinearChannelAdjustBlockFactory` which is simply a
            linear 1x1 convolution for channel adjustment.
        """
        subkey, key = jr.split(key)
        if num_blocks < 1:
            raise ValueError("num_blocks must be at least 1")

        if isinstance(hidden_channels, int):
            hidden_channels = (hidden_channels,) * (num_blocks + 1)
        else:
            if len(hidden_channels) != (num_blocks + 1):
                raise ValueError(
                    "The list of hidden channels must be one longer than the number of blocks"
                )

        self.lifting = lifting_factory(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=hidden_channels[0],
            activation=activation,
            boundary_mode=boundary_mode,
            key=subkey,
        )
        self.blocks = []
        for fan_in, fan_out in zip(
            hidden_channels[:-1],
            hidden_channels[1:],
        ):
            subkey, key = jr.split(key)
            self.blocks.append(
                block_factory(
                    num_spatial_dims=num_spatial_dims,
                    in_channels=fan_in,
                    out_channels=fan_out,
                    activation=activation,
                    boundary_mode=boundary_mode,
                    key=subkey,
                )
            )
        self.projection = projection_factory(
            num_spatial_dims=num_spatial_dims,
            in_channels=hidden_channels[-1],
            out_channels=out_channels,
            activation=activation,
            boundary_mode=boundary_mode,
            key=key,
        )

    def __call__(self, x):
        x = self.lifting(x)
        for block in self.blocks:
            x = block(x)
        x = self.projection(x)
        return x

    @property
    def receptive_field(self) -> tuple[tuple[float, float], ...]:
        lifting_receptive_field = self.lifting.receptive_field
        block_receptive_fields = tuple(block.receptive_field for block in self.blocks)
        projection_receptive_field = self.projection.receptive_field
        receptive_fields = (
            (lifting_receptive_field,)
            + block_receptive_fields
            + (projection_receptive_field,)
        )
        return sum_receptive_fields(receptive_fields)
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: Union[Sequence[int], int],
    num_blocks: int,
    activation: Callable,
    key: PRNGKeyArray,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ],
    lifting_factory: BlockFactory = LinearChannelAdjustBlockFactory(),
    block_factory: BlockFactory = ClassicResBlockFactory(),
    projection_factory: BlockFactory = LinearChannelAdjustBlockFactory()
)

Generic constructor for sequential block-based architectures like ResNets.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • hidden_channels: The number of channels in the hidden layers. Either an integer to have the same number of hidden channels in the layers between all blocks, or a list of num_blocks + 1 integers.
  • num_blocks: The number of blocks to use. Must be an integer greater equal than 1.
  • activation: The activation function to use in the blocks.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • boundary_mode: The boundary mode to use for the convolution. (Keyword only argument)
  • lifting_factory: The factory to use for the lifting block. Default is LinearChannelAdjustBlockFactory which is simply a linear 1x1 convolution for channel adjustment.
  • block_factory: The factory to use for the blocks. Default is ClassicResBlockFactory which is a classic ResNet block (with postactivation)
  • projection_factory: The factory to use for the projection block. Default is LinearChannelAdjustBlockFactory which is simply a linear 1x1 convolution for channel adjustment.
Source code in pdequinox/_sequential.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: Union[Sequence[int], int],
    num_blocks: int,
    activation: Callable,
    key: PRNGKeyArray,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"],
    lifting_factory: BlockFactory = LinearChannelAdjustBlockFactory(),
    block_factory: BlockFactory = ClassicResBlockFactory(),
    projection_factory: BlockFactory = LinearChannelAdjustBlockFactory(),
):
    """
    Generic constructor for sequential block-based architectures like
    ResNets.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `hidden_channels`: The number of channels in the hidden layers. Either
        an integer to have the same number of hidden channels in the layers
        between all blocks, or a list of `num_blocks + 1` integers.
    - `num_blocks`: The number of blocks to use. Must be an integer greater
        equal than `1`.
    - `activation`: The activation function to use in the blocks.
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `boundary_mode`: The boundary mode to use for the convolution.
        (Keyword only argument)
    - `lifting_factory`: The factory to use for the lifting block.
        Default is `LinearChannelAdjustBlockFactory` which is simply a
        linear 1x1 convolution for channel adjustment.
    - `block_factory`: The factory to use for the blocks. Default is
        `ClassicResBlockFactory` which is a classic ResNet block (with
        postactivation)
    - `projection_factory`: The factory to use for the projection block.
        Default is `LinearChannelAdjustBlockFactory` which is simply a
        linear 1x1 convolution for channel adjustment.
    """
    subkey, key = jr.split(key)
    if num_blocks < 1:
        raise ValueError("num_blocks must be at least 1")

    if isinstance(hidden_channels, int):
        hidden_channels = (hidden_channels,) * (num_blocks + 1)
    else:
        if len(hidden_channels) != (num_blocks + 1):
            raise ValueError(
                "The list of hidden channels must be one longer than the number of blocks"
            )

    self.lifting = lifting_factory(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=hidden_channels[0],
        activation=activation,
        boundary_mode=boundary_mode,
        key=subkey,
    )
    self.blocks = []
    for fan_in, fan_out in zip(
        hidden_channels[:-1],
        hidden_channels[1:],
    ):
        subkey, key = jr.split(key)
        self.blocks.append(
            block_factory(
                num_spatial_dims=num_spatial_dims,
                in_channels=fan_in,
                out_channels=fan_out,
                activation=activation,
                boundary_mode=boundary_mode,
                key=subkey,
            )
        )
    self.projection = projection_factory(
        num_spatial_dims=num_spatial_dims,
        in_channels=hidden_channels[-1],
        out_channels=out_channels,
        activation=activation,
        boundary_mode=boundary_mode,
        key=key,
    )
__call__ ¤
__call__(x)
Source code in pdequinox/_sequential.py
111
112
113
114
115
116
def __call__(self, x):
    x = self.lifting(x)
    for block in self.blocks:
        x = block(x)
    x = self.projection(x)
    return x

pdequinox.Hierarchical ¤

Bases: Module

Source code in pdequinox/_hierarchical.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
class Hierarchical(eqx.Module):
    lifting: Block
    down_sampling_blocks: List[Block]
    left_arc_blocks: List[List[Block]]  # Includes the bottleneck
    up_sampling_blocks: List[Block]
    right_arc_blocks: List[List[Block]]
    projection: PhysicsConv
    reduction_factor: int
    num_levels: int
    num_blocks: int
    channel_multipliers: tuple[int, ...]

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        hidden_channels: int,
        num_levels: int,
        num_blocks: int,
        activation: Callable,
        key: PRNGKeyArray,
        reduction_factor: int = 2,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"],
        channel_multipliers: Optional[tuple[int, ...]] = None,
        lifting_factory: BlockFactory = ClassicDoubleConvBlockFactory(),
        down_sampling_factory: BlockFactory = LinearConvDownBlockFactory(),
        left_arc_factory: BlockFactory = ClassicDoubleConvBlockFactory(),
        up_sampling_factory: BlockFactory = LinearConvUpBlockFactory(),
        right_arc_factory: BlockFactory = ClassicDoubleConvBlockFactory(),
        projection_factory: BlockFactory = LinearChannelAdjustBlockFactory(),
    ):
        """
        Generic constructor for hierarchical block-based architectures like
        UNets. (For the classic UNet, use `pdequinox.arch.ClassicUNet` instead.)

        Hierarchical architectures us a number of different spatial resolutions.
        The lower the resolution, the wider the receptive field of convolutions.

        Allows to increase the number of blocks per level via the `num_blocks`
        argument. This will be identical for the left arc (=encoder) and the
        right arc (=decoder). **No multi-skip** as in PDEArena.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `hidden_channels`: The number of channels in the hidden layers. This
            refers to the highest resolution. Right after the input, the input
            channels will be lifted to this feature dimension without changing
            the spatial resolution.
        - `num_levels`: The number of levels in the hierarchy. This is the
            number of down and up sampling blocks. If set to 0, this will just
            be a classical conv net. If set to 1, this will be a single down and
            up sampling block etc. The total number of resolutions are
            `num_levels + 1`.
        - `num_blocks`: The number of blocks to use at each level. (Also affects
            the number of blocks in the bottleneck.)
        - `activation`: The activation function to use in the blocks.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `reduction_factor`: The factor by which the spatial resolution is
            reduced at each level. This has to be an integer. In order to avoid
            ambiguities in shapes, it is best if the input spatial resolution is
            a multiple of `reduction_factor ** num_levels`. Default is `2`.
        - `boundary_mode`: The boundary mode to use for the convolution.
            (Keyword only argument)
        - `channel_multipliers`: The factor by which the number of channels is
            multiplied at each level. If set to `None`, the channels will grow
            by a factor of `reduction_factor` at each level. This is similar to
            the classical UNet which trades spatial resolution for feature
            dimension. Note however, that the parameters of convolutions scale
            with the mapped channels, hence the majority of numbers will then be
            in the coarsest representation. Supply a tuple of integers that
            represent the desired number of channels at each resolution
            different than the original one. The length of the tuple must match
            `num_levels`. For example, to not change the number of channels at
            any level, set this to `(1,) * num_levels`. Default is `None`.
        - `lifting_factory`: The factory to use for the lifting block.
            Default is `ClassicDoubleConvBlockFactory` which is a classic double
            convolution block.
        - `down_sampling_factory`: The factory to use for the down sampling
            blocks. This must be a block that is able to change the spatial
            resolution. Default is `LinearConvDownBlockFactory` which is a
            simple linear strided convolution block.
        - `left_arch_factory`: The factory to use for the left architecture
            blocks. Default is `ClassicDoubleConvBlockFactory` which is a
            classic double convolution block.
        - `up_sampling_factory`: The factory to use for the up sampling blocks.
            This must be a block that is able to change the spatial resolution.
            It should work in conjunction with the `down_sampling_factory`.
            Default is `LinearConvUpBlockFactory` which is a simple linear
            strided transposed convolution block.
        - `right_arch_factory`: The factory to use for the right architecture
            blocks. Default is `ClassicDoubleConvBlockFactory` which is a
            classic double convolution block.
        - `projection_factory`: The factory to use for the projection block.
            Default is `LinearChannelAdjustBlockFactory` which is simply a
            linear 1x1 convolution for channel adjustment.
        """
        self.down_sampling_blocks = []
        self.left_arc_blocks = []
        self.up_sampling_blocks = []
        self.right_arc_blocks = []
        self.reduction_factor = reduction_factor
        self.num_levels = num_levels

        key, lifting_key, projection_key = jr.split(key, 3)

        self.lifting = lifting_factory(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=hidden_channels,
            activation=activation,
            boundary_mode=boundary_mode,
            key=lifting_key,
        )
        self.projection = projection_factory(
            num_spatial_dims=num_spatial_dims,
            in_channels=hidden_channels,
            out_channels=out_channels,
            activation=activation,
            boundary_mode=boundary_mode,
            key=projection_key,
        )

        if channel_multipliers is not None:
            if len(channel_multipliers) != num_levels:
                raise ValueError("len(channel_multipliers) must match num_levels")
        else:
            channel_multipliers = tuple(
                self.reduction_factor**i for i in range(1, num_levels + 1)
            )

        self.channel_multipliers = channel_multipliers

        channel_list = [
            hidden_channels,
        ] + [hidden_channels * m for m in channel_multipliers]

        if num_blocks < 1:
            raise ValueError("num_blocks must be at least 1")
        self.num_blocks = num_blocks

        for (
            fan_in,
            fan_out,
        ) in zip(
            channel_list[:-1],
            channel_list[1:],
        ):
            # If num_levels is 0, the loop will not run
            key, down_key, left_key, up_key, right_key = jr.split(key, 5)
            self.down_sampling_blocks.append(
                down_sampling_factory(
                    num_spatial_dims=num_spatial_dims,
                    in_channels=fan_in,
                    out_channels=fan_in,
                    activation=activation,
                    boundary_mode=boundary_mode,
                    key=down_key,
                )
            )

            this_level_left_arc_blocks = []
            # The first block changes the number of channels
            this_level_left_arc_blocks.append(
                left_arc_factory(
                    num_spatial_dims=num_spatial_dims,
                    in_channels=fan_in,
                    out_channels=fan_out,
                    activation=activation,
                    boundary_mode=boundary_mode,
                    key=left_key,
                )
            )
            for _ in range(num_blocks - 1):
                this_level_left_arc_blocks.append(
                    left_arc_factory(
                        num_spatial_dims=num_spatial_dims,
                        # All subsequent blocks have the same number of channels
                        in_channels=fan_out,
                        out_channels=fan_out,
                        activation=activation,
                        boundary_mode=boundary_mode,
                        key=left_key,
                    )
                )
            self.left_arc_blocks.append(this_level_left_arc_blocks)

            self.up_sampling_blocks.append(
                up_sampling_factory(
                    num_spatial_dims=num_spatial_dims,
                    in_channels=fan_out,
                    out_channels=fan_out // self.reduction_factor,
                    activation=activation,
                    boundary_mode=boundary_mode,
                    key=up_key,
                )
            )

            this_level_right_arc_blocks = []
            # The first block changes the number of channels, and operates
            # together with incoming skip connection
            this_level_right_arc_blocks.append(
                right_arc_factory(
                    num_spatial_dims=num_spatial_dims,
                    in_channels=fan_out // self.reduction_factor + fan_in,
                    out_channels=fan_in,
                    activation=activation,
                    boundary_mode=boundary_mode,
                    key=right_key,
                )
            )
            for _ in range(num_blocks - 1):
                # All subsequent blocks have the same number of channels
                this_level_right_arc_blocks.append(
                    right_arc_factory(
                        num_spatial_dims=num_spatial_dims,
                        in_channels=fan_in,
                        out_channels=fan_in,
                        activation=activation,
                        boundary_mode=boundary_mode,
                        key=right_key,
                    )
                )
            self.right_arc_blocks.append(this_level_right_arc_blocks)

    def __call__(self, x: Any) -> Any:
        spatial_shape = x.shape[1:]
        for dims in spatial_shape:
            if dims % self.reduction_factor**self.num_levels != 0:
                raise ValueError("Spatial dim issue")

        x = self.lifting(x)

        skips = []
        for down, left_list in zip(self.down_sampling_blocks, self.left_arc_blocks):
            # If num_levels is 0, the loop will not run
            skips.append(x)
            x = down(x)

            # The last in the loop is the bottleneck block
            for left in left_list:
                x = left(x)

        for up, right_list in zip(
            reversed(self.up_sampling_blocks),
            reversed(self.right_arc_blocks),
        ):
            # If num_levels is 0, the loop will not run
            skip = skips.pop()
            x = up(x)
            # Equinox models are by default single batch, hence the channels are
            # at axis=0
            x = jnp.concatenate((skip, x), axis=0)
            for right in right_list:
                x = right(x)

        x = self.projection(x)

        return x

    @property
    def receptive_field(self) -> tuple[tuple[float, float], ...]:
        lifting_receptive_field = self.lifting.receptive_field
        projection_receptive_fields = self.projection.receptive_field

        down_receptive_fields = tuple(
            block.receptive_field for block in self.down_sampling_blocks
        )
        left_receptive_fields = tuple(
            tuple(block.receptive_field for block in block_list)
            for block_list in self.left_arc_blocks
        )
        up_receptive_fields = tuple(
            block.receptive_field for block in self.up_sampling_blocks
        )
        right_receptive_fields = tuple(
            tuple(block.receptive_field for block in block_list)
            for block_list in self.right_arc_blocks
        )

        spatial_reduction = tuple(
            self.reduction_factor**level for level in range(0, self.num_levels + 1)
        )

        # Down block acts on the fan_in spatial resolution
        scaled_down_receptive_field = tuple(
            tuple(
                (c_i_backward * r, c_i_forward * r)
                for (c_i_backward, c_i_forward) in conv_receptive_field
            )
            for conv_receptive_field, r in zip(
                down_receptive_fields, spatial_reduction[:-1]
            )
        )

        # Left block acts on the fan_out spatial resolution
        scaled_left_receptive_field = tuple(
            tuple(
                (c_i_backward * r, c_i_forward * r)
                for (c_i_backward, c_i_forward) in conv_receptive_field
            )
            for conv_receptive_field_list, r in zip(
                left_receptive_fields, spatial_reduction[1:]
            )
            for conv_receptive_field in conv_receptive_field_list
        )

        # Up block acts on the fan_out spatial resolution
        scaled_up_receptive_field = tuple(
            tuple(
                (c_i_backward * r, c_i_forward * r)
                for (c_i_backward, c_i_forward) in conv_receptive_field
            )
            for conv_receptive_field, r in zip(
                up_receptive_fields, spatial_reduction[1:]
            )
        )

        # Right block acts on the fan_in spatial resolution
        scaled_right_receptive_field = tuple(
            tuple(
                (c_i_backward * r, c_i_forward * r)
                for (c_i_backward, c_i_forward) in conv_receptive_field
            )
            for conv_receptive_field_list, r in zip(
                right_receptive_fields, spatial_reduction[:-1]
            )
            for conv_receptive_field in conv_receptive_field_list
        )

        collection_of_receptive_fields = (
            lifting_receptive_field,
            *scaled_down_receptive_field,
            *scaled_left_receptive_field,
            *scaled_up_receptive_field,
            *scaled_right_receptive_field,
            projection_receptive_fields,
        )

        return sum_receptive_fields(collection_of_receptive_fields)
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int,
    num_levels: int,
    num_blocks: int,
    activation: Callable,
    key: PRNGKeyArray,
    reduction_factor: int = 2,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ],
    channel_multipliers: Optional[tuple[int, ...]] = None,
    lifting_factory: BlockFactory = ClassicDoubleConvBlockFactory(),
    down_sampling_factory: BlockFactory = LinearConvDownBlockFactory(),
    left_arc_factory: BlockFactory = ClassicDoubleConvBlockFactory(),
    up_sampling_factory: BlockFactory = LinearConvUpBlockFactory(),
    right_arc_factory: BlockFactory = ClassicDoubleConvBlockFactory(),
    projection_factory: BlockFactory = LinearChannelAdjustBlockFactory()
)

Generic constructor for hierarchical block-based architectures like UNets. (For the classic UNet, use pdequinox.arch.ClassicUNet instead.)

Hierarchical architectures us a number of different spatial resolutions. The lower the resolution, the wider the receptive field of convolutions.

Allows to increase the number of blocks per level via the num_blocks argument. This will be identical for the left arc (=encoder) and the right arc (=decoder). No multi-skip as in PDEArena.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • hidden_channels: The number of channels in the hidden layers. This refers to the highest resolution. Right after the input, the input channels will be lifted to this feature dimension without changing the spatial resolution.
  • num_levels: The number of levels in the hierarchy. This is the number of down and up sampling blocks. If set to 0, this will just be a classical conv net. If set to 1, this will be a single down and up sampling block etc. The total number of resolutions are num_levels + 1.
  • num_blocks: The number of blocks to use at each level. (Also affects the number of blocks in the bottleneck.)
  • activation: The activation function to use in the blocks.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • reduction_factor: The factor by which the spatial resolution is reduced at each level. This has to be an integer. In order to avoid ambiguities in shapes, it is best if the input spatial resolution is a multiple of reduction_factor ** num_levels. Default is 2.
  • boundary_mode: The boundary mode to use for the convolution. (Keyword only argument)
  • channel_multipliers: The factor by which the number of channels is multiplied at each level. If set to None, the channels will grow by a factor of reduction_factor at each level. This is similar to the classical UNet which trades spatial resolution for feature dimension. Note however, that the parameters of convolutions scale with the mapped channels, hence the majority of numbers will then be in the coarsest representation. Supply a tuple of integers that represent the desired number of channels at each resolution different than the original one. The length of the tuple must match num_levels. For example, to not change the number of channels at any level, set this to (1,) * num_levels. Default is None.
  • lifting_factory: The factory to use for the lifting block. Default is ClassicDoubleConvBlockFactory which is a classic double convolution block.
  • down_sampling_factory: The factory to use for the down sampling blocks. This must be a block that is able to change the spatial resolution. Default is LinearConvDownBlockFactory which is a simple linear strided convolution block.
  • left_arch_factory: The factory to use for the left architecture blocks. Default is ClassicDoubleConvBlockFactory which is a classic double convolution block.
  • up_sampling_factory: The factory to use for the up sampling blocks. This must be a block that is able to change the spatial resolution. It should work in conjunction with the down_sampling_factory. Default is LinearConvUpBlockFactory which is a simple linear strided transposed convolution block.
  • right_arch_factory: The factory to use for the right architecture blocks. Default is ClassicDoubleConvBlockFactory which is a classic double convolution block.
  • projection_factory: The factory to use for the projection block. Default is LinearChannelAdjustBlockFactory which is simply a linear 1x1 convolution for channel adjustment.
Source code in pdequinox/_hierarchical.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int,
    num_levels: int,
    num_blocks: int,
    activation: Callable,
    key: PRNGKeyArray,
    reduction_factor: int = 2,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"],
    channel_multipliers: Optional[tuple[int, ...]] = None,
    lifting_factory: BlockFactory = ClassicDoubleConvBlockFactory(),
    down_sampling_factory: BlockFactory = LinearConvDownBlockFactory(),
    left_arc_factory: BlockFactory = ClassicDoubleConvBlockFactory(),
    up_sampling_factory: BlockFactory = LinearConvUpBlockFactory(),
    right_arc_factory: BlockFactory = ClassicDoubleConvBlockFactory(),
    projection_factory: BlockFactory = LinearChannelAdjustBlockFactory(),
):
    """
    Generic constructor for hierarchical block-based architectures like
    UNets. (For the classic UNet, use `pdequinox.arch.ClassicUNet` instead.)

    Hierarchical architectures us a number of different spatial resolutions.
    The lower the resolution, the wider the receptive field of convolutions.

    Allows to increase the number of blocks per level via the `num_blocks`
    argument. This will be identical for the left arc (=encoder) and the
    right arc (=decoder). **No multi-skip** as in PDEArena.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `hidden_channels`: The number of channels in the hidden layers. This
        refers to the highest resolution. Right after the input, the input
        channels will be lifted to this feature dimension without changing
        the spatial resolution.
    - `num_levels`: The number of levels in the hierarchy. This is the
        number of down and up sampling blocks. If set to 0, this will just
        be a classical conv net. If set to 1, this will be a single down and
        up sampling block etc. The total number of resolutions are
        `num_levels + 1`.
    - `num_blocks`: The number of blocks to use at each level. (Also affects
        the number of blocks in the bottleneck.)
    - `activation`: The activation function to use in the blocks.
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `reduction_factor`: The factor by which the spatial resolution is
        reduced at each level. This has to be an integer. In order to avoid
        ambiguities in shapes, it is best if the input spatial resolution is
        a multiple of `reduction_factor ** num_levels`. Default is `2`.
    - `boundary_mode`: The boundary mode to use for the convolution.
        (Keyword only argument)
    - `channel_multipliers`: The factor by which the number of channels is
        multiplied at each level. If set to `None`, the channels will grow
        by a factor of `reduction_factor` at each level. This is similar to
        the classical UNet which trades spatial resolution for feature
        dimension. Note however, that the parameters of convolutions scale
        with the mapped channels, hence the majority of numbers will then be
        in the coarsest representation. Supply a tuple of integers that
        represent the desired number of channels at each resolution
        different than the original one. The length of the tuple must match
        `num_levels`. For example, to not change the number of channels at
        any level, set this to `(1,) * num_levels`. Default is `None`.
    - `lifting_factory`: The factory to use for the lifting block.
        Default is `ClassicDoubleConvBlockFactory` which is a classic double
        convolution block.
    - `down_sampling_factory`: The factory to use for the down sampling
        blocks. This must be a block that is able to change the spatial
        resolution. Default is `LinearConvDownBlockFactory` which is a
        simple linear strided convolution block.
    - `left_arch_factory`: The factory to use for the left architecture
        blocks. Default is `ClassicDoubleConvBlockFactory` which is a
        classic double convolution block.
    - `up_sampling_factory`: The factory to use for the up sampling blocks.
        This must be a block that is able to change the spatial resolution.
        It should work in conjunction with the `down_sampling_factory`.
        Default is `LinearConvUpBlockFactory` which is a simple linear
        strided transposed convolution block.
    - `right_arch_factory`: The factory to use for the right architecture
        blocks. Default is `ClassicDoubleConvBlockFactory` which is a
        classic double convolution block.
    - `projection_factory`: The factory to use for the projection block.
        Default is `LinearChannelAdjustBlockFactory` which is simply a
        linear 1x1 convolution for channel adjustment.
    """
    self.down_sampling_blocks = []
    self.left_arc_blocks = []
    self.up_sampling_blocks = []
    self.right_arc_blocks = []
    self.reduction_factor = reduction_factor
    self.num_levels = num_levels

    key, lifting_key, projection_key = jr.split(key, 3)

    self.lifting = lifting_factory(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=hidden_channels,
        activation=activation,
        boundary_mode=boundary_mode,
        key=lifting_key,
    )
    self.projection = projection_factory(
        num_spatial_dims=num_spatial_dims,
        in_channels=hidden_channels,
        out_channels=out_channels,
        activation=activation,
        boundary_mode=boundary_mode,
        key=projection_key,
    )

    if channel_multipliers is not None:
        if len(channel_multipliers) != num_levels:
            raise ValueError("len(channel_multipliers) must match num_levels")
    else:
        channel_multipliers = tuple(
            self.reduction_factor**i for i in range(1, num_levels + 1)
        )

    self.channel_multipliers = channel_multipliers

    channel_list = [
        hidden_channels,
    ] + [hidden_channels * m for m in channel_multipliers]

    if num_blocks < 1:
        raise ValueError("num_blocks must be at least 1")
    self.num_blocks = num_blocks

    for (
        fan_in,
        fan_out,
    ) in zip(
        channel_list[:-1],
        channel_list[1:],
    ):
        # If num_levels is 0, the loop will not run
        key, down_key, left_key, up_key, right_key = jr.split(key, 5)
        self.down_sampling_blocks.append(
            down_sampling_factory(
                num_spatial_dims=num_spatial_dims,
                in_channels=fan_in,
                out_channels=fan_in,
                activation=activation,
                boundary_mode=boundary_mode,
                key=down_key,
            )
        )

        this_level_left_arc_blocks = []
        # The first block changes the number of channels
        this_level_left_arc_blocks.append(
            left_arc_factory(
                num_spatial_dims=num_spatial_dims,
                in_channels=fan_in,
                out_channels=fan_out,
                activation=activation,
                boundary_mode=boundary_mode,
                key=left_key,
            )
        )
        for _ in range(num_blocks - 1):
            this_level_left_arc_blocks.append(
                left_arc_factory(
                    num_spatial_dims=num_spatial_dims,
                    # All subsequent blocks have the same number of channels
                    in_channels=fan_out,
                    out_channels=fan_out,
                    activation=activation,
                    boundary_mode=boundary_mode,
                    key=left_key,
                )
            )
        self.left_arc_blocks.append(this_level_left_arc_blocks)

        self.up_sampling_blocks.append(
            up_sampling_factory(
                num_spatial_dims=num_spatial_dims,
                in_channels=fan_out,
                out_channels=fan_out // self.reduction_factor,
                activation=activation,
                boundary_mode=boundary_mode,
                key=up_key,
            )
        )

        this_level_right_arc_blocks = []
        # The first block changes the number of channels, and operates
        # together with incoming skip connection
        this_level_right_arc_blocks.append(
            right_arc_factory(
                num_spatial_dims=num_spatial_dims,
                in_channels=fan_out // self.reduction_factor + fan_in,
                out_channels=fan_in,
                activation=activation,
                boundary_mode=boundary_mode,
                key=right_key,
            )
        )
        for _ in range(num_blocks - 1):
            # All subsequent blocks have the same number of channels
            this_level_right_arc_blocks.append(
                right_arc_factory(
                    num_spatial_dims=num_spatial_dims,
                    in_channels=fan_in,
                    out_channels=fan_in,
                    activation=activation,
                    boundary_mode=boundary_mode,
                    key=right_key,
                )
            )
        self.right_arc_blocks.append(this_level_right_arc_blocks)
__call__ ¤
__call__(x: Any) -> Any
Source code in pdequinox/_hierarchical.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def __call__(self, x: Any) -> Any:
    spatial_shape = x.shape[1:]
    for dims in spatial_shape:
        if dims % self.reduction_factor**self.num_levels != 0:
            raise ValueError("Spatial dim issue")

    x = self.lifting(x)

    skips = []
    for down, left_list in zip(self.down_sampling_blocks, self.left_arc_blocks):
        # If num_levels is 0, the loop will not run
        skips.append(x)
        x = down(x)

        # The last in the loop is the bottleneck block
        for left in left_list:
            x = left(x)

    for up, right_list in zip(
        reversed(self.up_sampling_blocks),
        reversed(self.right_arc_blocks),
    ):
        # If num_levels is 0, the loop will not run
        skip = skips.pop()
        x = up(x)
        # Equinox models are by default single batch, hence the channels are
        # at axis=0
        x = jnp.concatenate((skip, x), axis=0)
        for right in right_list:
            x = right(x)

    x = self.projection(x)

    return x