FSDP2

FSDPについて、新しいデザインの実装が進められている。次のようなAPIの仕様変更が進められている。
この記事にあるように、FSDPは次のように呼び出すようになっていた。
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with torch.device("meta"):
model = Transformer()
policy = ModuleWrapPolicy({TransformerBlock})
# Call `reset_parameters()` on every module
model = FSDP(model, auto_wrap_policy=policy)
# Call `param_init_fn` on every module
def param_init_fn(module: nn.Module) -> None: ...
model = FSDP(model, auto_wrap_policy=policy, param_init_fn=param_init_fn)
これがこうなる。
with torch.device("meta"):
model = Transformer()
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
for tensor in itertools.chain(model.parameters(), model.buffers()):
assert tensor.device == torch.device("meta")
# Allocate buffers and sharded parameters on GPU
model.to_empty(device="cuda")
# Run user-defined initializers
model.init_weights() # or `model.apply(init_weights)`

変更点を少しずつ見ていく。疑問点を並べて、調べながら進めていこう。
FSDP2関連のプログラムは pytorch/torch/distributed/fsdp/_fully_shard のディレクトリに格納されている。

(なぜ?)FSDP2はもはやnn.Moduleラッパーではなくなった
- そのため、FSDP1ではauto_wrap_policyで指定できたユニット分割(パラメータ集約と勾配共有の単位)は今のところ手動的に行う必要がある
for module in model.modules(): if isinstance(module, TransformerBlock): fully_shard(module)
fully_shard(model)を実行すると、FSDPModuleクラスの関数が優先されて実行されるようになる(ラップされているのではない)。分かりにくいのだが、
にてtypeを用いて、第二引数のタプルのクラスをもとに、新しくクラスを作成している。合体させているようなイメージ。もしタプル内のクラスで同じ関数があった場合は、よりタプルの先頭に近いクラスメソッドが優先される。(こんな風にクラスを作ることがあるのか....)
そのため、例えばnn.Moduleをfully_shardに入力した場合、変換されたモデルに対して、nn.Moduleで定義された関数はほとんどそのまま使うことができる。

FSDP1でもそうだが、FSDPではhookと呼ばれる関数を登録し、forwardとbackwardの前後で通信処理が行われるようにしている。
- foward: unshardと呼ばれるパラメータ収集(Allgather)
- backward: unshardと呼ばれるパラメータ収集(Allgather)と、勾配収集(ReduceScatter)
で、このhookの登録がどこで行われているかというと、_fully_shard.pyのここ。このstateはFSDPState

FSDPStateによるhookの登録
_pre_forward
- 当該moduleのパラメータのunshardの完了処理
- 次のmoduleのパラメータunshardの発行
の処理を行う。これらのunshardは非同期で行うことができるが self.unshard_async_op: bool = Falseになっているっぽい(宿題)。一応、FSDPModule::_set_unshard_async_opでこれらを全てのParamに対してTrueにできる。
1. 当該moduleのパラメータのunshardの完了処理
この中でFSDPParamGroupのpre_forwardを呼び出している。
中身。パラメータ収集に関しては、大きく次の処理が行われる。
-
unshard
- 今からforwardの計算を行うモジュールのパラメータ集約用のAllgatherを行う。ここでasync_op引数があり、これがTrueであればAllgatherの完了を待たずに次の処理に進む。すでにAllgatherがprefetch等で発行済みであれば何もしない。
-
wait_for_unshard
- 前回発行したalgatherのstateがFSDPParamGroupの_all_gather_resultに格納されており、これをwaitしてAllgatherを確実に終わらせる。その後、foreach_all_gather_copy_outにてFSDPParamにAllgatherの結果をコピーする。これをcopy-outと表現している。(逆にAllgather用のinputにデータをコピーするのはcopy-inと表現されている)
2. 次のmoduleのパラメータunshardの発行
self._states_to_forward_prefetchにあらかじめパラメータを収集すべきFSDPParamGroupが格納されており、それらをprefetchする。_states_to_forward_prefetchはどこで設定しているんだろう(宿題)。
ただし注意書き
However, you must use explicit prefetching (e.g. via :meth:
unshard
)
in forward to still get overlap, and the pre-all-gather ops like dtype
casting and copy-in will not overlap with compute.
にもあるように、
- async_op=True にしても、何もしないと overlap は起きない
- overlap を狙うには、手動で unshard(async_op=True) を呼んで先取り通信が必要
- また、dtype変換や copy-in 処理は依然として同じストリーム上で実行されるため、完全な通信/計算オーバーラップは実現しないことに注意

_pre_backward
- self._fsdp_param_group.pre_backward(default_prefetch)
- unshard (prefetchされていれば何もしない)
- wait_for_unshard(発行済みのunshardのprefetchを完了させる)
- prefetch_unshard