🐥

LoRAの論文を読む

2023/07/29に公開

LoRAとはfinetuningを効率的に低コストでできる手法。詳しくは他にわかりやすい解説をしている記事がたくさんあるのでそちらを参考にしてください。

今回は自分の知らなかった細かなところのみを書きます。
(分からなかった点のメモも残っているので注意)

論文
https://arxiv.org/abs/2106.09685

実装はこのあたり
https://github.com/microsoft/LoRA
https://github.com/huggingface/peft

表記はtransfomerのattentionのquery/key/value/outputがW_q, W_k, W_v,W_oに対応

手法

学習済重みW_0に対し、fine tuningで新たに学習するweightをΔWとする
ただし、W_0 ∈ \mathbb{R}^{d*k}

新たに学習させたweightを含めたすべてのweightは次の様に表せる
W_0 + ΔW

ΔWは次のように表せる
ΔW = BA

ただし、B ∈ \mathbb{R}^{d*r} , B ∈ \mathbb{R}^{r*k} 。ここでrはthe rank of a LoRA moduleで、r << min(d, k)である

Aに対してはa random Gaussian initialization、Bに対してはzeroで初期化する。
したがって、学習の最初はΔW = BAはゼロとなる。

入力xに対する出力は以下のように表せる。
h = W_0 x + ΔW x

α/rΔWxをスケールする。ただし、αはrにおける定数。
最適化手法としてAdamを使った場合、スケールを適切に行なった場合αはおおよそlearning rateと同じになる。
(このあたりの説明が不明。実装だと毎回スケールしてそう?)

実装はこのあたり
https://github.com/microsoft/LoRA/blob/main/loralib/layers.py#L112
https://github.com/microsoft/LoRA/blob/main/loralib/layers.py#L134
https://github.com/microsoft/LoRA/blob/main/loralib/layers.py#L148

この論文ではtransfomerのattentionにのみ適応する。
MLPレイヤーやLayerNormレイヤー、biasesへの適応は今後の課題。

7章 UNDERSTANDING THE LOW-RANK UPDATES

5章の実験結果と6章はスキップ

7章ではいくつかの疑問に答えるために、実験を行う

  • transformerのどのweightを学習すれば、性能を最大化できるか
  • ΔWはrankdeficientなのか?(rankdeficientってなに?)
  • ΔWとWの相関関係、ΔWとWのサイズ比較

parameter budgetの場合、transformerのどのweightを学習すれば性能を最大化できるか

GPT-3 175Bで18Mのparameter budgetとする。データセットは以下を使う。

  • WikiSQL: sql queryと自然言語のセット
  • MultiNLI: 文同士の関係

W_qもしくはW_kのみの場合は低いスコアになる、一方でW_qW_vの両方を適応させた場合はがスコアが高くなる。これは、r=4でもΔWに十分な情報を含んでいることを示唆している。

qとvの2つのattention weight typeに適応することで、r=8よりも小さいr=4で十分ってことかな?
ただ、r=8でW_oの場合と、r=4でW_q, W_vの場合が近いスコアなのが気になる
論文中であまり言及されてないが、r=2で4種類のweight typeに適応させたのが一番スコア高い

以上より、大きなrで1種類の重みを適応させるより、多くの重み行列適応させる方が望ましい。

LoRAにおけるrは何か?

rank rの影響を調べる

weight typeがW_qW_vのときの結果に注目する

  • WikiSQLでは、r=1, r=2を比べるとr=1の方がスコアが高い
  • MultiNLIでは、r=1, r=2を比べるとr=2の方がスコアが高い

次の文は上記の結果を見て言ってる? To our surprise, a rank as small as one suffices for adapting both Wq and Wv on these datasets while training Wq alone needs a larger r.
だとしたら、W_q, W_vのr=2よりr=4の方がスコア高くなってるけどなぁ...
確かに、WikiSQLでW_q, W_v, W_k, W_oのときは、r=1の場合が一番スコア高い
r=64まで上げる必要はないけど、r=1だと小さすぎで、r=4がちょうど良さそう
W_oを含めるとスコアが高くなるのは、行列サイズが大きいのでΔWへの影響が大きいから?

小さなrの場合でも、よりcompetitivelyなパフォーマンスになっている。W_qだけでなくW_q, W_vでより強調される。(ここは何が言いたい?)

複数種類のweight typeに適応させると、rは小さくても十分。

ΔWとWの相関関係、ΔWとWのサイズ比較

ΔWの特異値分解する。左特異ベクトル、右特異ベクトルがそれぞれU,V
U^TWV^TWのフロベニウスノルムを算出。
UとVのtop rのU^TWV^Tとランダム行列のノルムも計算

ランダム行列と比べΔWはWと相関がある。これはΔWがすでにあるWに対する増幅を示唆する。
Wの最大特異値に対応する特異ベクトルを繰り返す代わりに、Wの強調されてない方向を強調する。 (∆W only amplifies directions that are not emphasized in W. )

r=4では6.91/0.32≒21.5、r=64では3.57/1.90≒1.88。r=4の方がより効率的にタスクに対して適応できている。

以上より、事前学習で学ぶ一般的な知識に対し、低ランク適応では特定のタスクの重要な特徴を強調するのではないか?

結論

今後の研究の方向性はたくさんある。

  1. LoRAは他の効率的な適応手法と組み合わせることができ、直交的な改善を提供できる可能性がある。
  2. ファインチューニングやLoRAの背後にあるメカニズムは明確ではない。我々は、LoRAが完全なファインチューニングよりも答えやすいと信じている。
    3)LoRAを適用する重み行列の選択は、ほとんどヒューリスティックに頼っている。もっと原理的な方法はありますか?
  3. 最後に、∆W のランク不足は、W もランク不足である可能性を示唆している。

ここはそのまま翻訳

所感

個人的に発見だったのは、rとattention weightの種類の選び方について。
実装だとr=8、W_kW_vが多いが、タスクに応じてちゃんと選ぶべきっぽい。W_kW_vW_qW_oの4つパターンが良さそう。
https://github.com/huggingface/peft/blob/v0.4.0/src/peft/tuners/lora.py#L69-L76

r=8でW_oがスコア高かったので、r=4でW_qW_oパターンとか割と良い感じになるのでは??
実験では175Bモデルで学習可能パラメタが18Mに制限された状態だった。llama2 13Bでは、学習可能パラメタは6.5Mなので、タスクにも依るがr=2とかで十分そう?

rankdeficientの話どこにあった?

結局スケールがわらなかった。実装的にはlayerのforwardが呼び出されるところで毎回スケールされてる気がする。ただし、デフォルト値はr=8でrola_alpha=8なので、defaultではscalingはされない。

実装ではdropoutあるけどこれの影響は?
zeroの行列から始まるからdropoutしなくても割とzeroの箇所が多そうだけど、どのくらい効果あるんだろう

わからないところ結構あったので分かる人いればtwitterとかで教えてください

Discussion