Architectural Constructors¤
There are two primary architectural constructors for Sequential and Hierarchical
Networks that allow for composability with the PDEquinox
blocks.
Sequential Constructor¤
The pdequinox.Sequential
network constructor is defined by:
- a lifting block \(\mathcal{L}\)
- \(N\) blocks \(\left \{ \mathcal{B}_i \right\}_{i=1}^N\)
- a projection block \(\mathcal{P}\)
- the hidden channels within the sequential processing
- the number of blocks \(N\) (one can also supply a list of hidden channels if they shall be different between blocks)
Hierarchical Constructor¤
The pdequinox.Hierarchical
network constructor is defined by:
- a lifting block \(\mathcal{L}\)
- The number of levels \(D\) (i.e., the number of additional hierarchies). Setting \(D = 0\) recovers the sequential processing.
- a list of \(D\) blocks \(\left \{ \mathcal{D}_i \right\}_{i=1}^D\) for downsampling, i.e. mapping downwards to the lower hierarchy (oftentimes this is that they halve the spatial axes while keeping the number of channels)
- a list of \(D\) blocks \(\left \{ \mathcal{B}_i^l \right\}_{i=1}^D\) for processing in the left arc (oftentimes this changes the number of channels, e.g. doubles it such that the combination of downsampling and left processing halves the spatial resolution and doubles the feature count)
- a list of \(D\) blocks \(\left \{ \mathcal{U}_i \right\}_{i=1}^D\) for upsamping, i.e., mapping upwards to the higher hierarchy (oftentimes this doubles the spatial resolution; at the same time it halves the feature count such that we can concatenate a skip connection)
- a list of \(D\) blocks \(\left \{ \mathcal{B}_i^r \right\}_{i=1}^D\) for processing in the right arc (oftentimes this changes the number of channels, e.g. halves it such that the combination of upsampling and right processing doubles the spatial resolution and halves the feature count)
- a projection block \(\mathcal{P}\)
- the hidden channels within the hierarchical processing (if just an integer is provided; this is assumed to be the number of hidden channels in the highest hierarchy.)
Beyond Architectural Constructors¤
For completion, pdequinox.arch
also provides a pdequinox.arch.ConvNet
which is a simple
feed-forward convolutional network. It also provides pdequinox.arch.MLP
which is a dense
networks which also requires pre-defining the number of resolution points. -->
API¤
pdequinox.Sequential
¤
Bases: Module
Source code in pdequinox/_sequential.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
|
__init__
¤
__init__(
num_spatial_dims: int,
in_channels: int,
out_channels: int,
*,
hidden_channels: Union[Sequence[int], int],
num_blocks: int,
activation: Callable,
key: PRNGKeyArray,
boundary_mode: Literal[
"periodic", "dirichlet", "neumann"
],
lifting_factory: BlockFactory = LinearChannelAdjustBlockFactory(),
block_factory: BlockFactory = ClassicResBlockFactory(),
projection_factory: BlockFactory = LinearChannelAdjustBlockFactory()
)
Generic constructor for sequential block-based architectures like ResNets.
Arguments:
num_spatial_dims
: The number of spatial dimensions. For example traditional convolutions for image processing have this set to2
.in_channels
: The number of input channels.out_channels
: The number of output channels.hidden_channels
: The number of channels in the hidden layers. Either an integer to have the same number of hidden channels in the layers between all blocks, or a list ofnum_blocks + 1
integers.num_blocks
: The number of blocks to use. Must be an integer greater equal than1
.activation
: The activation function to use in the blocks.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)boundary_mode
: The boundary mode to use for the convolution. (Keyword only argument)lifting_factory
: The factory to use for the lifting block. Default isLinearChannelAdjustBlockFactory
which is simply a linear 1x1 convolution for channel adjustment.block_factory
: The factory to use for the blocks. Default isClassicResBlockFactory
which is a classic ResNet block (with postactivation)projection_factory
: The factory to use for the projection block. Default isLinearChannelAdjustBlockFactory
which is simply a linear 1x1 convolution for channel adjustment.
Source code in pdequinox/_sequential.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
|
__call__
¤
__call__(x)
Source code in pdequinox/_sequential.py
111 112 113 114 115 116 |
|
pdequinox.Hierarchical
¤
Bases: Module
Source code in pdequinox/_hierarchical.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 |
|
__init__
¤
__init__(
num_spatial_dims: int,
in_channels: int,
out_channels: int,
*,
hidden_channels: int,
num_levels: int,
num_blocks: int,
activation: Callable,
key: PRNGKeyArray,
reduction_factor: int = 2,
boundary_mode: Literal[
"periodic", "dirichlet", "neumann"
],
channel_multipliers: Optional[tuple[int, ...]] = None,
lifting_factory: BlockFactory = ClassicDoubleConvBlockFactory(),
down_sampling_factory: BlockFactory = LinearConvDownBlockFactory(),
left_arc_factory: BlockFactory = ClassicDoubleConvBlockFactory(),
up_sampling_factory: BlockFactory = LinearConvUpBlockFactory(),
right_arc_factory: BlockFactory = ClassicDoubleConvBlockFactory(),
projection_factory: BlockFactory = LinearChannelAdjustBlockFactory()
)
Generic constructor for hierarchical block-based architectures like
UNets. (For the classic UNet, use pdequinox.arch.ClassicUNet
instead.)
Hierarchical architectures us a number of different spatial resolutions. The lower the resolution, the wider the receptive field of convolutions.
Allows to increase the number of blocks per level via the num_blocks
argument. This will be identical for the left arc (=encoder) and the
right arc (=decoder). No multi-skip as in PDEArena.
Arguments:
num_spatial_dims
: The number of spatial dimensions. For example traditional convolutions for image processing have this set to2
.in_channels
: The number of input channels.out_channels
: The number of output channels.hidden_channels
: The number of channels in the hidden layers. This refers to the highest resolution. Right after the input, the input channels will be lifted to this feature dimension without changing the spatial resolution.num_levels
: The number of levels in the hierarchy. This is the number of down and up sampling blocks. If set to 0, this will just be a classical conv net. If set to 1, this will be a single down and up sampling block etc. The total number of resolutions arenum_levels + 1
.num_blocks
: The number of blocks to use at each level. (Also affects the number of blocks in the bottleneck.)activation
: The activation function to use in the blocks.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)reduction_factor
: The factor by which the spatial resolution is reduced at each level. This has to be an integer. In order to avoid ambiguities in shapes, it is best if the input spatial resolution is a multiple ofreduction_factor ** num_levels
. Default is2
.boundary_mode
: The boundary mode to use for the convolution. (Keyword only argument)channel_multipliers
: The factor by which the number of channels is multiplied at each level. If set toNone
, the channels will grow by a factor ofreduction_factor
at each level. This is similar to the classical UNet which trades spatial resolution for feature dimension. Note however, that the parameters of convolutions scale with the mapped channels, hence the majority of numbers will then be in the coarsest representation. Supply a tuple of integers that represent the desired number of channels at each resolution different than the original one. The length of the tuple must matchnum_levels
. For example, to not change the number of channels at any level, set this to(1,) * num_levels
. Default isNone
.lifting_factory
: The factory to use for the lifting block. Default isClassicDoubleConvBlockFactory
which is a classic double convolution block.down_sampling_factory
: The factory to use for the down sampling blocks. This must be a block that is able to change the spatial resolution. Default isLinearConvDownBlockFactory
which is a simple linear strided convolution block.left_arch_factory
: The factory to use for the left architecture blocks. Default isClassicDoubleConvBlockFactory
which is a classic double convolution block.up_sampling_factory
: The factory to use for the up sampling blocks. This must be a block that is able to change the spatial resolution. It should work in conjunction with thedown_sampling_factory
. Default isLinearConvUpBlockFactory
which is a simple linear strided transposed convolution block.right_arch_factory
: The factory to use for the right architecture blocks. Default isClassicDoubleConvBlockFactory
which is a classic double convolution block.projection_factory
: The factory to use for the projection block. Default isLinearChannelAdjustBlockFactory
which is simply a linear 1x1 convolution for channel adjustment.
Source code in pdequinox/_hierarchical.py
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
|
__call__
¤
__call__(x: Any) -> Any
Source code in pdequinox/_hierarchical.py
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
|