🎉

LoRA(Low-Rank Adaptation)の理論と実践:効率的な大規模言語モデルのファインチューニング手法

2024/05/25に公開

1. はじめに

近年、自然言語処理の分野では、大規模な言語モデル(Large Language Models; LLMs)が目覚ましい性能を示しています。しかし、これらのモデルは数十億から数千億のパラメータを持つため、ファインチューニングに膨大な計算リソースを必要とします。そこで注目されているのが、LoRA(Low-Rank Adaptation)と呼ばれる手法です。

LoRAは、大規模言語モデルを効率的にファインチューニングするための手法で、追加学習するパラメータ数を大幅に削減できます。本記事では、LoRAの理論的背景と実践的な使い方について、初心者にも分かりやすく解説します。

2. LoRAの基本概念

2.1 LoRAとは

LoRAは、事前学習済みの大規模言語モデルに対して、低ランク行列(Low-Rank Matrix)を適応させることで、効率的にファインチューニングを行う手法です。具体的には、モデルの各層の重み行列に、低ランクの適応行列を加算することで、ファインチューニング時に学習するパラメータ数を大幅に削減できます。

2.2 LoRAの数式表現

LoRAでは、事前学習済みの重み行列 W \in \mathbb{R}^{n_1 \times n_2} に対して、低ランクの適応行列 \Delta W = BA を加算します。ここで、B \in \mathbb{R}^{n_1 \times r}, A \in \mathbb{R}^{r \times n_2} であり、r \ll min(n_1, n_2) はLoRAの低ランクのサイズを表します。ファインチューニング時には、BA のみを学習し、W は固定します。

3. LoRAの理論的背景

3.1 ニューラルネットワークのスケーリング則

ニューラルネットワークの幅(隠れ層のユニット数)が大きくなると、初期化方法や学習率を適切に設定しないと、学習が不安定になることが知られています。例えば、Heの初期化では、重みの初期値の分散を 1/n にスケーリングすることで、幅 n が大きくなっても安定して学習できるようになります。

LoRAでは、このスケーリング則を利用して、適応行列 BA の学習率を適切に設定することで、効率的なファインチューニングを実現しています。

3.2 LoRAの理論的分析

LoRAの理論的分析では、簡単な線形モデルを用いて、LoRAの学習ダイナミクスを解析的に導出しています。その結果、以下のような知見が得られています。

  1. BA に同じ学習率を使うと、幅 n が大きい場合に非効率的な学習になる。
  2. BA の学習率を \eta_B = \Theta(1), \eta_A = \Theta(n^{-1}) に設定すると、効率的な学習が可能になる。

ここで、\Theta(·) は漸近記法であり、n \to \infty の極限で、定数倍の範囲内に収まることを表します。

4. LoRAの実装と使用方法

4.1 LoRAの実装

LoRAは、PyTorchやTensorFlowなどの深層学習フレームワークを用いて簡単に実装できます。以下は、PyTorchでLoRAを実装した疑似コードです。

class LoRA(nn.Module):
    def __init__(self, layer, rank):
        super().__init__()
        self.layer = layer
        self.rank = rank
        self.B = nn.Parameter(torch.zeros(layer.weight.shape[0], rank))
        self.A = nn.Parameter(torch.zeros(rank, layer.weight.shape[1]))
        
    def forward(self, x):
        return self.layer(x) + torch.matmul(self.B, self.A) 

ここでは、layerは事前学習済みのレイヤーを表し、rankはLoRAの低ランクのサイズを表します。BAは学習対象のパラメータで、forward関数で元の重み行列に加算されます。

4.2 LoRAの使用方法

LoRAを使ってファインチューニングを行う際は、以下の手順を踏みます。

  1. 事前学習済みのモデルをロードし、ファインチューニング対象のレイヤーを選択する。
  2. 選択したレイヤーに対して、LoRAを適用する。
  3. LoRAの適応行列 BA の学習率を適切に設定する。
  4. 通常のファインチューニングと同様に、損失関数を定義し、最適化アルゴリズムを選択する。
  5. ファインチューニングを実行する。

以下は、PyTorchでLoRAを使ったファインチューニングの疑似コードです。

# 事前学習済みモデルのロード
model = PretrainedModel()

# LoRAの適用
for name, layer in model.named_modules():
    if isinstance(layer, nn.Linear):
        lora = LoRA(layer, rank=8)
        setattr(model, name, lora)
        
# 最適化アルゴリズムの設定
optimizer = torch.optim.Adam([
    {'params': model.parameters(), 'lr': 1e-4},  # 元のモデルパラメータ
    {'params': [p for n, p in model.named_parameters() if 'lora' in n], 'lr': 1e-3}  # LoRAパラメータ
])

# ファインチューニングの実行
for epoch in range(num_epochs):
    for x, y in dataloader:
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()

ここでは、modelは事前学習済みのモデルを表し、named_modules()でモデルの各レイヤーを取得しています。nn.Linearレイヤーに対してLoRAを適用し、setattrでモデルのレイヤーを置き換えています。

最適化アルゴリズムには、Adamを使用し、元のモデルパラメータとLoRAパラメータに対して異なる学習率を設定しています。LoRAパラメータの学習率は、元のモデルパラメータの学習率よりも大きく設定することが推奨されています。

5. LoRAの実験結果

5.1 GLUEベンチマーク

著者らは、LoRAをRoBERTaとGPT-2に適用し、GLUEベンチマークでその性能を評価しました。その結果、LoRAは通常のファインチューニングと同等の性能を達成しながら、ファインチューニング時間を大幅に短縮できることが示されました。

特に、RoBERTaでは、LoRAを使ったファインチューニングが、通常のファインチューニングの約1/4の時間で同等の性能を達成しました。また、GPT-2では、LoRAを使ったファインチューニングが、通常のファインチューニングの約1/2の時間で同等の性能を達成しました。

5.2 LoRAの学習率の影響

著者らは、LoRAの適応行列 BA の学習率が、ファインチューニングの性能に与える影響についても調査しました。その結果、BA の学習率を適切に設定することが、効率的なファインチューニングに重要であることが分かりました。

具体的には、B の学習率を A の学習率よりも大きく設定することで、安定して高い性能が得られることが示されました。一方、BA の学習率を同じにした場合、ファインチューニングの性能が低下する傾向が見られました。

6. まとめ

本記事では、大規模言語モデルを効率的にファインチューニングするための手法であるLoRAについて解説しました。LoRAは、低ランクの適応行列を使って、ファインチューニング時に学習するパラメータ数を大幅に削減できる手法です。

理論的な分析から、LoRAの適応行列の学習率を適切に設定することが、効率的なファインチューニングに重要であることが分かりました。また、実験結果から、LoRAが通常のファインチューニングと同等の性能を達成しながら、ファインチューニング時間を大幅に短縮できることが示されました。

LoRAは、大規模言語モデルを効率的にファインチューニングするための有望な手法であり、今後さらなる発展が期待されます。本記事が、LoRAを理解し、活用するための一助となれば幸いです。

※ 用語解説など

低ランク行列(Low-Rank Matrix)

低ランク行列とは、行列のランク(独立な行または列の最大数)が行列のサイズに比べて小さい行列のことを指します。低ランク行列は、元の行列を近似的に表現できるため、データの圧縮や次元削減によく使われます。

Xavier初期化(Glorot初期化)

Xavier初期化は、活性化関数が線形である場合に、各層の出力の分散が一定になるように重みを初期化する手法です。重みの初期値は、以下の式で与えられます。
W \sim \mathcal{U}(-\sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}, \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}})
ここで、\mathcal{U} は一様分布、n_{\text{in}} は入力ユニット数、n_{\text{out}} は出力ユニット数を表します。

He初期化

He初期化は、Xavier初期化をReLU型活性化関数向けに少しアレンジした手法で、CNNの重み初期化として広く使われています。重みの初期値は、以下の式で与えられます。
W \sim \mathcal{N}(0, \sqrt{\frac{2}{n_{\text{in}}}})
ここで、\mathcal{N} は正規分布、n_{\text{in}} は入力ユニット数を表します。

以下の数式の意味

W \in \mathbb{R}^{n_1 \times n_2}: 事前学習済みの重み行列 Wn_1 \times n_2 次元の実数値行列であることを表します。
\Delta W = BA: 低ランクの適応行列 \Delta W が、BA の積で表されることを示しています。
B \in \mathbb{R}^{n_1 \times r}, A \in \mathbb{R}^{r \times n_2}: 適応行列 Bn_1 \times r 次元、Ar \times n_2 次元の実数値行列であることを表します。
r \ll \min(n_1, n_2): LoRAの低ランクのサイズ r が、n_1n_2 のうち小さい方よりもはるかに小さいことを示しています。

参考

https://arxiv.org/abs/2402.12354v1

Discussion