Skip to content

Feed-Forward Conv Net¤

pdequinox.arch.ConvNet ¤

Bases: Module

Source code in pdequinox/arch/_conv_net.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class ConvNet(eqx.Module):
    layers: tuple[PhysicsConv, ...]
    activation: Callable
    final_activation: Callable

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        *,
        hidden_channels: int = 16,
        depth: int = 10,
        activation: Callable = jax.nn.relu,
        kernel_size: int = 3,
        final_activation: Callable = _identity,
        use_bias: bool = True,
        use_final_bias: bool = True,
        boundary_mode: Literal["periodic", "dirichlet", "neumann"] = "periodic",
        key: PRNGKeyArray,
        zero_bias_init: bool = False,
    ):
        """
        A simple feed-forward convolutional neural network.

        **Arguments:**

        - `num_spatial_dims`: The number of spatial dimensions. For example
            traditional convolutions for image processing have this set to `2`.
        - `in_channels`: The number of input channels.
        - `out_channels`: The number of output channels.
        - `hidden_channels`: The number of channels in the hidden layers.
            Default is `16`.
        - `depth`: The number of hidden layers. Default is `10`. If `depth ==
            0`, there will only be one **linear** convolution from the input
            channels to the output channels. Hence, `depth` denotes the number
            of hidden layers. The number of convolutions performed is `depth +
            1`.
        - `activation`: The activation function to use in the hidden layers.
            Default is `jax.nn.relu`.
        - `kernel_size`: The size of the convolutional kernel. Default is `3`.
        - `final_activation`: The activation function to use in the final layer.
            Default is the identity function.
        - `use_bias`: If `True`, uses bias in the hidden layers. Default is
            `True`.
        - `use_final_bias`: If `True`, uses bias in the final layer. Default is
            `True`.
        - `boundary_mode`: The boundary mode to use. Default is `periodic`.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        - `zero_bias_init`: If `True`, initialises the bias to zero. Default is
            `False`.
        """

        self.activation = activation
        self.final_activation = final_activation

        keys = jax.random.split(key, depth + 1)

        def conv_constructor(i, o, b, k):
            return PhysicsConv(
                num_spatial_dims=num_spatial_dims,
                in_channels=i,
                out_channels=o,
                kernel_size=kernel_size,
                stride=1,
                dilation=1,
                use_bias=b,
                zero_bias_init=zero_bias_init,
                boundary_mode=boundary_mode,
                key=k,
            )

        layers = []
        if depth == 0:
            layers.append(
                conv_constructor(in_channels, out_channels, use_final_bias, keys[0])
            )
        else:
            layers.append(
                conv_constructor(in_channels, hidden_channels, use_bias, keys[0])
            )
            for i in range(depth - 1):
                layers.append(
                    conv_constructor(
                        hidden_channels, hidden_channels, use_bias, keys[i + 1]
                    )
                )
            layers.append(
                conv_constructor(
                    hidden_channels, out_channels, use_final_bias, keys[-1]
                )
            )

        self.layers = tuple(layers)

    def __call__(self, x: jax.Array) -> jax.Array:
        for layer in self.layers[:-1]:
            x = self.activation(layer(x))
        return self.final_activation(self.layers[-1](x))

    @property
    def receptive_field(self) -> tuple[tuple[float, float], ...]:
        individual_receptive_fields = tuple(
            conv.receptive_field for conv in self.layers
        )
        return sum_receptive_fields(individual_receptive_fields)
__init__ ¤
__init__(
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 16,
    depth: int = 10,
    activation: Callable = jax.nn.relu,
    kernel_size: int = 3,
    final_activation: Callable = _identity,
    use_bias: bool = True,
    use_final_bias: bool = True,
    boundary_mode: Literal[
        "periodic", "dirichlet", "neumann"
    ] = "periodic",
    key: PRNGKeyArray,
    zero_bias_init: bool = False
)

A simple feed-forward convolutional neural network.

Arguments:

  • num_spatial_dims: The number of spatial dimensions. For example traditional convolutions for image processing have this set to 2.
  • in_channels: The number of input channels.
  • out_channels: The number of output channels.
  • hidden_channels: The number of channels in the hidden layers. Default is 16.
  • depth: The number of hidden layers. Default is 10. If depth == 0, there will only be one linear convolution from the input channels to the output channels. Hence, depth denotes the number of hidden layers. The number of convolutions performed is depth + 1.
  • activation: The activation function to use in the hidden layers. Default is jax.nn.relu.
  • kernel_size: The size of the convolutional kernel. Default is 3.
  • final_activation: The activation function to use in the final layer. Default is the identity function.
  • use_bias: If True, uses bias in the hidden layers. Default is True.
  • use_final_bias: If True, uses bias in the final layer. Default is True.
  • boundary_mode: The boundary mode to use. Default is periodic.
  • key: A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)
  • zero_bias_init: If True, initialises the bias to zero. Default is False.
Source code in pdequinox/arch/_conv_net.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def __init__(
    self,
    num_spatial_dims: int,
    in_channels: int,
    out_channels: int,
    *,
    hidden_channels: int = 16,
    depth: int = 10,
    activation: Callable = jax.nn.relu,
    kernel_size: int = 3,
    final_activation: Callable = _identity,
    use_bias: bool = True,
    use_final_bias: bool = True,
    boundary_mode: Literal["periodic", "dirichlet", "neumann"] = "periodic",
    key: PRNGKeyArray,
    zero_bias_init: bool = False,
):
    """
    A simple feed-forward convolutional neural network.

    **Arguments:**

    - `num_spatial_dims`: The number of spatial dimensions. For example
        traditional convolutions for image processing have this set to `2`.
    - `in_channels`: The number of input channels.
    - `out_channels`: The number of output channels.
    - `hidden_channels`: The number of channels in the hidden layers.
        Default is `16`.
    - `depth`: The number of hidden layers. Default is `10`. If `depth ==
        0`, there will only be one **linear** convolution from the input
        channels to the output channels. Hence, `depth` denotes the number
        of hidden layers. The number of convolutions performed is `depth +
        1`.
    - `activation`: The activation function to use in the hidden layers.
        Default is `jax.nn.relu`.
    - `kernel_size`: The size of the convolutional kernel. Default is `3`.
    - `final_activation`: The activation function to use in the final layer.
        Default is the identity function.
    - `use_bias`: If `True`, uses bias in the hidden layers. Default is
        `True`.
    - `use_final_bias`: If `True`, uses bias in the final layer. Default is
        `True`.
    - `boundary_mode`: The boundary mode to use. Default is `periodic`.
    - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
        initialisation. (Keyword only argument.)
    - `zero_bias_init`: If `True`, initialises the bias to zero. Default is
        `False`.
    """

    self.activation = activation
    self.final_activation = final_activation

    keys = jax.random.split(key, depth + 1)

    def conv_constructor(i, o, b, k):
        return PhysicsConv(
            num_spatial_dims=num_spatial_dims,
            in_channels=i,
            out_channels=o,
            kernel_size=kernel_size,
            stride=1,
            dilation=1,
            use_bias=b,
            zero_bias_init=zero_bias_init,
            boundary_mode=boundary_mode,
            key=k,
        )

    layers = []
    if depth == 0:
        layers.append(
            conv_constructor(in_channels, out_channels, use_final_bias, keys[0])
        )
    else:
        layers.append(
            conv_constructor(in_channels, hidden_channels, use_bias, keys[0])
        )
        for i in range(depth - 1):
            layers.append(
                conv_constructor(
                    hidden_channels, hidden_channels, use_bias, keys[i + 1]
                )
            )
        layers.append(
            conv_constructor(
                hidden_channels, out_channels, use_final_bias, keys[-1]
            )
        )

    self.layers = tuple(layers)
__call__ ¤
__call__(x: jax.Array) -> jax.Array
Source code in pdequinox/arch/_conv_net.py
111
112
113
114
def __call__(self, x: jax.Array) -> jax.Array:
    for layer in self.layers[:-1]:
        x = self.activation(layer(x))
    return self.final_activation(self.layers[-1](x))