Open7

FSDP2

nariaki3551nariaki3551

FSDPについて、新しいデザインの実装が進められている。次のようなAPIの仕様変更が進められている。

https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md

この記事にあるように、FSDPは次のように呼び出すようになっていた。

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with torch.device("meta"):
    model = Transformer()
policy = ModuleWrapPolicy({TransformerBlock})
# Call `reset_parameters()` on every module
model = FSDP(model, auto_wrap_policy=policy)
# Call `param_init_fn` on every module
def param_init_fn(module: nn.Module) -> None: ...
model = FSDP(model, auto_wrap_policy=policy, param_init_fn=param_init_fn)

これがこうなる。

with torch.device("meta"):
    model = Transformer()
for module in model.modules():
    if isinstance(module, TransformerBlock):
        fully_shard(module)
fully_shard(model)
for tensor in itertools.chain(model.parameters(), model.buffers()):
    assert tensor.device == torch.device("meta")
# Allocate buffers and sharded parameters on GPU
model.to_empty(device="cuda")
# Run user-defined initializers
model.init_weights() # or `model.apply(init_weights)`
nariaki3551nariaki3551

(なぜ?)FSDP2はもはやnn.Moduleラッパーではなくなった

  • そのため、FSDP1ではauto_wrap_policyで指定できたユニット分割(パラメータ集約と勾配共有の単位)は今のところ手動的に行う必要がある
    for module in model.modules():
     if isinstance(module, TransformerBlock):
         fully_shard(module)
    

fully_shard(model)を実行すると、FSDPModuleクラスの関数が優先されて実行されるようになる(ラップされているのではない)。分かりにくいのだが、

https://github.com/pytorch/pytorch/blob/d620fefb2c0e6a58ad7352189278d48d34b2fbb4/torch/distributed/fsdp/_fully_shard/_fully_shard.py#L240

にてtypeを用いて、第二引数のタプルのクラスをもとに、新しくクラスを作成している。合体させているようなイメージ。もしタプル内のクラスで同じ関数があった場合は、よりタプルの先頭に近いクラスメソッドが優先される。(こんな風にクラスを作ることがあるのか....)

https://github.com/pytorch/pytorch/blob/d620fefb2c0e6a58ad7352189278d48d34b2fbb4/torch/distributed/fsdp/_fully_shard/_fully_shard.py#L234-L242

そのため、例えばnn.Moduleをfully_shardに入力した場合、変換されたモデルに対して、nn.Moduleで定義された関数はほとんどそのまま使うことができる。
https://github.com/pytorch/pytorch/blob/d620fefb2c0e6a58ad7352189278d48d34b2fbb4/torch/distributed/fsdp/_fully_shard/_fully_shard.py#L252-L552

nariaki3551nariaki3551

FSDP1でもそうだが、FSDPではhookと呼ばれる関数を登録し、forwardとbackwardの前後で通信処理が行われるようにしている。

  • foward: unshardと呼ばれるパラメータ収集(Allgather)
  • backward: unshardと呼ばれるパラメータ収集(Allgather)と、勾配収集(ReduceScatter)

で、このhookの登録がどこで行われているかというと、_fully_shard.pyのここ。このstateはFSDPState

https://github.com/pytorch/pytorch/blob/d620fefb2c0e6a58ad7352189278d48d34b2fbb4/torch/distributed/fsdp/_fully_shard/_fully_shard.py#L210-L211

nariaki3551nariaki3551

FSDPStateによるhookの登録

https://github.com/pytorch/pytorch/blob/3c2bf247867457fd03603a67257f4fc9581f3899/torch/distributed/fsdp/_fully_shard/_fsdp_state.py#L106-L113

_pre_forward

  1. 当該moduleのパラメータのunshardの完了処理
  2. 次のmoduleのパラメータunshardの発行

の処理を行う。これらのunshardは非同期で行うことができるが self.unshard_async_op: bool = Falseになっているっぽい(宿題)。一応、FSDPModule::_set_unshard_async_opでこれらを全てのParamに対してTrueにできる。

1. 当該moduleのパラメータのunshardの完了処理

この中でFSDPParamGroupのpre_forwardを呼び出している。

https://github.com/pytorch/pytorch/blob/3c2bf247867457fd03603a67257f4fc9581f3899/torch/distributed/fsdp/_fully_shard/_fsdp_state.py#L239-L240

中身。パラメータ収集に関しては、大きく次の処理が行われる。

  1. unshard
    • 今からforwardの計算を行うモジュールのパラメータ集約用のAllgatherを行う。ここでasync_op引数があり、これがTrueであればAllgatherの完了を待たずに次の処理に進む。すでにAllgatherがprefetch等で発行済みであれば何もしない。
  2. wait_for_unshard
    • 前回発行したalgatherのstateがFSDPParamGroupの_all_gather_resultに格納されており、これをwaitしてAllgatherを確実に終わらせる。その後、foreach_all_gather_copy_outにてFSDPParamにAllgatherの結果をコピーする。これをcopy-outと表現している。(逆にAllgather用のinputにデータをコピーするのはcopy-inと表現されている)

2. 次のmoduleのパラメータunshardの発行

https://github.com/pytorch/pytorch/blob/3c2bf247867457fd03603a67257f4fc9581f3899/torch/distributed/fsdp/_fully_shard/_fsdp_state.py#L241-L243

self._states_to_forward_prefetchにあらかじめパラメータを収集すべきFSDPParamGroupが格納されており、それらをprefetchする。_states_to_forward_prefetchはどこで設定しているんだろう(宿題)。


ただし注意書き

However, you must use explicit prefetching (e.g. via :meth:unshard)
in forward to still get overlap, and the pre-all-gather ops like dtype
casting and copy-in will not overlap with compute.

にもあるように、

  • async_op=True にしても、何もしないと overlap は起きない
  • overlap を狙うには、手動で unshard(async_op=True) を呼んで先取り通信が必要
  • また、dtype変換や copy-in 処理は依然として同じストリーム上で実行されるため、完全な通信/計算オーバーラップは実現しないことに注意
nariaki3551nariaki3551

_pre_backward

  1. self._fsdp_param_group.pre_backward(default_prefetch)
    1. unshard (prefetchされていれば何もしない)
    2. wait_for_unshard(発行済みのunshardのprefetchを完了させる)
  2. prefetch_unshard