LongLoRA論文まとめ
概要
LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models
のまとめ
訓練済みモデルのコンテキスト長を効率良く伸ばす手法について書かれている。
-
Shifted Sparse Attention
通常の実装では全てのトークンにアテンションをかけるが、グルーピングして一部のトークンにのみアテンションすることで効率的に学習する仕組み。
たった2行のコードで実装できる。 -
Improved LoRA for Long Context
Shifted Sparse Attentionをする場合はAttention層だけではなく、embedding層とnormalization層にもLoRAを適用したほうが精度が出る。 -
LongAlpacaを用いた教師あり学習
https://github.com/dvlab-research/LongLoRA?tab=readme-ov-file#longalpaca-data
学習に使ったデータセット。
実装
既存の方法との精度やメモリ消費量の比較
introduction
LongLoRAの概要。
embedding層とnormalization層を訓練することも重要。
perplexityもFullコンテキストの訓練に比べて遜色ない。
計算資源
フルファインチューニングでコンテキスト長を伸ばす訓練を行う場合、以下の様な膨大な計算資源が必要になる。
- Llama: 2k to 8k
A100(80G)×32
LongLoRAを使うと以下の様な計算資源で行える。
- Llama2 7B: 4k -> 100k
- Llama2 70B: 4k -> 32k
A100(80G)×8
Shifted Sparse Attention
トークングループ数: 4
トークン数: 8
ヘッド数: 4
の例
- attention headを2つのチャンクに分割する
- 片方のチャンクのトークンをグループサイズ(トークン数 / トークングループ数)の半分シフトする(末尾は先頭に回る)
- トークンをグループに分割する
- アテンションはグループ内のみで行う
Flash-Attention2と互換性がある。
dilated attention、sparse attentionとは相性が悪い。
pytorchでの実装例
Improved LoRA for Long Context
embedding層とnormalization層もLoRAチューニングしたほうが精度が出る。
コンテキスト長を伸ばしてもPerplexityが劣化しない。
LoRA+とは(embedding層とnormalization層もLoRAチューニングしたパターン)
Discussion