Open8

DyT: LayerNormのmesn-std正規化をtanhで置換する

bilzardbilzard

提案手法の位置付け

  • RMSNorm: LNのcenterizationを省略し、stdをRMSで置換することで計算量を削減
  • DyT (This method): LNのcenterizationとscalingをtanhで置換することで計算量を削減
\text{LN(x)} = \frac{x - \text{mean}(x)}{\text{std}(x)} \odot g + b
\text{RMSNorm(x)} = \frac{x}{\text{RMS}(x)} \odot g + b
\text{DyT(x)} = \tanh (\alpha \cdot x) \odot g + b

where,

\text{std}(x) = \sqrt{ \frac{1}{n} \sum _ i (x _ i - \mu)^2 }
\text{RMS}(x) = \sqrt{ \frac{1}{n} \sum _ i x _ i^2 }
bilzardbilzard

この手法を思いついた経緯

  • 実際のTransformerのLNの(re-normalization前の)入力・出力の関係をグラフにプロットしてみると、tanhに似たS字カーブを観察できる。
  • さらに細かく、tokenごとの入力・出力の関係をグラフにプロットしてみると、mean-std normの部分は 入力の外れ値を規定の範囲に収める役割を果たしている ことがわかる→ この役割はtanhで代替可能では? というのが本手法のアイデア(🪓)。


bilzardbilzard

提案手法

  • 追加で tanhの線形部分の傾きをコントロールするパラメータ \alpha を導入する。これは1つのDyTレイヤにつき1自由度のみ追加される。
\text{DyT(x)} = \tanh (\alpha \cdot x) \odot g + b


bilzardbilzard

結果

論文参照。

議論と考察

1. \alpha の初期値と外れ値との関係について

本手法の本質は、LNでアクティベーションの統計値から動的に計算していたスケール(std)を、データから経験的に学習した定数 \alpha で置換する というものである。したがって、 本手法は外れ値に対して脆弱である可能性がある 。この仮説は、ViT-Lの学習で大きな \alpha_0 を選ぶと学習が発散する、といった事実や、以下の実験結果を説明する。

  1. 小さい \alpha _ 0 は学習の安定をもたらす(そのかわりパフォーマンスを犠牲にする)
  2. modelのwidth(次元)が広いほど小さい \alpha _ 0 を好む
  3. modelのdepth(レイヤ数)は \alpha _ 0 の選び方の感応性(sensitivity)は小さい

2は直感的には、モデルの重みの量子化において、統計をとるブロック数を多くとることで量子化誤差が小さくなる、という事実と関係しているように思える。つまり、 modelのwidthが大きくなるほどサンプル数が増えて外れ値が入り込む確率が上がるため、より保守的なスケーリング(小さい \alpha _ 0 を選択)の方が学習結果が安定する 、という説明である。

外れ値の影響を緩和するには?

上述の仮説が正しく、実際に外れ値が学習の不安定化の原因ならば、以下のような応用が考えられるかもしれない。

  1. NレイヤごとにLNを挿入して外れ値の影響を緩和する
  2. 全token、全次元で1つのパラメータ \alpha を計算するのでなく、次元やトークンをブロック化し、スケール対象の区画を細分化する

2はモデルの重みの量子化で一般的に量子化誤差を抑えるために導入されている。

2. (リスク) LLMでは\alpha の初期値のチューニングが必要

LLM以外のTransformerでは「学習結果は \alpha の初期値の選び方にほとんど影響しない」という結果だったが(ViT-Lを除き)、「LLMでは \alpha の初期値をチューニングするとゲインが得られた」としている。裏を返せば \alpha の初期値をチューニングしないと最適な結果が得られない ということである。学習に膨大なコストがかかるLLMに本手法を適用するには慎重になる必要があるかもしれない。

bilzardbilzard

感想

  • 外れ値とかアクティベーションの分布といった話は量子化学習でよく考察される内容だが、一般的なモデルの学習の安定化にも寄与しているのではないか?→そして、分布の安定化に関係しているのはおそらくnorm layerであり、この部分の設計がTransformer、というか層数・次元数の大きい大規模の設計で重要なのでは?といったことを考えた。
bilzardbilzard

疑問

以下の論文では層数の大きいLLMほど、LNがLatencyに占める割合は相対に小さいようなので、ここの処理効率化は大規模LLMでは、全体の効率化にさほど寄与しないのではないか?と思うのだが、RMSNormは現代の一般的なLLM(GPT-3, Llama, Qwen2.5, DeepSeek-V3 etc.)で広く採用されているのでなんでだろう?という感じ。安定化により寄与するとか、並列化が容易とか、別のメリットがある?

https://arxiv.org/abs/2401.14489