Zenn
🐡

RMSNorm が主要な LLM で標準的に採用されているのはなぜか?

に公開
2

現在主流の多くのLLMでは LayerNorm のかわりに RMSNorm [10]が採用されています。ただ、大規模なパラメータの LLM ではむしろ MLP や Attention など GEMM 演算が占める割合が大部分を占めるので、Normalization の改善で全体的な学習・推論効率の改善が見込めるとは考えにくいです。そこで、技術レポートが公開されている主要なLLMでRMSNormを採用した理由について調べました。

その結果、少なくとも今回調べた限りでは、RMSNormの論文で言及されているような効率化や学習の安定化といった一般的な記述はされているものの、一定規模以上のLLMで定量的な比較に基づいてRMSNormの優位性を主張した論文は見つかりませんでした

従って、この記事の仮の結論としては単純に「GPT-3の設計に倣っただけ」としました。
反例(十分大きなパラメータのLLMでRMSNormが明確に優位であることを主張する論文、記事、GitHub repo etc.)を知っている方はコメントやTwitterで教えていただけると幸いです。

経緯

RMSNormとは?

  • RMSNorm [10]: LNのcenterizationを省略し、stdをRMSで置換することで計算量を削減する手法。
LN(x)=xmean(x)std(x)g+b \text{LN(x)} = \frac{x - \text{mean}(x)}{\text{std}(x)} \odot g + b
RMSNorm(x)=xRMS(x)g+b \text{RMSNorm(x)} = \frac{x}{\text{RMS}(x)} \odot g + b

where,

std(x)=1ni(xiμ)2 \text{std}(x) = \sqrt{ \frac{1}{n} \sum _ i (x _ i - \mu)^2 }
RMS(x)=1nixi2 \text{RMS}(x) = \sqrt{ \frac{1}{n} \sum _ i x _ i^2 }

元の論文では、計算効率の改善や、学習の安定化について言及されている。ただ、ここで扱われているモデルでは、モデル全体の計算時間に占めるLayerNormの計算時間の割合がそれなりに多いケースである。

大規模パラメータの LLM では RMSNorm の導入による計算効率的なメリットはない

[1]によると、GPT3-LまではLayerNormに比べてRMSNormが5%程度の改善が見込めたことを示しているが、「GPT3-XLより大きくなるとLNが全体の推論時間に占める割合が1%未満であるため、RMSNormにしても改善は見込めない」と報告している。つまり、 一定規模以上のLLMでRMSNormが採用される理由は少なくとも計算効率ではない


(source: [1])

主要なLLMの技術レポートにおけるRMSNormについての言及について

  • GPT3 [2]: 記載なし
  • Llama [3]: 「GPT-3に従った」
  • OLMo [4]: 「non-parametric なLNが最も安全で効率的な選択肢だと信じている」
  • OLMoE [5]: 「non-parametric なLNは勾配のスパイクが多く学習が安定しなかったのでRMNormを採用した」
  • OLMo 2 [6]: 「OLMo-0424 ではパフォーマンス上の優位性からnon-parametric LNを採用していたが、現在のRMSNormの実装ではパフォーマンスは同等で、また、ablationによりnon-parametric なLNとRMSNormでは性能に大差がなかったのでLLMでは『標準的な』RMSNormを採用した」
  • Qwen [7]: 「RMSNormはLNと比べて同等の性能で計算効率が良い」という記述はあるが、具体的な数値の話やAblationはなし。
  • DeepSeek LLM [8]: 「Llamaに倣った」
  • Gemma [9]: 「学習の安定のためにRMSNormを採用した」と記載はあるが、ablationはない。

QwenとGemmaは安定性や計算効率についての言及はあるが、ablationにより具体的に比較を行なったかは不明である。

なお、OLMoEでは「安定性の理由からRMSNormに戻した」と記載されているが、これは元々パラメータなしのnon-parametric LNという特殊なLNを採用していたためで、通常のパラメータありのLNとの比較ではないことに注意。

参考: OLMoEで観測されたnon-parametric LNの勾配スパイク

(source: [5])

結論

調査した限りでは、一定規模以上のLLMで定量的な比較に基づいてRMSNormの優位性を主張した論文は見つからなかった
従ってGPT-3以降のアーキテクチャでRMSNormが採用された理由は単に「GPT-3が採用していたから」という理由ではないかと推測する。

References

GitHubで編集を提案
2

Discussion

ログインするとコメントできます