Convolution Autodiff Primitive Rules in Frameworks

Created with ❤️ by Machine Learning & Simulation.

Follow @felix_m_koehler

Also check out the pullback/vJp rules in more mathmatical notation fkoehler.site/conv-autodiff-table. However, those rules assume a one-channel to one-channel convolution of only one samples (i.e., no batched convolutions). Here, we consider the general convolution routines of major deep learning frameworks.

Important: Take care that same frameworks (e.g., PyTorch) use cross-correlation instead of convolution. That changes the pullback rule into the filter (see link above).

Legend: B = batch size, C_i = input channels, C_o = output channels, K = kernel size, N = spatial size

Interactive App for JAX

Primitive Primal Pullback/vJp into filter Pullback/vJp into input
Julia NNlib x = NC_iB w = KC_iC_o y = NC_oB
1D "Same"-Padding Convolution y = conv(x, w, pad=1) dw = conv(permutedims(dy, (1, 3, 2)), permutedims(x, (1, 3, 2)), pad=1, flipped=true) dx = conv(dy, permutedims(w, (1, 3, 2)), pad=1, flipped=true)
JAX (uses cross-correlation of XLA backend) x = BC_iN w = C_iC_oK y = BC_oN
1D "Same"-Padding Cross-Correlation y = lax.conv_general_dilated(x, w, (1,), ((1, 1),)) dw = jnp.flip(lax.conv_general_dilated(jnp.transpose(dy, (1, 0, 2)), jnp.transpose(x, (1, 0, 2)), (1,), ((1, 1),)), 2) dx = lax.conv_general_dilated(dy, jnp.transpose(jnp.flip(w, 2), (1, 0, 2)), (1,), ((1, 1),))
PyTorch (uses cross-correlation) x = BC_iN w = C_oC_iK y = BC_oN
1D "Same"-Padding Cross-Correlation y = torch.nn.functional.conv1d(x, w, padding=1) todo todo