🐈

ReLoRA: Stack More Layers Differently のメモ

2023/07/18に公開

https://arxiv.org/abs/2307.05695

LLM のフルの(事前)学習, 計算量おおくてつらいね...
でも LoRA でファインチューンとか, 追加事前学習(incremental pre-training)は一定の成果を見せてるね.

事前学習(pre-training)は, ネットワークに新しい task を適合させるために少量の修正を可能にするステップである.

Intrinsic Dimensionality Explains the Effectiveness of Language Model Fine-Tuning
https://arxiv.org/abs/2012.13255

によれば, ネットワークを事前学習するほど, タスクの学習に必要な変更のランク(rank)が減少することが示されている.

そして, LLM で多く使われている Transformer の Attention 機構は多くの場合低ランクになる

したがって, 最初はフルで学習しつつ, 途中から LoRA で low-rank(より少ない trainable parameters)で学習を継続するとよいのでは?

ただ, LoRA が一個だけだと性能でないかもなので, rank の理論

rank(A + B) ≤ rank(A) + rank(B)

に従って, step が進むごとに LoRA レイヤーを追加(stack)していって, rank を上げるようにしてみたよ.

LoRA のスタッキング

学習ステップの区間(T1, T2, T3)ごとに LoRA レイヤーをくっつけていくよ.
(すでに過ぎた区間は W_i にフリーズ)

ここでそれぞれの区間の和(LoRA モジュール)は十分に独立している(=> 学習をうまくリセットする必要がある)とする.

=> 最初の区間は LoRA 一個だけなのでランクが低いが, ステップ進むにつれてランク数は増えていき性能向上が見込める. しかし, 計算量は各区間で同じなので計算量は一定, という感じカナ

=> ただそれだと Intrinsic Dimensionality Explains the Effectiveness of Language Model Fine-Tuning で pretrain するほどだんだんランクが下がるというのと矛盾するようなきもするが... 最初に複数 LoRA で, だんだん LoRA レイヤー減らしていくのじゃダメなのかしらん?

Jagged-cosine learning rate multiplier

あと, 区間ごとに学習のリセットをうまくやらないとダメ.

既存の手法として,

  • cosine 減衰する learning rate multipler scheduler がよくあります.
  • また, 最初に warm-up 期間としてジョジョに learning rate を上げていくのもよくあります.

これを参考に, ReLoRA では, jagged-cosine scheduler を提案しています.


(jagged-cosine: ギザギザ cosine)

ReLoRA のリセットのたびに warm-up しつつも全体的には cosine で減衰するような learning rate scheduler となっています.

実験

数字が少ないほうが性能がよい.
Control は, LoRA と同じ trainable parameters を, LoRA ではなく full training と同じように学習したもの.
60M では性能向上なかったけど, それ以上のケースでは
Control よりも ReLoRA でいい感じになったよ.
(full traning にはさすがにかなわないけど)

計算リソースあんまりないから 350M までしか試していないよ...
(8 x 4090)

Limitation and future work

とりまステップ数少ないけど 1.3B サイズで試してみたけど, ReLoRA で trainable parameters はすごいへるけど, 学習のメモリと計算量はそんなに減らなかったよ...
(メモリ 30% 減, 計算時間は 2 倍効率よいくらい)

よりネットワーク規模大きくなったり, 量子化とか使ったらいい感じになるかもはしれません!

感想

現状ではすごい事前学習が早くなったり, メモリ少なくできたりではありませんでしたが,
ここからいろいろアルゴリズム改善でいい感じになりそうな片鱗は味わえました.

https://github.com/guitaricet/peft_pretraining

コードもありますので, いろんな人が改善なりしていくことでしょう.
メモリ 1/3, 計算量 1/5 くらいにできると実務的に使っていけそう.

今後に期待ですね.

Discussion