UNet¤
A hierarchical multi-scale convolutional network.
pdequinox.arch.ClassicUNet
¤
Bases: Hierarchical
Source code in pdequinox/arch/_classic_u_net.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
|
__init__
¤
__init__(
num_spatial_dims: int,
in_channels: int,
out_channels: int,
*,
hidden_channels: int = 16,
num_levels: int = 4,
use_norm: bool = True,
activation: Callable = jax.nn.relu,
key: PRNGKeyArray,
boundary_mode: Literal[
"periodic", "dirichlet", "neumann"
] = "periodic"
)
The vanilla UNet archiecture very close to the original Ronneberger et al. (2015) paper.
Uses a hierarchy of spatial resolutions to obtain a wide receptive field.
This version does not use maxpool for downsampling but instead uses a strided convolution. Up- and downsampling use 3x3 operations (instead of 2x2 operations). If active, uses group norm instead of batch norm.
Arguments:
num_spatial_dims
: The number of spatial dimensions. For example traditional convolutions for image processing have this set to2
.in_channels
: The number of input channels.out_channels
: The number of output channels.hidden_channels
: The number of channels in the hidden layers. Default is16
. This is the number of channels in finest (input) spatial resolution.num_levels
: The number of levels in the hierarchy. Default is4
. Each level halves the spatial resolution while doubling the number of channels.use_norm
: IfTrue
, uses group norm as part of double convolutions. Default isTrue
.activation
: The activation function to use in the blocks. Default isjax.nn.relu
.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)boundary_mode
: The boundary mode to use. Default isperiodic
.
Source code in pdequinox/arch/_classic_u_net.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
|
__call__
¤
__call__(x: Any) -> Any
Source code in pdequinox/_hierarchical.py
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
|
pdequinox.arch.ModernUNet
¤
Bases: Hierarchical
Source code in pdequinox/arch/_modern_u_net.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
|
__init__
¤
__init__(
num_spatial_dims: int,
in_channels: int,
out_channels: int,
*,
hidden_channels: int = 16,
num_levels: int = 4,
num_blocks: int = 2,
channel_multipliers: Optional[tuple[int, ...]] = None,
use_norm: bool = True,
activation: Callable = jax.nn.relu,
key: PRNGKeyArray,
boundary_mode: Literal[
"periodic", "dirichlet", "neumann"
] = "periodic"
)
A modern UNet version close to the ones used by Gupta & Brandstetter (2023) in PDEArena.
Uses ResNet blocks for the left and right arc of the UNet.
In comparison to the version in PDEArena, the num_block
in the left
and right arc of the UNet are identical (PDEArena uses one additional in
the right arc). Here, we also do not do multi-skips, only the last state
in the processing of one hierarchy level is skip-connected to the
decoder.
Arguments:
num_spatial_dims
: The number of spatial dimensions. For example traditional convolutions for image processing have this set to2
.in_channels
: The number of input channels.out_channels
: The number of output channels.hidden_channels
: The number of channels in the hidden layers. Default is16
. This is the number of channels in finest (input) spatial resolution.num_levels
: The number of levels in the hierarchy. Default is4
. Each level halves the spatial resolution. By default, it also doubles the number of channels. This can be changed by settingchannel_multipliers
.num_blocks
: The number of blocks in the left and right arc of the UNet, for each level. One block is a single modern ResNet block (using pre-activation) consisting of two convolutions. The default value ofnum_blocks
is2
meaning that for each level in both the encoder, bottleneck and decoder, two blocks are used. Hence, there are a total of four convolutions contributing receptive field per level.channel_multipliers
: A tuple of integers that specify the channel multipliers for each level. IfNone
, the default is to double the number of channels at each level (fornum_levels=4
this would mean(2, 4, 8, 16)
). The length of the tuple should be equal tonum_levels
.use_norm
: IfTrue
, uses group norm as part of the ResNet blocks. Default isTrue
.activation
: The activation function to use in the blocks. Default isjax.nn.relu
.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)boundary_mode
: The boundary mode to use. Default isperiodic
.
Source code in pdequinox/arch/_modern_u_net.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
|
__call__
¤
__call__(x: Any) -> Any
Source code in pdequinox/_hierarchical.py
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
|