🚀 FSDP2で大規模言語モデルを効率的にファインチューニング:ZeRO-3の実装と実践ガイド
📝 概要
この記事では、PyTorchの最新分散学習機能であるFSDP2(Fully Sharded Data Parallel 2)を使って、限られたGPUメモリで大規模言語モデル(LLM)をファインチューニングする方法を詳しく解説します。7BパラメータのLlamaモデルを4台のGPUで効率的に学習させる実装例を通して、ZeRO-3アルゴリズムの理論と実践の両方をカバーします 🤖
🔍 なぜFSDP2が必要なのか?
GPUメモリの課題 💾
大規模言語モデルの学習では、GPUメモリが以下の要素で消費されます:
- モデルパラメータ(例:7Bモデル = bf16で約14GB)
- 勾配(パラメータと同サイズ = bf16で約14GB)
- オプティマイザ状態(Adam/AdamWでパラメータの12倍 = fp32で約84GB)
- アクティベーション(バッチサイズとシーケンス長に依存)
7Bモデルの場合、1台のGPUあたり110GB以上のメモリが必要になってしまいます!😱
オプティマイザ状態が12倍になる理由
Adam/AdamWオプティマイザは、各パラメータに対してFP32で3つのコピーを保持します:
- マスターコピー(パラメータ本体:4バイト)
- モーメンタム状態(4バイト)
- 分散状態(4バイト)
合計:4 + 4 + 4 = 12バイト/パラメータ
FSDP2のメモリ使用量公式 📊
FSDP2のフルシャーディングを使用した場合、GPU1台あたりのピークメモリ使用量は以下で推定できます:
ここで:
- P: 総パラメータ数
- N: GPU数
- 16バイト/パラメータ: 2(パラメータbf16)+ 2(勾配bf16)+ 12(オプティマイザfp32)
- 2bshL: 勾配チェックポイント使用時のアクティベーションメモリ(bf16)
- O: システムオーバーヘッド(約6.5GB)
重要なポイント: FSDP2は16PをN台のGPUで分散するため、大規模モデルの学習が可能になります!⚡
🏗️ ZeRO-3とFSDP2の理論
ZeRO-3アルゴリズム
ZeRO-3(Zero Redundancy Optimizer Stage 3)は、Microsoft DeepSpeedで提案された分散学習手法です:
- パラメータシャーディング: 各GPUが異なるパラメータ部分を保持
- Just-in-Time収集: 計算時のみ必要なパラメータを収集
- 即座に再シャーディング: 計算後すぐにパラメータを分散
FSDP2の進化点
FSDP2は、PyTorchによるZeRO-3の最新実装で、以下の改善があります:
- DTensorベース: パラメータごとのきめ細かいシャーディング制御
- コンポーザブル: 他の並列化戦略との簡単な統合
- 高性能: 通信オーバーラップの改善とオーバーヘッド削減
💻 実装:モデル初期化が鍵
FSDP2の複雑さは主にモデル初期化にあります。正しく設定すれば、学習ループは通常のPyTorchと同じです!
ステップ1: デバイスメッシュのセットアップ
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
# 分散環境の初期化
dist.init_process_group(backend="nccl")
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])
# 1次元メッシュでピュアFSDPを構成
mesh = init_device_mesh("cuda", (world_size,))
ステップ2: メタデバイス初期化(大規模モデル用)
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
# 重みなしでコンフィグをロード
model_name = "meta-llama/Llama-2-7b-hf"
config = AutoConfig.from_pretrained(model_name)
config.torch_dtype = torch.bfloat16
# メタデバイス上でモデル構造を作成(メモリ割り当てなし)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
なぜメタデバイス? 7Bモデルを通常通りロードすると、各GPUで14GB以上のRAMが必要です。メタデバイスはメモリ割り当てなしで構造のみ作成します 🧠
ステップ3: FSDP2シャーディングの適用
from torch.distributed.fsdp import fully_shard
from torch.distributed.fsdp.api import MixedPrecisionPolicy
# FSDP2パラメータの設定
fsdp_kwargs = {
"mesh": mesh,
"reshard_after_forward": True, # ZeRO-3動作: フォワード後に再シャーディング
"mp_policy": MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32, # 勾配により高い精度を使用
cast_forward_inputs=True,
),
}
# ボトムアップでシャーディング適用: レイヤー → ルート
for i, layer in enumerate(model.model.layers):
model.model.layers[i] = fully_shard(layer, **fsdp_kwargs)
model = fully_shard(model, **fsdp_kwargs) # ルートシャーディング
重要な原則: トランスフォーマーレイヤーレベルでシャーディングし、小さなモジュール単位では行いません。これにより通信効率とメモリ節約のバランスを取ります ⚖️
ステップ4: 事前学習済み重みのロード
from torch.distributed.checkpoint.state_dict import set_model_state_dict, StateDictOptions
# ランク0のみで重みをロード
if local_rank == 0:
temp_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
state_dict = temp_model.state_dict()
del temp_model # メモリ解放
else:
state_dict = None
# 自動的にブロードキャストとシャーディング
set_model_state_dict(
model,
model_state_dict=state_dict,
options=StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
),
)
ステップ5: 勾配チェックポイントの有効化(オプション)
# HuggingFaceモデルの場合、ロード後に有効化
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
🏃♂️ 学習:通常のPyTorchと同じ!
初期化が完了すれば、学習ループは通常と同じです:
# オプティマイザの設定
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
# 学習ループ
for batch_idx, batch in enumerate(dataloader):
# フォワードパス
outputs = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"]
)
loss = outputs.loss
# バックワードパス
loss.backward()
# 勾配クリッピング(オプション)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# パラメータ更新
optimizer.step()
optimizer.zero_grad()
if batch_idx % 10 == 0:
print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")
🔧 実装時の注意点とトラブルシューティング
よくある問題と解決策
-
OOM (Out of Memory) エラー
- バッチサイズを減らす
- 勾配蓄積ステップを増やす
- 勾配チェックポイントを有効化
-
通信エラー
- NCCLバックエンドの設定確認
- ファイアウォール設定の確認
-
学習が遅い
- シャーディング戦略の見直し
- 通信オーバーラップの最適化
🎯 まとめ
FSDP2により大規模モデル学習が身近になりました:
✅ 馴染みやすいAPI: 学習ループは通常のPyTorchと変わらず
✅ 自動シャーディング: FSDP2が全ての通信を自動処理
✅ 柔軟性: 勾配チェックポイントなど他の最適化との簡単な組み合わせ
✅ 高いメモリ効率: 従来の4分の1のメモリで同等の学習が可能
重要なポイント
- ZeRO-3の理解: 全てをシャーディングし、必要時のみ収集
- 慎重な初期化: メタデバイス → シャーディング → 重みロード
- レイヤーレベルシャーディング: 効率とメモリのバランス
- 通常の学習: 設定後は普通のPyTorch!
複雑さは初期化に集中していますが、その見返りは大きいです。本来GPUに載らないモデルを、最小限のコード変更で学習できるようになります 🎉
この記事が皆さんの大規模モデル学習の助けになれば幸いです!質問やコメントお待ちしています 💬
Discussion