Explicit Autodiff Primitive Rules

Created with ❤️ by Machine Learning & Simulation.

Follow @felix_m_koehler

Click the 🔗 to see the corresponding derivation video. 👉 Full Playlist.

Primitive Primal Pushforward/Jvp Pullback/vJp
Explicit Scalar Rules
Scalar Addition $z=x+y$ $\dot{z}=\dot{x} + \dot{y}$ 🔗 $ \begin{align} \bar{x} &= \bar{z} \\ \bar{y} &= \bar{z} \end{align} $ 🔗
Scalar Multiplication $z=x \cdot y$ $\dot{z}=y \cdot \dot{x} + x \cdot \dot{y}$ 🔗 $ \begin{align} \bar{x} &= \bar{z} \cdot y \\ \bar{y} &= \bar{z} \cdot x \end{align} $ 🔗
Scalar Negation $z=-x$ $\dot{z} = - \dot{x}$ $\bar{x} = - \bar{z}$
Scalar Inversion $z=\frac{1}{x}$ $\dot{z} = - \frac{\dot{x}}{x^2}$ $\bar{x} = - \frac{\bar{z}}{x^2}$
Scalar Power $z = x^l$ $\dot{z} = l \, x^{l-1} \, \dot{x}$ $\bar{x} = \bar{z} \, l \, x^{l-1}$
Scalar Unary Function $z=f(x)$ $\dot{z} = \frac{\partial f}{\partial x} \dot{x}$ 🔗 $\bar{x} = \bar{z} \frac{\partial f}{\partial x}$ 🔗
General Scalar Function $ \left\{ z^{[i]} \right\}_{i=1}^n = f\left(\left\{ x^{[j]} \right\}_{j=1}^m\right) $ $ \dot{z}^{[i]} = \sum_{j=1}^m \frac{\partial f^{[i]}}{\partial x^{[j]}} \dot{x}^{[j]} $ $ \bar{x}^{[j]} = \sum_{i=1}^n \bar{z}^{[i]} \frac{\partial f^{[i]}}{\partial x^{[j]}} $
Explicit Tensor Rules
Matrix-Vector Product $ \mathbf{z} = \mathbf{A} \mathbf{x} $ $ \dot{\mathbf{z}} = \dot{\mathbf{A}} \mathbf{x} + \mathbf{A} \dot{\mathbf{x}} $ 🔗 $ \begin{align} \bar{\mathbf{x}} &= \mathbf{A}^T \bar{\mathbf{z}} \\ \bar{\mathbf{A}} &= \bar{\mathbf{z}} \mathbf{x}^T \end{align} $ 🔗
Matrix-Matrix Product $ \mathbf{C} = \mathbf{A} \mathbf{B} $ $ \dot{\mathbf{C}} = \dot{\mathbf{A}} \mathbf{B} + \mathbf{A} \dot{\mathbf{B}} $ 🔗 $ \begin{align} \bar{\mathbf{A}} &= \bar{\mathbf{C}} \mathbf{B}^T \\ \bar{\mathbf{B}} &= \mathbf{A}^T \bar{\mathbf{C}} \end{align} $ 🔗
Scalar-Vector Product $ \mathbf{z} = \alpha \mathbf{x} $ $ \dot{\mathbf{z}} = \dot{\alpha} \mathbf{x} + \alpha \dot{\mathbf{x}} $ $ \begin{align} \bar{\mathbf{x}} &= \bar{\mathbf{z}} \alpha \\ \bar{\alpha} &= \bar{\mathbf{z}}^T \mathbf{x} \end{align} $
Scalar-Matrix Product $ \mathbf{C} = \alpha \mathbf{A} $ $ \dot{\mathbf{C}} = \dot{\alpha} \mathbf{A} + \alpha \dot{\mathbf{A}} $ $ \begin{align} \bar{\mathbf{A}} &= \bar{\mathbf{C}} \alpha \\ \bar{\alpha} &= \bar{\mathbf{C}} : \mathbf{A} \end{align} $
Matrix Transposition $ \mathbf{C} = \mathbf{A}^T $ $ \dot{\mathbf{C}} = \dot{\mathbf{A}}^T $ $ \bar{\mathbf{A}} = \bar{\mathbf{C}}^T $
Vector Inner Product $ \alpha = \mathbf{x}^T \mathbf{y} $ $ \dot{\alpha} = \dot{\mathbf{x}}^T \mathbf{y} + \mathbf{x}^T \dot{\mathbf{y}} $ $ \begin{align} \bar{\mathbf{x}} &= \bar{\alpha} \mathbf{y} \\ \bar{\mathbf{y}} &= \bar{\alpha} \mathbf{x} \end{align} $
Vector Outer Product $ \mathbf{C} = \mathbf{x} \mathbf{y}^T $ $ \dot{\mathbf{C}} = \dot{\mathbf{x}} \mathbf{y}^T + \mathbf{x} \dot{\mathbf{y}}^T $ $ \begin{align} \bar{\mathbf{x}} &= \bar{\mathbf{C}} \mathbf{y} \\ \bar{\mathbf{y}} &= \bar{\mathbf{C}}^T \mathbf{x} \end{align} $
Elementwise Multiplication $ \mathbf{C} = \mathbf{A} \odot \mathbf{B} $ $ \dot{\mathbf{C}} = \dot{\mathbf{A}} \odot \mathbf{B} + \mathbf{A} \odot \dot{\mathbf{B}} $ $ \begin{align} \bar{\mathbf{A}} &= \bar{\mathbf{C}} \odot \mathbf{B} \\ \bar{\mathbf{B}} &= \bar{\mathbf{C}} \odot \mathbf{A} \end{align} $
l2 Loss (unscaled) $ l = \frac{1}{2}||\mathbf{x}||_2^2 $ $ \dot{l} = \mathbf{x}^T \dot{\mathbf{x}} $ 🔗 $ \bar{\mathbf{x}} = \, \bar{l} \, \mathbf{x} $ 🔗
L2 Loss (MSE) $ l = \frac{1}{2N}||\mathbf{x}||_2^2 $ $ \dot{l} = \frac{1}{N}\mathbf{x}^T \dot{\mathbf{x}} $ $ \bar{\mathbf{x}} = \, \frac{\bar{l}}{N} \, \mathbf{x} $
Scalar Function Broadcasting $ \mathbf{z} = \sigma.( \mathbf{x} ) $ $ \dot{\mathbf{z}} = \sigma'.( \mathbf{x} ) \odot \dot{\mathbf{x}} $ 🔗 $ \bar{\mathbf{x}} = \bar{\mathbf{z}} \odot \sigma'.( \mathbf{x} ) $ 🔗
Softmax $\begin{align} \mathbf{z} &= \frac{\exp .(\mathbf{x})}{\mathbf{1}^T \exp . (\mathbf{x})} \\ z_i &= \frac{\exp(x_i)}{\sum_j \exp(x_j)} \end{align}$ $ \dot{\mathbf{z}} = \mathbf{z} \odot \dot{\mathbf{x}} - \mathbf{z} (\mathbf{z}^T \dot{\mathbf{x}}) $ 🔗 $ \bar{\mathbf{x}} = \bar{\mathbf{z}} \odot \mathbf{z} - \mathbf{z} (\bar{\mathbf{z}}^T \mathbf{z}) $ 🔗
FFT $ \mathbf{z} = \text{fft}( \mathbf{x} ) $ $ \dot{\mathbf{z}} = \text{fft}( \dot{\mathbf{x}} ) $ $ \begin{align} \text{JAX} \quad \bar{\mathbf{x}} &= \text{fft}( \bar{\mathbf{z}} ) \\ \text{PyTorch} \quad \bar{\mathbf{x}} &= \text{fft}( \bar{\mathbf{z}}^* ) \end{align} $
Details on Differences
IFFT $ \mathbf{z} = \text{ifft}( \mathbf{x} ) $ $ \dot{\mathbf{z}} = \text{ifft}( \dot{\mathbf{x}} ) $ $ \begin{align} \text{JAX} \quad \bar{\mathbf{x}} &= \text{ifft}( \bar{\mathbf{z}} ) \\ \text{PyTorch} \quad \bar{\mathbf{x}} &= \text{ifft}( \bar{\mathbf{z}}^* ) \end{align} $
RFFT $ \mathbf{z} = \text{rfft}( \mathbf{x} ) $ $ \dot{\mathbf{z}} = \text{rfft}( \dot{\mathbf{x}} ) $ JAX
$ \begin{align} \bar{\mathbf{x}} &= \mathcal{R}(\text{fft}( \text{pad}(\bar{\mathbf{z}}, (0, N//2 - 1) ))) \end{align} $
IRFFT $ \mathbf{z} = \text{irfft}( \mathbf{x} ) $ $ \dot{\mathbf{z}} = \text{irfft}( \dot{\mathbf{x}} ) $ JAX
$ \begin{align} t &= \text{rfft}(\bar{z}) \\ t_{1:(N//2)} &= 2 \cdot t_{1:(N//2)} \\ \bar{x} &= \frac{1}{N} t \end{align} $
Convolution with
Arbitrary Zero Padding
Convolution in the
sense of CNNs
(actually cross-correlation)
More Details
z = conv(
  x, w,
  padding=(p_l, p_r),
)
dot_z = conv(
  dot_x, w,
  padding=(p_l, p_r),
) + conv(
  x, dot_w,
  padding=(p_l, p_r),
)
x_cot = conv(
  z_cot, w,
  padding=(K-1-p_l, K-1-p_r),
  flip_kernel=True,
  permute_kernel=True,
)
w_cot = conv(
  x, z_cot,
  padding=(p_l, p_r),
  permute_kernel=True,
  permute_input=True,
)
General Convolution
(holds per spatial axis)
z = conv(
  x, w,
  padding=(p_l, p_r),
  stride=s,
  dilation=d,
  fill=f
)
dot_z = conv(
  dot_x, w,
  padding=(p_l, p_r),
  stride=s,
  dilation=d,
  fill=f
) + conv(
  x, dot_w,
  padding=(p_l, p_r),
  stride=s,
  dilation=d,
  fill=f
)
    
c = s*((N - K + p_l + p_r) // s)
x_cot = conv(
  z_cot, w,
  padding=(K - p_l - 1, N + p_l - 1 - c),
  stride=f,
  dilation=d,
  fill=s,
  flip_kernel=True,
  permute_kernel=True,
)
w_cot = conv(
  x, z_cot,
  padding=(p_l, K - N - p_l + c),
  stride=d,
  dilation=s,
  fill=f,
  permute_kernel=True,
  permute_input=True,
)
Vector Unary Function $ \mathbf{z} = f(\mathbf{x}) $ $ \dot{\mathbf{z}} = \frac{\partial f}{\partial \mathbf{x}} \dot{\mathbf{x}} $ $ \begin{align} \bar{\mathbf{x}}^T &= \bar{\mathbf{z}}^T \frac{\partial f}{\partial \mathbf{x}} \\ &\text{or} \\ \bar{\mathbf{x}} &= \left( \frac{\partial f}{\partial \mathbf{x}} \right)^T \bar{\mathbf{z}} \end{align} $
General Vector Function $ \left\{ \mathbf{z}^{[i]}\right\}_{i=1}^n = f(\left\{ \mathbf{x}^{[j]}\right\}_{j=1}^m) $ $ \dot{\mathbf{z}}^{[i]} = \sum_{j=1}^m \frac{\partial f^{[i]}}{\partial \mathbf{x}^{[j]}} \dot{\mathbf{x}}^{[j]} $ $ \bar{\mathbf{x}}^{[j],T} = \sum_{i=1}^n \bar{\mathbf{z}}^{[i],T} \frac{\partial f^{[i]}}{\partial \mathbf{x}^{[j]}} $
Array Manipulations
Concatenation $ \mathbf{z} = \text{concat}(\mathbf{x}, \mathbf{y}) $ $ \dot{\mathbf{z}} = \text{concat}(\dot{\mathbf{x}}, \dot{\mathbf{y}}) $ $ \begin{align} \bar{\mathbf{x}} &= \text{slice}(\bar{\mathbf{z}}, 0, \text{len}(\mathbf{x})) \\ \bar{\mathbf{y}} &= \text{slice}(\bar{\mathbf{z}}, \text{len}(\mathbf{x}), \text{len}(\mathbf{z})) \end{align} $
Slicing $ \mathbf{z} = \text{slice}(\mathbf{x}, a, b) $ $ \dot{\mathbf{z}} = \text{slice}(\dot{\mathbf{x}}, a, b) $ $ \begin{align} \bar{\mathbf{x}} &= \text{concat}(\text{zeros}(a), \bar{\mathbf{z}}, \text{zeros}(\text{len}(\mathbf{x}) - b)) \\ &\text{or} \\ \bar{\mathbf{x}} &= \text{pad}(\bar{\mathbf{z}}, (a, \text{len}(\mathbf{x}) - b)) \end{align} $
(Zero) Padding $ \mathbf{z} = \text{pad}(\mathbf{x}, (a, b)) $ $ \dot{\mathbf{z}} = \text{pad}(\dot{\mathbf{x}}, (a, b)) $ $ \bar{\mathbf{x}} = \text{slice}(\bar{\mathbf{z}}, a, \text{len}(\mathbf{z}) - b) $
Reshaping $ \mathbf{z} = \text{reshape}(\mathbf{x}, \text{shape}(\mathbf{z})) $ $ \dot{\mathbf{z}} = \text{reshape}(\dot{\mathbf{x}}, \text{shape}(\mathbf{z})) $ $ \bar{\mathbf{x}} = \text{reshape}(\bar{\mathbf{z}}, \text{shape}(\mathbf{x})) $