Skip to content

UNet¤

A hierarchical multi-scale convolutional network.

pdequinox.arch.ClassicUNet ¤

Bases: Hierarchical

Source code in pdequinox/arch/_classic_u_net.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class ClassicUNet(Hierarchical):
    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        hidden_channels: int = 16,
        num_levels: int = 4,
        use_norm: bool = True,
        activation: Callable = jax.nn.relu,
        key: PRNGKeyArray,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"] = "periodic",
    ):
        """
        The vanilla UNet archiecture very close to the original Ronneberger et
        al. (2015) paper.

        Uses a hierarchy of spatial resolutions to obtain a wide receptive
        field.

        This version does **not** use maxpool for downsampling but instead uses
        a strided convolution. Up- and downsampling use 3x3 operations (instead
        of 2x2 operations). If active, uses group norm instead of batch norm.


        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `hidden_channels`: The number of channels in the hidden layers.
            Default is `16`. This is the number of channels in finest (input)
            spatial resolution.
        - `num_levels`: The number of levels in the hierarchy. Default is `4`.
            Each level halves the spatial resolution while doubling the number
            of channels.
        - `use_norm`: If `True`, uses group norm as part of double convolutions.
            Default is `True`.
        - `activation`: The activation function to use in the blocks. Default is
            `jax.nn.relu`.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `boundary_mode`: The boundary mode to use. Default is `periodic`.
        """
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            hidden_channels=hidden_channels,
            num_levels=num_levels,
            num_blocks=1,
            activation=activation,
            key=key,
            boundary_mode=boundary_mode,
            lifting_factory=ClassicDoubleConvBlockFactory(
                use_norm=use_norm,
            ),
            down_sampling_factory=LinearConvDownBlockFactory(),
            left_arc_factory=ClassicDoubleConvBlockFactory(
                use_norm=use_norm,
            ),
            up_sampling_factory=LinearConvUpBlockFactory(),
            right_arc_factory=ClassicDoubleConvBlockFactory(
                use_norm=use_norm,
            ),
            projection_factory=LinearChannelAdjustBlockFactory(),
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 16,
    num_levels: int = 4,
    use_norm: bool = True,
    activation: Callable = jax.nn.relu,
    key: PRNGKeyArray,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ] = "periodic"
)

The vanilla UNet archiecture very close to the original Ronneberger et al. (2015) paper.

Uses a hierarchy of spatial resolutions to obtain a wide receptive field.

This version does not use maxpool for downsampling but instead uses a strided convolution. Up- and downsampling use 3x3 operations (instead of 2x2 operations). If active, uses group norm instead of batch norm.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • hidden_channels: The number of channels in the hidden layers. Default is 16. This is the number of channels in finest (input) spatial resolution.
  • num_levels: The number of levels in the hierarchy. Default is 4. Each level halves the spatial resolution while doubling the number of channels.
  • use_norm: If True, uses group norm as part of double convolutions. Default is True.
  • activation: The activation function to use in the blocks. Default is jax.nn.relu.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • boundary_mode: The boundary mode to use. Default is periodic.
Source code in pdequinox/arch/_classic_u_net.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 16,
    num_levels: int = 4,
    use_norm: bool = True,
    activation: Callable = jax.nn.relu,
    key: PRNGKeyArray,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"] = "periodic",
):
    """
    The vanilla UNet archiecture very close to the original Ronneberger et
    al. (2015) paper.

    Uses a hierarchy of spatial resolutions to obtain a wide receptive
    field.

    This version does **not** use maxpool for downsampling but instead uses
    a strided convolution. Up- and downsampling use 3x3 operations (instead
    of 2x2 operations). If active, uses group norm instead of batch norm.


    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `hidden_channels`: The number of channels in the hidden layers.
        Default is `16`. This is the number of channels in finest (input)
        spatial resolution.
    - `num_levels`: The number of levels in the hierarchy. Default is `4`.
        Each level halves the spatial resolution while doubling the number
        of channels.
    - `use_norm`: If `True`, uses group norm as part of double convolutions.
        Default is `True`.
    - `activation`: The activation function to use in the blocks. Default is
        `jax.nn.relu`.
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `boundary_mode`: The boundary mode to use. Default is `periodic`.
    """
    super().__init__(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        hidden_channels=hidden_channels,
        num_levels=num_levels,
        num_blocks=1,
        activation=activation,
        key=key,
        boundary_mode=boundary_mode,
        lifting_factory=ClassicDoubleConvBlockFactory(
            use_norm=use_norm,
        ),
        down_sampling_factory=LinearConvDownBlockFactory(),
        left_arc_factory=ClassicDoubleConvBlockFactory(
            use_norm=use_norm,
        ),
        up_sampling_factory=LinearConvUpBlockFactory(),
        right_arc_factory=ClassicDoubleConvBlockFactory(
            use_norm=use_norm,
        ),
        projection_factory=LinearChannelAdjustBlockFactory(),
    )
__call__ ¤
__call__(x: Any) -> Any
Source code in pdequinox/_hierarchical.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def __call__(self, x: Any) -> Any:
    spatial_shape = x.shape[1:]
    for dims in spatial_shape:
        if dims % self.reduction_factor**self.num_levels != 0:
            raise ValueError("Spatial dim issue")

    x = self.lifting(x)

    skips = []
    for down, left_list in zip(self.down_sampling_blocks, self.left_arc_blocks):
        # If num_levels is 0, the loop will not run
        skips.append(x)
        x = down(x)

        # The last in the loop is the bottleneck block
        for left in left_list:
            x = left(x)

    for up, right_list in zip(
        reversed(self.up_sampling_blocks),
        reversed(self.right_arc_blocks),
    ):
        # If num_levels is 0, the loop will not run
        skip = skips.pop()
        x = up(x)
        # Equinox models are by default single batch, hence the channels are
        # at axis=0
        x = jnp.concatenate((skip, x), axis=0)
        for right in right_list:
            x = right(x)

    x = self.projection(x)

    return x

pdequinox.arch.ModernUNet ¤

Bases: Hierarchical

Source code in pdequinox/arch/_modern_u_net.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class ModernUNet(Hierarchical):
    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        hidden_channels: int = 16,
        num_levels: int = 4,
        num_blocks: int = 2,
        channel_multipliers: Optional[tuple[int, ...]] = None,
        use_norm: bool = True,
        activation: Callable = jax.nn.relu,
        key: PRNGKeyArray,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"] = "periodic",
    ):
        """
        A modern UNet version close to the ones used by Gupta & Brandstetter
        (2023) in PDEArena.

        Uses ResNet blocks for the left and right arc of the UNet.

        In comparison to the version in PDEArena, the `num_block` in the left
        and right arc of the UNet are identical (PDEArena uses one additional in
        the right arc). Here, we also do not do multi-skips, only the last state
        in the processing of one hierarchy level is skip-connected to the
        decoder.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `hidden_channels`: The number of channels in the hidden layers.
            Default is `16`. This is the number of channels in finest (input)
            spatial resolution.
        - `num_levels`: The number of levels in the hierarchy. Default is `4`.
            Each level halves the spatial resolution. By default, it also
            doubles the number of channels. This can be changed by setting
            `channel_multipliers`.
        - `num_blocks`: The number of blocks in the left and right arc of the
            UNet, for each level. One block is a single modern ResNet block
            (using pre-activation) consisting of two convolutions. The default
            value of `num_blocks` is `2` meaning that for each level in both the
            encoder, bottleneck and decoder, two blocks are used. Hence, there
            are a total of four convolutions contributing receptive field per
            level.
        - `channel_multipliers`: A tuple of integers that specify the channel
            multipliers for each level. If `None`, the default is to double the
            number of channels at each level (for `num_levels=4` this would mean
            `(2, 4, 8, 16)`). The length of the tuple should be equal to
            `num_levels`.
        - `use_norm`: If `True`, uses group norm as part of the ResNet blocks.
            Default is `True`.
        - `activation`: The activation function to use in the blocks. Default is
            `jax.nn.relu`.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `boundary_mode`: The boundary mode to use. Default is `periodic`.
        """
        super().__init__(
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            hidden_channels=hidden_channels,
            num_levels=num_levels,
            num_blocks=num_blocks,
            channel_multipliers=channel_multipliers,
            activation=activation,
            key=key,
            boundary_mode=boundary_mode,
            lifting_factory=ModernResBlockFactory(
                use_norm=use_norm,
            ),
            down_sampling_factory=LinearConvDownBlockFactory(),
            left_arc_factory=ModernResBlockFactory(
                use_norm=use_norm,
            ),
            up_sampling_factory=LinearConvUpBlockFactory(),
            right_arc_factory=ModernResBlockFactory(
                use_norm=use_norm,
            ),
            projection_factory=LinearChannelAdjustBlockFactory(),
        )
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 16,
    num_levels: int = 4,
    num_blocks: int = 2,
    channel_multipliers: Optional[tuple[int, ...]] = None,
    use_norm: bool = True,
    activation: Callable = jax.nn.relu,
    key: PRNGKeyArray,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ] = "periodic"
)

A modern UNet version close to the ones used by Gupta & Brandstetter (2023) in PDEArena.

Uses ResNet blocks for the left and right arc of the UNet.

In comparison to the version in PDEArena, the num_block in the left and right arc of the UNet are identical (PDEArena uses one additional in the right arc). Here, we also do not do multi-skips, only the last state in the processing of one hierarchy level is skip-connected to the decoder.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • hidden_channels: The number of channels in the hidden layers. Default is 16. This is the number of channels in finest (input) spatial resolution.
  • num_levels: The number of levels in the hierarchy. Default is 4. Each level halves the spatial resolution. By default, it also doubles the number of channels. This can be changed by setting channel_multipliers.
  • num_blocks: The number of blocks in the left and right arc of the UNet, for each level. One block is a single modern ResNet block (using pre-activation) consisting of two convolutions. The default value of num_blocks is 2 meaning that for each level in both the encoder, bottleneck and decoder, two blocks are used. Hence, there are a total of four convolutions contributing receptive field per level.
  • channel_multipliers: A tuple of integers that specify the channel multipliers for each level. If None, the default is to double the number of channels at each level (for num_levels=4 this would mean (2, 4, 8, 16)). The length of the tuple should be equal to num_levels.
  • use_norm: If True, uses group norm as part of the ResNet blocks. Default is True.
  • activation: The activation function to use in the blocks. Default is jax.nn.relu.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • boundary_mode: The boundary mode to use. Default is periodic.
Source code in pdequinox/arch/_modern_u_net.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 16,
    num_levels: int = 4,
    num_blocks: int = 2,
    channel_multipliers: Optional[tuple[int, ...]] = None,
    use_norm: bool = True,
    activation: Callable = jax.nn.relu,
    key: PRNGKeyArray,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"] = "periodic",
):
    """
    A modern UNet version close to the ones used by Gupta & Brandstetter
    (2023) in PDEArena.

    Uses ResNet blocks for the left and right arc of the UNet.

    In comparison to the version in PDEArena, the `num_block` in the left
    and right arc of the UNet are identical (PDEArena uses one additional in
    the right arc). Here, we also do not do multi-skips, only the last state
    in the processing of one hierarchy level is skip-connected to the
    decoder.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `hidden_channels`: The number of channels in the hidden layers.
        Default is `16`. This is the number of channels in finest (input)
        spatial resolution.
    - `num_levels`: The number of levels in the hierarchy. Default is `4`.
        Each level halves the spatial resolution. By default, it also
        doubles the number of channels. This can be changed by setting
        `channel_multipliers`.
    - `num_blocks`: The number of blocks in the left and right arc of the
        UNet, for each level. One block is a single modern ResNet block
        (using pre-activation) consisting of two convolutions. The default
        value of `num_blocks` is `2` meaning that for each level in both the
        encoder, bottleneck and decoder, two blocks are used. Hence, there
        are a total of four convolutions contributing receptive field per
        level.
    - `channel_multipliers`: A tuple of integers that specify the channel
        multipliers for each level. If `None`, the default is to double the
        number of channels at each level (for `num_levels=4` this would mean
        `(2, 4, 8, 16)`). The length of the tuple should be equal to
        `num_levels`.
    - `use_norm`: If `True`, uses group norm as part of the ResNet blocks.
        Default is `True`.
    - `activation`: The activation function to use in the blocks. Default is
        `jax.nn.relu`.
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `boundary_mode`: The boundary mode to use. Default is `periodic`.
    """
    super().__init__(
        num_spatial_dims=num_spatial_dims,
        in_channels=in_channels,
        out_channels=out_channels,
        hidden_channels=hidden_channels,
        num_levels=num_levels,
        num_blocks=num_blocks,
        channel_multipliers=channel_multipliers,
        activation=activation,
        key=key,
        boundary_mode=boundary_mode,
        lifting_factory=ModernResBlockFactory(
            use_norm=use_norm,
        ),
        down_sampling_factory=LinearConvDownBlockFactory(),
        left_arc_factory=ModernResBlockFactory(
            use_norm=use_norm,
        ),
        up_sampling_factory=LinearConvUpBlockFactory(),
        right_arc_factory=ModernResBlockFactory(
            use_norm=use_norm,
        ),
        projection_factory=LinearChannelAdjustBlockFactory(),
    )
__call__ ¤
__call__(x: Any) -> Any
Source code in pdequinox/_hierarchical.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def __call__(self, x: Any) -> Any:
    spatial_shape = x.shape[1:]
    for dims in spatial_shape:
        if dims % self.reduction_factor**self.num_levels != 0:
            raise ValueError("Spatial dim issue")

    x = self.lifting(x)

    skips = []
    for down, left_list in zip(self.down_sampling_blocks, self.left_arc_blocks):
        # If num_levels is 0, the loop will not run
        skips.append(x)
        x = down(x)

        # The last in the loop is the bottleneck block
        for left in left_list:
            x = left(x)

    for up, right_list in zip(
        reversed(self.up_sampling_blocks),
        reversed(self.right_arc_blocks),
    ):
        # If num_levels is 0, the loop will not run
        skip = skips.pop()
        x = up(x)
        # Equinox models are by default single batch, hence the channels are
        # at axis=0
        x = jnp.concatenate((skip, x), axis=0)
        for right in right_list:
            x = right(x)

    x = self.projection(x)

    return x