Created with ❤️ by Machine Learning & Simulation.
Also check out the pullback/vJp rules in more mathmatical notation fkoehler.site/conv-autodiff-table. However, those rules assume a one-channel to one-channel convolution of only one samples (i.e., no batched convolutions). Here, we consider the general convolution routines of major deep learning frameworks.
Important: Take care that same frameworks (e.g., PyTorch) use cross-correlation instead of convolution. That changes the pullback rule into the filter (see link above).
Legend: B = batch size, C_i = input channels, C_o = output channels, K = kernel size, N = spatial size
Primitive | Primal | Pullback/vJp into filter | Pullback/vJp into input |
---|---|---|---|
Julia NNlib | x = NC_iB | w = KC_iC_o | y = NC_oB |
1D "Same"-Padding Convolution | y = conv(x, w, pad=1) | dw = conv(permutedims(dy, (1, 3, 2)), permutedims(x, (1, 3, 2)), pad=1, flipped=true) | dx = conv(dy, permutedims(w, (1, 3, 2)), pad=1, flipped=true) |
JAX (uses cross-correlation of XLA backend) | x = BC_iN | w = C_iC_oK | y = BC_oN |
1D "Same"-Padding Cross-Correlation | y = lax.conv_general_dilated(x, w, (1,), ((1, 1),)) | dw = jnp.flip(lax.conv_general_dilated(jnp.transpose(dy, (1, 0, 2)), jnp.transpose(x, (1, 0, 2)), (1,), ((1, 1),)), 2) | dx = lax.conv_general_dilated(dy, jnp.transpose(jnp.flip(w, 2), (1, 0, 2)), (1,), ((1, 1),)) |
PyTorch (uses cross-correlation) | x = BC_iN | w = C_oC_iK | y = BC_oN |
1D "Same"-Padding Cross-Correlation | y = torch.nn.functional.conv1d(x, w, padding=1) | todo | todo |