Created with ❤️ by Machine Learning & Simulation.
Click the 🔗 to see the corresponding derivation video. 👉 Full Playlist.
Primitive | Primal | Pushforward/Jvp | Pullback/vJp | |
---|---|---|---|---|
Explicit Scalar Rules | ||||
Scalar Addition | $z=x+y$ | $\dot{z}=\dot{x} + \dot{y}$ 🔗 | $ \begin{align} \bar{x} &= \bar{z} \\ \bar{y} &= \bar{z} \end{align} $ 🔗 | |
Scalar Multiplication | $z=x \cdot y$ | $\dot{z}=y \cdot \dot{x} + x \cdot \dot{y}$ 🔗 | $ \begin{align} \bar{x} &= \bar{z} \cdot y \\ \bar{y} &= \bar{z} \cdot x \end{align} $ 🔗 | |
Scalar Negation | $z=-x$ | $\dot{z} = - \dot{x}$ | $\bar{x} = - \bar{z}$ | |
Scalar Inversion | $z=\frac{1}{x}$ | $\dot{z} = - \frac{\dot{x}}{x^2}$ | $\bar{x} = - \frac{\bar{z}}{x^2}$ | |
Scalar Power | $z = x^l$ | $\dot{z} = l \, x^{l-1} \, \dot{x}$ | $\bar{x} = \bar{z} \, l \, x^{l-1}$ | |
Scalar Unary Function | $z=f(x)$ | $\dot{z} = \frac{\partial f}{\partial x} \dot{x}$ 🔗 | $\bar{x} = \bar{z} \frac{\partial f}{\partial x}$ 🔗 | |
General Scalar Function | $ \left\{ z^{[i]} \right\}_{i=1}^n = f\left(\left\{ x^{[j]} \right\}_{j=1}^m\right) $ | $ \dot{z}^{[i]} = \sum_{j=1}^m \frac{\partial f^{[i]}}{\partial x^{[j]}} \dot{x}^{[j]} $ | $ \bar{x}^{[j]} = \sum_{i=1}^n \bar{z}^{[i]} \frac{\partial f^{[i]}}{\partial x^{[j]}} $ | |
Explicit Tensor Rules | ||||
Matrix-Vector Product | $ \mathbf{z} = \mathbf{A} \mathbf{x} $ | $ \dot{\mathbf{z}} = \dot{\mathbf{A}} \mathbf{x} + \mathbf{A} \dot{\mathbf{x}} $ 🔗 | $ \begin{align} \bar{\mathbf{x}} &= \mathbf{A}^T \bar{\mathbf{z}} \\ \bar{\mathbf{A}} &= \bar{\mathbf{z}} \mathbf{x}^T \end{align} $ 🔗 | |
Matrix-Matrix Product | $ \mathbf{C} = \mathbf{A} \mathbf{B} $ | $ \dot{\mathbf{C}} = \dot{\mathbf{A}} \mathbf{B} + \mathbf{A} \dot{\mathbf{B}} $ 🔗 | $ \begin{align} \bar{\mathbf{A}} &= \bar{\mathbf{C}} \mathbf{B}^T \\ \bar{\mathbf{B}} &= \mathbf{A}^T \bar{\mathbf{C}} \end{align} $ 🔗 | |
Scalar-Vector Product | $ \mathbf{z} = \alpha \mathbf{x} $ | $ \dot{\mathbf{z}} = \dot{\alpha} \mathbf{x} + \alpha \dot{\mathbf{x}} $ | $ \begin{align} \bar{\mathbf{x}} &= \bar{\mathbf{z}} \alpha \\ \bar{\alpha} &= \bar{\mathbf{z}}^T \mathbf{x} \end{align} $ | |
Scalar-Matrix Product | $ \mathbf{C} = \alpha \mathbf{A} $ | $ \dot{\mathbf{C}} = \dot{\alpha} \mathbf{A} + \alpha \dot{\mathbf{A}} $ | $ \begin{align} \bar{\mathbf{A}} &= \bar{\mathbf{C}} \alpha \\ \bar{\alpha} &= \bar{\mathbf{C}} : \mathbf{A} \end{align} $ | |
Matrix Transposition | $ \mathbf{C} = \mathbf{A}^T $ | $ \dot{\mathbf{C}} = \dot{\mathbf{A}}^T $ | $ \bar{\mathbf{A}} = \bar{\mathbf{C}}^T $ | |
Vector Inner Product | $ \alpha = \mathbf{x}^T \mathbf{y} $ | $ \dot{\alpha} = \dot{\mathbf{x}}^T \mathbf{y} + \mathbf{x}^T \dot{\mathbf{y}} $ | $ \begin{align} \bar{\mathbf{x}} &= \bar{\alpha} \mathbf{y} \\ \bar{\mathbf{y}} &= \bar{\alpha} \mathbf{x} \end{align} $ | |
Vector Outer Product | $ \mathbf{C} = \mathbf{x} \mathbf{y}^T $ | $ \dot{\mathbf{C}} = \dot{\mathbf{x}} \mathbf{y}^T + \mathbf{x} \dot{\mathbf{y}}^T $ | $ \begin{align} \bar{\mathbf{x}} &= \bar{\mathbf{C}} \mathbf{y} \\ \bar{\mathbf{y}} &= \bar{\mathbf{C}}^T \mathbf{x} \end{align} $ | |
Elementwise Multiplication | $ \mathbf{C} = \mathbf{A} \odot \mathbf{B} $ | $ \dot{\mathbf{C}} = \dot{\mathbf{A}} \odot \mathbf{B} + \mathbf{A} \odot \dot{\mathbf{B}} $ | $ \begin{align} \bar{\mathbf{A}} &= \bar{\mathbf{C}} \odot \mathbf{B} \\ \bar{\mathbf{B}} &= \bar{\mathbf{C}} \odot \mathbf{A} \end{align} $ | |
l2 Loss (unscaled) | $ l = \frac{1}{2}||\mathbf{x}||_2^2 $ | $ \dot{l} = \mathbf{x}^T \dot{\mathbf{x}} $ 🔗 | $ \bar{\mathbf{x}} = \, \bar{l} \, \mathbf{x} $ 🔗 | |
L2 Loss (MSE) | $ l = \frac{1}{2N}||\mathbf{x}||_2^2 $ | $ \dot{l} = \frac{1}{N}\mathbf{x}^T \dot{\mathbf{x}} $ | $ \bar{\mathbf{x}} = \, \frac{\bar{l}}{N} \, \mathbf{x} $ | |
Scalar Function Broadcasting | $ \mathbf{z} = \sigma.( \mathbf{x} ) $ | $ \dot{\mathbf{z}} = \sigma'.( \mathbf{x} ) \odot \dot{\mathbf{x}} $ 🔗 | $ \bar{\mathbf{x}} = \bar{\mathbf{z}} \odot \sigma'.( \mathbf{x} ) $ 🔗 | |
Softmax | $\begin{align} \mathbf{z} &= \frac{\exp .(\mathbf{x})}{\mathbf{1}^T \exp . (\mathbf{x})} \\ z_i &= \frac{\exp(x_i)}{\sum_j \exp(x_j)} \end{align}$ | $ \dot{\mathbf{z}} = \mathbf{z} \odot \dot{\mathbf{x}} - \mathbf{z} (\mathbf{z}^T \dot{\mathbf{x}}) $ 🔗 | $ \bar{\mathbf{x}} = \bar{\mathbf{z}} \odot \mathbf{z} - \mathbf{z} (\bar{\mathbf{z}}^T \mathbf{z}) $ 🔗 | |
FFT | $ \mathbf{z} = \text{fft}( \mathbf{x} ) $ | $ \dot{\mathbf{z}} = \text{fft}( \dot{\mathbf{x}} ) $ |
$
\begin{align}
\text{JAX} \quad \bar{\mathbf{x}} &= \text{fft}( \bar{\mathbf{z}} )
\\
\text{PyTorch} \quad \bar{\mathbf{x}} &= \text{fft}( \bar{\mathbf{z}}^* )
\end{align}
$
Details on Differences |
|
IFFT | $ \mathbf{z} = \text{ifft}( \mathbf{x} ) $ | $ \dot{\mathbf{z}} = \text{ifft}( \dot{\mathbf{x}} ) $ |
$
\begin{align}
\text{JAX} \quad \bar{\mathbf{x}} &= \text{ifft}( \bar{\mathbf{z}} )
\\
\text{PyTorch} \quad \bar{\mathbf{x}} &= \text{ifft}( \bar{\mathbf{z}}^* )
\end{align}
$
| |
RFFT | $ \mathbf{z} = \text{rfft}( \mathbf{x} ) $ | $ \dot{\mathbf{z}} = \text{rfft}( \dot{\mathbf{x}} ) $ |
JAX
$ \begin{align} \bar{\mathbf{x}} &= \mathcal{R}(\text{fft}( \text{pad}(\bar{\mathbf{z}}, (0, N//2 - 1) ))) \end{align} $ |
|
IRFFT | $ \mathbf{z} = \text{irfft}( \mathbf{x} ) $ | $ \dot{\mathbf{z}} = \text{irfft}( \dot{\mathbf{x}} ) $ |
JAX
$ \begin{align} t &= \text{rfft}(\bar{z}) \\ t_{1:(N//2)} &= 2 \cdot t_{1:(N//2)} \\ \bar{x} &= \frac{1}{N} t \end{align} $ |
|
Convolution with
Arbitrary Zero Padding Convolution in the sense of CNNs (actually cross-correlation) More Details |
z = conv( x, w, padding=(p_l, p_r), ) |
dot_z = conv( dot_x, w, padding=(p_l, p_r), ) + conv( x, dot_w, padding=(p_l, p_r), ) |
x_cot = conv( z_cot, w, padding=(K-1-p_l, K-1-p_r), flip_kernel=True, permute_kernel=True, ) w_cot = conv( x, z_cot, padding=(p_l, p_r), permute_kernel=True, permute_input=True, ) |
|
General Convolution
(holds per spatial axis) |
z = conv( x, w, padding=(p_l, p_r), stride=s, dilation=d, fill=f ) |
dot_z = conv( dot_x, w, padding=(p_l, p_r), stride=s, dilation=d, fill=f ) + conv( x, dot_w, padding=(p_l, p_r), stride=s, dilation=d, fill=f ) |
c = s*((N - K + p_l + p_r) // s) x_cot = conv( z_cot, w, padding=(K - p_l - 1, N + p_l - 1 - c), stride=f, dilation=d, fill=s, flip_kernel=True, permute_kernel=True, ) w_cot = conv( x, z_cot, padding=(p_l, K - N - p_l + c), stride=d, dilation=s, fill=f, permute_kernel=True, permute_input=True, ) |
|
Vector Unary Function | $ \mathbf{z} = f(\mathbf{x}) $ | $ \dot{\mathbf{z}} = \frac{\partial f}{\partial \mathbf{x}} \dot{\mathbf{x}} $ | $ \begin{align} \bar{\mathbf{x}}^T &= \bar{\mathbf{z}}^T \frac{\partial f}{\partial \mathbf{x}} \\ &\text{or} \\ \bar{\mathbf{x}} &= \left( \frac{\partial f}{\partial \mathbf{x}} \right)^T \bar{\mathbf{z}} \end{align} $ | |
General Vector Function | $ \left\{ \mathbf{z}^{[i]}\right\}_{i=1}^n = f(\left\{ \mathbf{x}^{[j]}\right\}_{j=1}^m) $ | $ \dot{\mathbf{z}}^{[i]} = \sum_{j=1}^m \frac{\partial f^{[i]}}{\partial \mathbf{x}^{[j]}} \dot{\mathbf{x}}^{[j]} $ | $ \bar{\mathbf{x}}^{[j],T} = \sum_{i=1}^n \bar{\mathbf{z}}^{[i],T} \frac{\partial f^{[i]}}{\partial \mathbf{x}^{[j]}} $ | |
Array Manipulations | ||||
Concatenation | $ \mathbf{z} = \text{concat}(\mathbf{x}, \mathbf{y}) $ | $ \dot{\mathbf{z}} = \text{concat}(\dot{\mathbf{x}}, \dot{\mathbf{y}}) $ | $ \begin{align} \bar{\mathbf{x}} &= \text{slice}(\bar{\mathbf{z}}, 0, \text{len}(\mathbf{x})) \\ \bar{\mathbf{y}} &= \text{slice}(\bar{\mathbf{z}}, \text{len}(\mathbf{x}), \text{len}(\mathbf{z})) \end{align} $ | |
Slicing | $ \mathbf{z} = \text{slice}(\mathbf{x}, a, b) $ | $ \dot{\mathbf{z}} = \text{slice}(\dot{\mathbf{x}}, a, b) $ | $ \begin{align} \bar{\mathbf{x}} &= \text{concat}(\text{zeros}(a), \bar{\mathbf{z}}, \text{zeros}(\text{len}(\mathbf{x}) - b)) \\ &\text{or} \\ \bar{\mathbf{x}} &= \text{pad}(\bar{\mathbf{z}}, (a, \text{len}(\mathbf{x}) - b)) \end{align} $ | |
(Zero) Padding | $ \mathbf{z} = \text{pad}(\mathbf{x}, (a, b)) $ | $ \dot{\mathbf{z}} = \text{pad}(\dot{\mathbf{x}}, (a, b)) $ | $ \bar{\mathbf{x}} = \text{slice}(\bar{\mathbf{z}}, a, \text{len}(\mathbf{z}) - b) $ | |
Reshaping | $ \mathbf{z} = \text{reshape}(\mathbf{x}, \text{shape}(\mathbf{z})) $ | $ \dot{\mathbf{z}} = \text{reshape}(\dot{\mathbf{x}}, \text{shape}(\mathbf{z})) $ | $ \bar{\mathbf{x}} = \text{reshape}(\bar{\mathbf{z}}, \text{shape}(\mathbf{x})) $ |