😎

論文紹介 : Language Models are Super Mario

2024/04/21に公開

概要

Language Models are Super Mario: Absorbing Abilities from Homologous Models as a Free Lunch という論文を読んだので簡単に説明します
https://arxiv.org/abs/2311.03099

こちらの論文では、言語モデル(LM)がGPUを使ったファインチューニングなどを行わなくても、同じタイプのモデルからパラメータを同化することで、新たな能力を獲得できることを明らかにします。

論文のタイトルには「スーパーマリオ」が入っていますが、これは何を意味するでしょうか? スーパーマリオはゲーム中のアイテムを使用することで、火の玉を投げるような新しい能力を取得することができます。提案手法では、言語モデル(LM)がスーパーマリオと同様に再トレーニングやGPUを必要とせずに、他のモデルを吸収することで、その能力を向上できることを発見しました。

提案手法ではデルタパラメータを扱います。これはファインチューニングされたパラメータと、事前学習されたパラメータとの差を意味します。ほとんどのデルタパラメータをゼロに設定するためにDAREという手法を導入します。DAREでは比率pによってデルタパラメータをランダムに削除し、残りのパラメータを1/(1-p)で再スケールして、元の埋め込みを近似するようにします。

DAREを使って複数の同系統のモデルをスパース化することで、モデルマージ時のパラメータの干渉を緩和することができます。

(1)デルタパラメータの値の範囲は一般的に小さく(0.005以内)で、極端な冗長性があります。DAREはその90%、あるいは99%を問題なく除去することができます。
(2)DAREは複数のタスクに特化したLMを多様な能力を持つ1つのLMに統合することができます。例えばWizardLMとWizardMathを融合することで、WizardLMのGSM8Kゼロショット精度は2.2から66.2に向上し、WizardMathの64.2を超える命令追従性を維持したまま、WizardLMの性能を向上させることができました。

導入

(1)ファインチューニングされたLMはバックボーンに関係なく、相当数の冗長デルタパラメータを示します(例:BERT、RoBERTa、LLaMA、Llama 2、Code Llamaなど)。DAREは、モデル性能に大きな影響を与えることなく、90%または99%のデルタパラメータを削除することができます。DAREは元の埋め込みをうまく近似することができ、LMの各層に対して非常によく似た埋め込みを提供することができます。リスケール操作はDAREの成功を保証するために重要で、リスケールせずに30%または40%のデルタパラメータを削除すると、顕著に悪い結果に繋がります。

(2)DAREはエンコーダーベースのLMにおいて、様々なモデルマージの手法の性能を向上できる場合が多いです。下の図に示すように、WizardLMとWizardMathをDAREとパラメータ平均を組み合わせて統合すると、GSM8KではWizardLMの数学的推論能力が2.2から66.3とゼロショット精度が大幅に向上し、AlpacaEvalでは命令追従能力も若干向上することがわかりました。

(3)教師ありファインチューニング(SFT)済みモデルのデルタパラメータは通常0.005以内にとどまり、事前学習済みLMの修正が少ないことを示し、比較的小さな値域のデルタパラメータに対してはDAREが有効です。しかし、モデルが継続的な事前学習を受けると、デルタパラメータは急速に0.03程度まで達し、DAREを実行不可能にします。さらにfine-tunedパラメータ(事前学習済みパラメータとデルタパラメータの組み合わせ)のわずか10%を削除するだけで、パフォーマンスが大幅に低下し、ほぼゼロに近づきます。これは教師ありファインチューニングが主に事前学習済み言語モデルの能力を引き出すものであり、新しい能力を導入するものではないことを裏付けています。

手法

デルタパラメータは「教師ありファインチューイング後のパラメータ」から「事前学習済みのパラメータ」を引いたものです。デルタパラメータはSFTプロセス中のパラメータの変化を反映するため、デルタパラメータの特性を分析することで、SFTの理解を深めることができます。

DAREについて

提案手法では、デルタパラメータが与えられると、DAREはまずドロップ率pに基いて、デルタパラメータでランダムにドロップします。そして残りの値を1/(1-p)の倍数でスケールします。論文中ではDAREによって埋め込みの期待値がなぜ維持されるかを理論的に証明しています。

最後にDAREとドロップアウトの違いを考えます。両方ともランダムなドロップアウトとリスケーリングの操作を含みますが、2つの側面で異なります。(1)DAREはデルタパラメータを処理し、ドロップアウトはモデル出力を操作します。(2)DAREは学習せずにデルタパラメータの冗長性を低減することを目的とし、デルタパラメータを永久に排除し、他の推論のみ保持します。ドロップアウトはモデルがオーバーフィットするのを防ぐために使用され、学習中に出力の一部が一時的に削除されますが、推論のためにすべての出力が保存されます。

モデルとDAREの融合

DAREはデルタパラメータのほとんどをゼロにすることで、効果的に冗長性を低減するため、DAREは複数のモデルをマージする際にパラメータの干渉を役立つと仮定しています。図2(b)を例にとると、数学とコード関連のモデルをマージする場合、DAREは既存のモデルマージ手法を支援し、パラメータ干渉が少ないか全くない2つのモデルの能力をよりよく吸収することができます。

図2(b)の左側はDARE、右側が通常のモデルマージです。DAREのほうはパラメータの冗長性を減らしているため、2つのモデルをマージするときの干渉が少なくなっていることを表しています。

実験

SFTデルタパラメータの極端な冗長性

ドロップ率pを0.0, 0.1, ..., 0.9, 0.99と変化させ、DAREを適用してモデルを得ます。結論として(1)SFTデルタパラメータは非常に冗長です。DAREは90%のデルタパラメータを効果的に除去することができ、性能を大幅に低下させることはありません。場合によってはドロップアウト率pが99%に達することがあります。(2)ドロップアウト率の許容範囲は、LMが大きくなるほど大きくなります。すなわちパラメータが多いLMは、より高いドロップアウト率に耐えることができます。例えば、WizardMath-70Bはp=0.99のときにも良好な結果を発揮しますが、WizardMath-7BとWizardMath-13Bは失敗します。これはLMのスケーリング則といくつかの関連を示しており、モデルサイズとそれらが与えることができるドロップアウト率の間に定量化可能な相関が存在する可能性を示しています。

図3,図4はドロップアウト率を変化させたときのLMの性能をグラフにしたものです。ドロップアウト率を0.9まで上げても、性能低下がほとんどないことがわかります。

DAREとモデルマージ

著者らは7Bのパラメータを持つ2つのマージされたモデルを提供しました(スーパーマリオv1, スーパーマリオv2)こちらはOpen LLM Leaderboardで評価します。表2よりマージされたLMは構築された個々のモデルに勝り、大幅な改善を達成していることがわかります。注目すべきは、2024年1月28日まで、スーパーマリオv2がOpen LLM Leaderboardで1位を獲得していることです。これらの利点は、CPUのみで安価に得られることが期待できます。

リスケール操作の重要性

DAREのリスケール演算は元の埋め込みを近似するために不可欠です。これを検証するために、デルタパラメータをランダムに削除するDropOnlyを導入し「リスケーリングを行わずに」デルタパラメータをランダムに削除します。そして元のLMと、DAREとDropOnlyを用いたLMの埋め込みの類似度を計算します。具体的には、各入力トークンの埋め込みをレイヤーごとに取得し、平均コサイン類似度を計算します。図6がその結果です。90%のデルタパラメータを削除しても、DAREは各層の元の埋め込みを完全に維持でき、類似度は0.95であることがわかります。しかしDropOnlyはp=0.1の元の埋め込みを保存するだけで、pが高くなると類似度は急速に減少します。

図7に埋め込みコサイン類似度の分布を示します。DAREが元の埋め込みを近似する能力を持つことがわかります。

まとめ

提案手法では、まずLMにおけるSFTデルタパラメータの極めて冗長な特性について議論し、データ、再トレーニング、GPUなしでSFTに必要なデルタパラメータの数を効果的に削減するシンプルなアプローチのDAREを提案しました。DAREはすべてのSFTデルタパラメータを使用した場合と比較して、性能をあまり犠牲にすることなく、90%または99%のSFTデルタパラメータを印象的に低下させることができます。DAREを用いた広範な実験を行って、DAREがSFTデルタパラメータの冗長性を低減し、モデルマージ性能を向上させる効果があることが実証されました。またDAREがなぜ機能するのか、またDAREを使用する前提条件について深い分析を行いました。

Discussion