🦙

LongLoRA論文まとめ

2024/04/17に公開

概要

https://arxiv.org/abs/2309.12307
論文
LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models
のまとめ

訓練済みモデルのコンテキスト長を効率良く伸ばす手法について書かれている。

  • Shifted Sparse Attention
    通常の実装では全てのトークンにアテンションをかけるが、グルーピングして一部のトークンにのみアテンションすることで効率的に学習する仕組み。
    たった2行のコードで実装できる。

  • Improved LoRA for Long Context
    Shifted Sparse Attentionをする場合はAttention層だけではなく、embedding層とnormalization層にもLoRAを適用したほうが精度が出る。

  • LongAlpacaを用いた教師あり学習
    https://github.com/dvlab-research/LongLoRA?tab=readme-ov-file#longalpaca-data

学習に使ったデータセット。

https://github.com/dvlab-research/LongLoRA
実装


既存の方法との精度やメモリ消費量の比較

introduction


LongLoRAの概要。
embedding層とnormalization層を訓練することも重要。


perplexityもFullコンテキストの訓練に比べて遜色ない。

計算資源

フルファインチューニングでコンテキスト長を伸ばす訓練を行う場合、以下の様な膨大な計算資源が必要になる。

- Llama: 2k to 8k
A100(80G)×32

LongLoRAを使うと以下の様な計算資源で行える。

- Llama2 7B: 4k -> 100k
- Llama2 70B: 4k -> 32k
A100(80G)×8

Shifted Sparse Attention

トークングループ数: 4
トークン数: 8
ヘッド数: 4

の例

  • attention headを2つのチャンクに分割する
  • 片方のチャンクのトークンをグループサイズ(トークン数 / トークングループ数)の半分シフトする(末尾は先頭に回る)
  • トークンをグループに分割する
  • アテンションはグループ内のみで行う

Flash-Attention2と互換性がある。
dilated attention、sparse attentionとは相性が悪い。


pytorchでの実装例

Improved LoRA for Long Context

embedding層とnormalization層もLoRAチューニングしたほうが精度が出る。

コンテキスト長を伸ばしてもPerplexityが劣化しない。
LoRA+とは(embedding層とnormalization層もLoRAチューニングしたパターン)

Discussion