Fourier Neural Operator (FNO)¤
A Fourier pseudo-spectral architecture.
pdequinox.arch.ClassicFNO
¤
Bases: Sequential
Source code in pdequinox/arch/_classic_fno.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
|
__init__
¤
__init__(
num_spatial_dims: int,
in_channels: int,
out_channels: int,
*,
hidden_channels: int = 32,
num_modes: int = 12,
num_blocks: int = 4,
activation: Callable = jax.nn.gelu,
boundary_mode: Optional[
Literal["periodic", "dirichlet", "neumann"]
] = None,
key: PRNGKeyArray
)
The vanilla Fourier Neural Operator very close to the original Li et al. (2020) paper.
Performs spectral convolution in Fourier to obtain global receptive field.
Note that this architecture does not take a boundary_mode
argument.
The authors argue that arbitrary boundary conditions can be learned.
Arguments:
num_spatial_dims
: The number of spatial dimensions. For example traditional convolutions for image processing have this set to2
.in_channels
: The number of input channels.out_channels
: The number of output channels.hidden_channels
: The number of channels in the hidden layers. Default is32
.num_modes
: The number of modes to use in the Fourier basis. Think of modes as the equivalence of kernel size for classical convolutions. Default is12
.num_blocks
: The number of blocks to use. One block consists of one spectral convolution with a byass by a 1x1 convolution, followed by the activation function. Default is4
.activation
: The activation function to use in the blocks. Default isjax.nn.gelu
. This is often preferrable overjax.nn.relu
because it recovers more higher modes.boundary_mode
: Unused, just for compatibility with other architectures.key
: Ajax.random.PRNGKey
used to provide randomness for parameter initialisation. (Keyword only argument.)
See also the reference implementation in PyTorch:
https://github.com/neuraloperator/neuraloperator/blob/af93f781d5e013f8ba5c52baa547f2ada304ffb0/fourier_1d.py#L62
Source code in pdequinox/arch/_classic_fno.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
|
__call__
¤
__call__(x)
Source code in pdequinox/_sequential.py
111 112 113 114 115 116 |
|