Feed-Forward Conv Net¤
pdequinox.arch.ConvNet
¤
Bases: Module
Source code in pdequinox/arch/_conv_net.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
|
__init__
¤
__init__(
num_spatial_dims: int,
in_channels: int,
out_channels: int,
*,
hidden_channels: int = 16,
depth: int = 10,
activation: Callable = jax.nn.relu,
kernel_size: int = 3,
final_activation: Callable = _identity,
use_bias: bool = True,
use_final_bias: bool = True,
boundary_mode: Literal[
"periodic", "dirichlet", "neumann"
] = "periodic",
key: PRNGKeyArray,
zero_bias_init: bool = False
)
A simple feed-forward convolutional neural network.
Arguments:
num_spatial_dims
: The number of spatial dimensions. For example traditional convolutions for image processing have this set to2
.in_channels
: The number of input channels.out_channels
: The number of output channels.hidden_channels
: The number of channels in the hidden layers. Default is16
.depth
: The number of hidden layers. Default is10
. Ifdepth == 0
, there will only be one linear convolution from the input channels to the output channels. Hence,depth
denotes the number of hidden layers. The number of convolutions performed isdepth + 1
.activation
: The activation function to use in the hidden layers. Default isjax.nn.relu
.kernel_size
: The size of the convolutional kernel. Default is3
.final_activation
: The activation function to use in the final layer. Default is the identity function.use_bias
: IfTrue
, uses bias in the hidden layers. Default isTrue
.use_final_bias
: IfTrue
, uses bias in the final layer. Default isTrue
.boundary_mode
: The boundary mode to use. Default isperiodic
.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)zero_bias_init
: IfTrue
, initialises the bias to zero. Default isFalse
.
Source code in pdequinox/arch/_conv_net.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
|
__call__
¤
__call__(x: jax.Array) -> jax.Array
Source code in pdequinox/arch/_conv_net.py
111 112 113 114 |
|