🚀 FSDP2で大規模言語モデルを効率的にファインチューニング:ZeRO-3の実装と実践ガイド

に公開

📝 概要

この記事では、PyTorchの最新分散学習機能であるFSDP2(Fully Sharded Data Parallel 2)を使って、限られたGPUメモリで大規模言語モデル(LLM)をファインチューニングする方法を詳しく解説します。7BパラメータのLlamaモデルを4台のGPUで効率的に学習させる実装例を通して、ZeRO-3アルゴリズムの理論と実践の両方をカバーします 🤖

🔍 なぜFSDP2が必要なのか?

GPUメモリの課題 💾

大規模言語モデルの学習では、GPUメモリが以下の要素で消費されます:

  • モデルパラメータ(例:7Bモデル = bf16で約14GB)
  • 勾配(パラメータと同サイズ = bf16で約14GB)
  • オプティマイザ状態(Adam/AdamWでパラメータの12倍 = fp32で約84GB)
  • アクティベーション(バッチサイズとシーケンス長に依存)

7Bモデルの場合、1台のGPUあたり110GB以上のメモリが必要になってしまいます!😱

オプティマイザ状態が12倍になる理由

Adam/AdamWオプティマイザは、各パラメータに対してFP32で3つのコピーを保持します:

  • マスターコピー(パラメータ本体:4バイト)
  • モーメンタム状態(4バイト)
  • 分散状態(4バイト)

合計:4 + 4 + 4 = 12バイト/パラメータ

FSDP2のメモリ使用量公式 📊

FSDP2のフルシャーディングを使用した場合、GPU1台あたりのピークメモリ使用量は以下で推定できます:

M = \frac{16P}{N} + 2bshL + O

ここで:

  • P: 総パラメータ数
  • N: GPU数
  • 16バイト/パラメータ: 2(パラメータbf16)+ 2(勾配bf16)+ 12(オプティマイザfp32)
  • 2bshL: 勾配チェックポイント使用時のアクティベーションメモリ(bf16)
  • O: システムオーバーヘッド(約6.5GB)

重要なポイント: FSDP2は16PをN台のGPUで分散するため、大規模モデルの学習が可能になります!⚡

🏗️ ZeRO-3とFSDP2の理論

ZeRO-3アルゴリズム

ZeRO-3(Zero Redundancy Optimizer Stage 3)は、Microsoft DeepSpeedで提案された分散学習手法です:

  1. パラメータシャーディング: 各GPUが異なるパラメータ部分を保持
  2. Just-in-Time収集: 計算時のみ必要なパラメータを収集
  3. 即座に再シャーディング: 計算後すぐにパラメータを分散

FSDP2の進化点

FSDP2は、PyTorchによるZeRO-3の最新実装で、以下の改善があります:

  • DTensorベース: パラメータごとのきめ細かいシャーディング制御
  • コンポーザブル: 他の並列化戦略との簡単な統合
  • 高性能: 通信オーバーラップの改善とオーバーヘッド削減

💻 実装:モデル初期化が鍵

FSDP2の複雑さは主にモデル初期化にあります。正しく設定すれば、学習ループは通常のPyTorchと同じです!

ステップ1: デバイスメッシュのセットアップ

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh

# 分散環境の初期化
dist.init_process_group(backend="nccl")
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])

# 1次元メッシュでピュアFSDPを構成
mesh = init_device_mesh("cuda", (world_size,))

ステップ2: メタデバイス初期化(大規模モデル用)

from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM

# 重みなしでコンフィグをロード
model_name = "meta-llama/Llama-2-7b-hf"
config = AutoConfig.from_pretrained(model_name)
config.torch_dtype = torch.bfloat16

# メタデバイス上でモデル構造を作成(メモリ割り当てなし)
with init_empty_weights():
    model = AutoModelForCausalLM.from_config(config)

なぜメタデバイス? 7Bモデルを通常通りロードすると、各GPUで14GB以上のRAMが必要です。メタデバイスはメモリ割り当てなしで構造のみ作成します 🧠

ステップ3: FSDP2シャーディングの適用

from torch.distributed.fsdp import fully_shard
from torch.distributed.fsdp.api import MixedPrecisionPolicy

# FSDP2パラメータの設定
fsdp_kwargs = {
    "mesh": mesh,
    "reshard_after_forward": True,  # ZeRO-3動作: フォワード後に再シャーディング
    "mp_policy": MixedPrecisionPolicy(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,  # 勾配により高い精度を使用
        cast_forward_inputs=True,
    ),
}

# ボトムアップでシャーディング適用: レイヤー → ルート
for i, layer in enumerate(model.model.layers):
    model.model.layers[i] = fully_shard(layer, **fsdp_kwargs)

model = fully_shard(model, **fsdp_kwargs)  # ルートシャーディング

重要な原則: トランスフォーマーレイヤーレベルでシャーディングし、小さなモジュール単位では行いません。これにより通信効率とメモリ節約のバランスを取ります ⚖️

ステップ4: 事前学習済み重みのロード

from torch.distributed.checkpoint.state_dict import set_model_state_dict, StateDictOptions

# ランク0のみで重みをロード
if local_rank == 0:
    temp_model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.bfloat16, 
        device_map="cpu"
    )
    state_dict = temp_model.state_dict()
    del temp_model  # メモリ解放
else:
    state_dict = None

# 自動的にブロードキャストとシャーディング
set_model_state_dict(
    model,
    model_state_dict=state_dict,
    options=StateDictOptions(
        full_state_dict=True,
        broadcast_from_rank0=True,
    ),
)

ステップ5: 勾配チェックポイントの有効化(オプション)

# HuggingFaceモデルの場合、ロード後に有効化
model.enable_input_require_grads()
model.gradient_checkpointing_enable()

🏃‍♂️ 学習:通常のPyTorchと同じ!

初期化が完了すれば、学習ループは通常と同じです:

# オプティマイザの設定
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)

# 学習ループ
for batch_idx, batch in enumerate(dataloader):
    # フォワードパス
    outputs = model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        labels=batch["labels"]
    )
    
    loss = outputs.loss
    
    # バックワードパス
    loss.backward()
    
    # 勾配クリッピング(オプション)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    # パラメータ更新
    optimizer.step()
    optimizer.zero_grad()
    
    if batch_idx % 10 == 0:
        print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")

🔧 実装時の注意点とトラブルシューティング

よくある問題と解決策

  1. OOM (Out of Memory) エラー

    • バッチサイズを減らす
    • 勾配蓄積ステップを増やす
    • 勾配チェックポイントを有効化
  2. 通信エラー

    • NCCLバックエンドの設定確認
    • ファイアウォール設定の確認
  3. 学習が遅い

    • シャーディング戦略の見直し
    • 通信オーバーラップの最適化

🎯 まとめ

FSDP2により大規模モデル学習が身近になりました:

馴染みやすいAPI: 学習ループは通常のPyTorchと変わらず
自動シャーディング: FSDP2が全ての通信を自動処理
柔軟性: 勾配チェックポイントなど他の最適化との簡単な組み合わせ
高いメモリ効率: 従来の4分の1のメモリで同等の学習が可能

重要なポイント

  1. ZeRO-3の理解: 全てをシャーディングし、必要時のみ収集
  2. 慎重な初期化: メタデバイス → シャーディング → 重みロード
  3. レイヤーレベルシャーディング: 効率とメモリのバランス
  4. 通常の学習: 設定後は普通のPyTorch!

複雑さは初期化に集中していますが、その見返りは大きいです。本来GPUに載らないモデルを、最小限のコード変更で学習できるようになります 🎉

この記事が皆さんの大規模モデル学習の助けになれば幸いです!質問やコメントお待ちしています 💬

Discussion