Multi-Layer Perceptron (MLP)¤
pdequinox.arch.MLP
¤
Bases: Module
Source code in pdequinox/arch/_mlp.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | |
__init__
¤
__init__(
num_spatial_dims: int,
in_channels: int,
out_channels: int,
*,
num_points: int,
width_size: int = 64,
depth: int = 3,
activation: Callable = jax.nn.relu,
final_activation: Callable = _identity,
use_bias: bool = True,
use_final_bias: bool = True,
boundary_mode: Optional[
Literal["periodic", "dirichlet", "neumann"]
] = None,
key: PRNGKeyArray
)
A MLP for the conv format.
Takes states of shape (in_channels, #num_points) with a leading
in_channels axis and as many spatial axes as num_spatial_dims.
Internally, the input is flattened and given to a classical MLP. The
conv shape is restored in the end.
Contrary to convolutional networks, the num_points must be supplied!
Arguments:
num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to2.in_channels: The number of input channels.out_channels: The number of output channels.num_points: The number of points in each spatial dimension. Must be supplied.width_size: The width of the hidden layers. Default is64.depth: The number of hidden layers. Default is3. The number of linear-affine transformations performed isdepth + 1.activation: The activation function to use in the hidden layers. Default isjax.nn.relu.final_activation: The activation function to use in the final layer. Default is the identity function.use_bias: IfTrue, uses bias in the hidden layers. Default isTrue.use_final_bias: IfTrue, uses bias in the final layer. Default isTrue.boundary_mode: Unused, just for compatibility with other architectures.key: Ajax.random.PRNGKeyused to provide randomness for parameter initialisation. (Keyword only argument.)
Source code in pdequinox/arch/_mlp.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | |
__call__
¤
__call__(x)
Source code in pdequinox/arch/_mlp.py
96 97 98 99 100 101 102 103 104 | |