Convolution Autodiff Primitive Rules in Frameworks

Created with ❤️ by Machine Learning & Simulation.

Follow @felix_m_koehler

Assumes "convolution" in the sense of CNNs which actually is cross-correlation.

Interactive App

Description Primal Pullback/vJp into input Pullback/vJp into filter
Valid Padding, K=3
z = conv(
  x, w,
  padding=(0, 0),
)
x_cot = conv(
  z_cot, w,
  padding=(2, 2),
  flip_kernel=True,
  permute_kernel=True,
)
w_cot = conv(
  x, z_cot,
  padding=(0, 0),
  permute_kernel=True,
  permute_input=True,
)
Valid Padding, Variable Kernel Size
z = conv(
  x, w,
  padding=(0, 0),
)
x_cot = conv(
  z_cot, w,
  padding=(K-1, K-1),
  flip_kernel=True,
  permute_kernel=True,
)
w_cot = conv(
  x, z_cot,
  padding=(0, 0),
  permute_kernel=True,
  permute_input=True,
)
Variable Padding & Kernel Size
z = conv(
  x, w,
  padding=(p_l, p_r),
)
x_cot = conv(
  z_cot, w,
  padding=(K-1-p_l, K-1-p_r),
  flip_kernel=True,
  permute_kernel=True,
)
w_cot = conv(
  x, z_cot,
  padding=(p_l, p_r),
  permute_kernel=True,
  permute_input=True,
)
Strided Convolution, No Padding
z = conv(
  x, w,
  padding=(0, 0),
  stride=s,
  dilation=1,
  fill=1,
)
x_cot = conv(
  z_cot, w,
  padding=(K-1, K-1 + (N+K)%s)
  stride=1,
  dilation=1,
  fill=s,
  flip_kernel=True,
  permute_kernel=True,
)
w_cot = conv(
  x, z_cot,
  padding=(0, (N+K)%s),
  stride=1,
  dilation=s,
  fill=1,
  permute_kernel=True,
  permute_input=True,
)
Strided Convolution
z = conv(
  x, w,
  padding=(p_l, p_r),
  stride=s,
  dilation=1,
  fill=1,
)
x_cot = conv(
  z_cot, w,
  padding=(K-p_l-1, N+p_l-1-s*((N-K+p_l+p_r)//s))
  stride=1,
  dilation=1,
  fill=s,
  flip_kernel=True,
  permute_kernel=True,
)
w_cot = conv(
  x, z_cot,
  padding=(p_l, K-N-p_l+s*((N-K+p_l+p_r)//s)),
  stride=1,
  dilation=s,
  fill=1,
  permute_kernel=True,
  permute_input=True,
)
Dilated Convolution (=Atrous Convolution)
z = conv(
  x, w,
  padding=(p_l, p_r),
  stride=1,
  dilation=d,
  fill=1,
)
x_cot = conv(
  z_cot, w,
  padding=(K-p_l-1, K-p_r-1),
  stride=1,
  dilation=d,
  fill=1,
  flip_kernel=True,
  permute_kernel=True,
)
w_cot = conv(
  x, z_cot,
  padding=(p_l, p_r),
  stride=d,
  dilation=1,
  fill=1,
  permute_kernel=True,
  permute_input=True,
)
Transposed Convolution (Filled Convolution)
z = conv(
  x, w,
  padding=(p_l, p_r),
  stride=1,
  dilation=1,
  fill=f,
)
x_cot = conv(
  z_cot, w,
  padding=(K-1-p_l, K-1-p_r),
  stride=f,
  dilation=1,
  fill=1,
  flip_kernel=True,
  permute_kernel=True,
)
w_cot = conv(
  x, z_cot,
  padding=(p_l, p_r),
  stride=1,
  dilation=1,
  fill=f,
  permute_kernel=True,
  permute_input=True,
)
General Convolution
z = conv(
  x, w,
  padding=(p_l, p_r),
  stride=s,
  dilation=d,
  fill=f
)
x_cot = conv(
  z_cot, w,
  padding=(K-p_l-1, N+p_l-1-s*((N-K+p_l+p_r)//s))
  stride=f,
  dilation=d,
  fill=s,
  flip_kernel=True,
  permute_kernel=True,
)
w_cot = conv(
  x, z_cot,
  padding=(p_l, K-N-p_l+s*((N-K+p_l+p_r)//s)),
  stride=d,
  dilation=s,
  fill=f,
  permute_kernel=True,
  permute_input=True,
)