PyTorch Lightning Fabric 調べる

Lightning Fabric は生の PyTorch よりも高レイヤーだが PyTorch Lightning の Trainer よりも低いレベルでの操作をできるようにするライブラリ。HuggingFace の accelerate みたいなもの。
複数 GPU でのモデルの扱いを簡単にしたり、ロガーをまとめて管理したりとほぼほぼ huggingface/accelerate と同じような機能がついている。
個人的に、PyTorch Lightning は safetensors での保存に対応してなかったり、多くのものを隠蔽しすぎてしまって拡張性がそこまで高くない感じがするので、より薄く使える Fabric をちょっと調べてみたいと思った
単一 GPU のみを想定して学習コードを書くと、複数 GPU 対応させるときに大変だったりするので、いい感じのラッパーを使ってお手軽に複数 GPU 対応 & 細かい設定を両立させたいモチベーション

共通セットアップ
accelerate と同じ雰囲気で行う
fabric = Fabric(devices=2)
fabric.launch()
# Set up model and optimizer for accelerated training
model, optimizer = fabric.setup(model, optimizer)
# If you don't want Fabric to set the device
model, optimizer = fabric.setup(model, optimizer, move_to_device=False)
Fabric()
で devices=2
を渡しているが、多分 CLI の引数経由で渡すのが楽で良さそう?
fabric run ./path/to/train.py \
--strategy=ddp \
--devices=8 \
--accelerator=cuda \
--precision="bf16"

fabric run ./train.py
で実行するときは fabric.launch()
を実行しなくていい(するとエラーになる)ようだ
RuntimeError: This script was launched through the CLI, and processes have already been created. Calling
.launch()
again is not allowed.
fabric.launch()
はコード中でデバイスを指定する場合に呼び出して、セットアップを行うみたいな感じっぽい
The
launch()
method should only be used if you intend to specify accelerator, devices, and so on in the code (programmatically). If you are launching with the Lightning CLI,fabric run
..., removelaunch()
from your code.
とのこと。
launch()
を使うのはノートブックとかそこらへんのシチュエーションになりそう。

torch compile
accelerate と異なり、手動で torch.comple()
して通常通り fabric.setup()
に渡すらしい。

setup()
fabric.setup()
や fabric.setup_module()
、fabric.setup_dataloaders()
に通すと、主に以下が行われる
- 自動でマルチGPU・分散学習に対応
-
forward
の書き換え & 自動 autocast - dataloader のテンソルの自動 move
というように、デバイスや混合精度を深く意識せずに学習コードを書けるようになる。