🔙

# LayerNormの誤差逆伝播法

2023/09/01に公開

LayerNorm[1]の逆伝播計算式備忘録

となるから、ヤコビ行列は

\begin{align*} \frac{\partial y_i}{\partial x_j}&= \frac{\delta_{ij}-1/H}{\sigma} - \frac{x_i - \mu}{\sigma^2}\frac{\partial\sigma}{\partial x_j}\\ &=\frac{\delta_{ij}-1/H}{\sigma} - \frac{x_i - \mu}{2\sigma^3}\frac{\partial}{\partial x_j}\frac1H\sum_{h=1}^{H}(x_h-\mu)^2\\ &=\frac{1}{\sigma}\left(\delta_{ij} - \frac1H - \frac{(x_i - \mu)(x_j - \mu)}{H\sigma^2}\right)\\ &=\frac{1}{\sigma}\left(\delta_{ij} - \frac{1 + y_iy_j}{H}\right)\\ \end{align*}

となる。勾配は

\begin{align*} \frac{\partial \mathcal{L}}{\partial x_j} &=\sum_i\frac{\partial y_i}{\partial x_j}\frac{\partial \mathcal{L}}{\partial y_i}\\ &=\frac1\sigma\left(\frac{\partial \mathcal{L}}{\partial y_j} - \frac1H\left(\sum_i\frac{\partial \mathcal{L}}{\partial y_i}+\left(\bm{y}\cdot\frac{\partial\mathcal{L}}{\partial\bm{y}}\right)y_j\right)\right) \end{align*}

となる。
Pythonでの実装を以下に示す。

import numpy as np

N, H = 10, 20

# forward
x = np.random.randn(N, H)
mu = x.mean(axis=-1, keepdims=True)
sigma = x.std(axis=-1, keepdims=True)
y = (x - mu) / sigma

# backward
dy = np.random.randn(N, H)
a = dy.sum(axis=-1, keepdims=True)
b = np.einsum('...i,...i->...', y, dy)[..., None]
dx = (dy - (a + b * y) / H) / sigma

assert x.shape == dx.shape