🧬

アニメーションでDeepSpeed (ZeRO1)の仕組みを完全に理解する

2023/12/14に公開

Turingのリサーチチームで完全自動運転の研究開発を行なっている棚橋です。Turingアドベントカレンダー14日目の記事として、DeepSpeedについて取り上げます。
DeepSpeedはMicrosoftによって開発されたライブラリで、一言で言うと、「1つのGPUに乗り切らないような巨大MLモデルをなんとか学習させるため」のツールです。特に、この論文で提案されているDeepSpeedのZero Redundancy Optimizer (ZeRO)という技術が非常に注目されています。
また、DeepSpeedが昨今の大規模言語モデルの学習に多く利用されています。詳細は過去の記事をご覧ください。

https://zenn.dev/turing_motors/articles/04c1328bf6095a

DeepSpeedのライブラリ自体が簡単に利用できる反面、内部でどのように動作しているのかまで理解する機会はなかなかありません。しかし、効率的に動作させるためのチューニングを行ったり、発展的な利用を行うには、内部の仕組みや実装について理解しておくことが重要ですし、何よりどんな実装なのか気になりますよね。とは言っても、DeepSpeedのリポジトリに含まれるPythonファイルだけで10万行以上あり、コードを読み解くのもひと苦労でしょう。
そこでこの記事ではDeepSpeedの処理を説明するためのアニメーションとコードを紹介し、自分でDeepSpeedが作れるくらいDeepSpeedについて理解できるようにすることを目指します。

DeepSpeedの使い方

DeepSpeedは様々な利用方法があります。最もポピュラーな使い方は、HuggingFaceのtransformersライブラリから利用する方法で、HuggingFace: DeepSpeed Integrationで詳しい使い方が紹介されています。

また、DeepSpeed単体で、Pytorchのモデルと組み合わせて使うには以下のように、数行を挟むだけでDeepSpeedに対応させることが可能となります。

DeepSpeed単体のサンプルコード
import deepspeed

# Pytorchのモデルを作成(ZeRO3の時はwith zero.Init()で囲む必要があります)
model = ...

# 学習データを準備
trainset = ...

# optimizerとモデルをDeepSpeed用にラップする
model_engine, optimizer, trainloader, __ = deepspeed.initialize(
    args=args, model=model, training_data=trainset)
    
for i, data in enumerate(trainloader):
        inputs, labels = data[0].to(model_engine.local_rank), data[1].to(model_engine.local_rank)
        outputs = model_engine(inputs)
        loss = criterion(outputs, labels)
        model_engine.backward(loss)
        model_engine.step()

DeepSpeed ZeROとは

DeepSpeedのZeROではstage1~3の段階があり、どの程度モデルを分割するかの程度に対応しています。下図のP_osがZeRO1、P_os+gがZeRO2、P_os+g+pがZeRO3に対応しています。stage 1から3に上がるに連れて、必要なメモリ量(図中の緑やオレンジや青の面積)が小さくなっていることが確認できます。

大規模言語モデルなどはモデル自体をGPUに乗せるのが大変なのですが、学習時には勾配(上図のg)やOptimizer State (上図のos。FP32で保持された勾配やモメンタム、分散の情報)を扱う必要があり、特にOSはFP32で保持している上にモメンタムなど多種類の情報を持っているのでモデルパラメータ(FP16)の8倍のサイズになります。以下はZeRO1の処理をあらわすアニメーションです(下で詳しく説明します)

DeepSpeed ZeRO1を完全理解する

ZeRO1では、一番巨大なOSをGPUごとに分割して保持することによって必要なメモリサイズを1/4に削減可能です。この記事では、DeepSpeedがまだ生まれたばかりでコードが理解しやすいDeepSpeed v0.1.0のコードを元に、内部実装を見ていきたいと思います。v0.1.0ではZero1の技術しか実装されていませんが、最新のコードと基本は変わっていないため、ひとまずZero1を理解するには十分です。また、ZeRO2, 3についてもZeRO1を拡張した考え方に基づいて実装されているので、まずは最もシンプルなZeRO1を理解することが重要です。

プロセスの起動

ZeROの処理に入る前に、DeepSpeedの起動部分から見てみます。DeepSpeedの起動はdeepspeedコマンドにPythonファイルを渡すことで行えます。例えば、こちらのサンプルコードを動かすには、以下のようにコマンドを実行します。

$ deepspeed cifar10_deepspeed.py --deepspeed_config ds_config.json

deepspeedコマンドの内部では何が起きているのでしょうか。このコマンドを実行するとdeepspeed_launch.pyが呼ばれます。この処理では分散処理のノード数や全体のGPU数(world_size)などを確認し、GPUごとに複数のプロセスを起動します。例えば、ノード数=1でGPU数=2という状況だと、以下のように2つのプロセスが起動されます。それぞれのコマンドには--local_rankという引数が渡され、これがノード内での識別番号になっています。ちなみに、global_rankは全ノードにおけるプロセスの識別番号を表します。

/opt/conda/bin/python -u cifar10_deepspeed.py --local_rank=0 --deepspeed_config ds_config.json
/opt/conda/bin/python -u cifar10_deepspeed.py --local_rank=1 --deepspeed_config ds_config.json
複数プロセスを起動する部分のコード

DeepSpeedモデルの作成

前節の処理のよって、自分の書いたPythonプログラムが複数プロセスで起動します。DeepSpeedの使い方サンプルコードで示したように、通常のPytorchモデルの学習と同じようにモデルの作成を行い、これをdeepspeed.initialize関数に渡し、DeepSpeed用にラップされたモデル(これをDeepSpeedモデルと呼ぶことにします)とoptimizerが得られます。下記コードにあるように、initialize関数の中では、DeepSpeedLightというクラスのオブジェクトを作成しています。

DeepSpeedLightを作成するinitialize関数

DeepSpeedLightクラスはPytorchモデル(nn.Module)を継承しており、forward()backward()step()メソッドなど、学習時に必要となる基本的なメソッドを備えています。以下にDeepSpeedLightクラスを簡略化したコードを示します。

class DeepSpeedLight(Module):

  def __init__(self, args, model, ...):
    ...
    # pytorchのdist(分散処理モジュール)のバックエンドをNCCLに設定
    dist.init_process_group(backend="nccl")
    
    # GPUのデバイス、world_size、global_rankの設定
    self._init_distributed(dist_init_required)
    
    # モデルをデバイスに転送して同期する
    self._configure_distributed_model(model)
    
    # optimizerの設定
    self._configure_optimizer(optimizer, model_parameters)
  
  def forward(self, *inputs, **kwargs):
    loss = self.module(*inputs, **kwargs)
    return loss
  
  def backward(self, loss):
    # optimizerでpartitionごとに勾配を計算
    self.optimizer.backward(loss)
    
    self.allreduce_gradients()
  
  def step(self):
    # モデルパラメータ(重み)の更新
    self.optimizer.step()
    
    # 勾配の初期化
    self.optimizer.zero_grad()

初期化を行う__init__メソッドでは、以下のように分散学習の準備を行います。

GPUのデバイス、world_size、global_rankの設定を行う
モデルをデバイスに転送して同期する処理

また、_configure_optimizer()では、FP16_DeepSpeedZeroOptimizerインスタンスを作成して一般的なOptimizerをDeepSpeed用に拡張しています。ZeRO1の場合、Optimizer Stateの分割を行うのでエッセンスとなる処理はこのクラスに実装されています。

https://github.com/microsoft/DeepSpeed/blob/c61e23b4b108df2af0dda7939ee59d4ae9090415/deepspeed/pt/deepspeed_light.py#L520-L531

DeepSpeed ZeRO1の流れ

ここでようやくZeRO1の処理のフローを確認するための準備ができました。ここからがこの記事の本番です!下のアニメーションはZeRO1の処理の全体像を示しています。

簡単に処理をまとめると以下のようになります。

  1. optimizerの初期化。optimizer stateに分割したパラメータだけを保持する。
  2. forward()を実行してlossを計算する
  3. backward()を実行して勾配を計算する
  4. 2で計算した勾配を全体で平均化する(reduce all)
  5. step()でOptimizer Stateを計算し、モデルの重みを更新する
  6. 部分ごとに計算したモデルの重みが全体に行き渡るようにbroadcastする

0. optimizerの初期化

FP16_DeepSpeedZeroOptimizerの初期化時に、各プロセスは自分の担当しているパラメータの勾配のみを計算するように、パラメータの分割を行います。これを行っているのが、以下のコードです。

https://github.com/microsoft/DeepSpeed/blob/c61e23b4b108df2af0dda7939ee59d4ae9090415/deepspeed/pt/deepspeed_zero_optimizer.py#L135-L174

この処理では、parameterグループごとに分割の設定を行っています。paramterグループとはモデルの部分によって学習率を変えたりしたい時に、パラメータをグルーピングできるpytorchの機能ですが、ここでは単一のグループを想定して問題ありません。
最初にfp16_groupsにオリジナルの重みテンソルを格納しています。次に、fp16_groupsには様々なshapeのテンソルが入っているので、1次元のテンソルにフラット化して繋げた巨大なベクトルを作成してfp16_groups_flatに格納します。分散処理を行う時にテンソルの形を意識せずに分割や転送の処理を行いたいのでこのような操作を行っています。この巨大ベクトルをpartition数(=world_size)で分割することになるので、このベクトルのサイズはworld_sizeで割り切れる必要があります。そのために、キリが良くなるように余った領域を0で埋める処理を行っています。フラット化したテンソルをpartition数で分割したものがparallel_partitioned_fp16_groupsです。さらに、自分のpartition_idの重みテンソルを取り出してFP32に変換したものがsingle_partition_of_fp32_groupsです。最後にこのテンソルがparam_group['params']に格納されます。つまり、optimizerからすると、元々はparam_group['params']には様々な形のテンソルが入っていたのが、フラット化したテンソルの断片(自分のpartition分)が格納されたということになります。

# 元々はparam_group['params']には様々な形のテンソルが入っていたが、zero1では巨大なフラット化したテンソルの断片が格納される。
param_group['params'] = [self.single_partition_of_fp32_groups[i]]

この処理のなかで出てくる変数と格納されているものをまとめると以下となります。元々、モデルにWeight1~3というweightが入っていたという想定の図です。

  • fp16_groups: 分割前の重みパラメータが格納されているリスト
  • fp16_groups_flat: fp16_groupsの複数の重みを1次元にフラット化して繋げた一つのテンソル。ただし、次元の大きさはworld_sizeで割り切れるように、0ベクトルでpaddingされている。
  • parallel_partitioned_fp16_groups: partitionごとに分割した重みのフラット化したテンソルが入っている。
  • single_partition_of_fp32_groups: 自分のpartitionの重みと勾配のフラット化したテンソルが入っている。requires_grad=Trueとして、勾配情報も保持している。

1. forward()を実行してlossを計算

DeepSpeedではデータ並列処理を行うので、各プロセスに渡ってくるデータのミニバッチは異なっています。これを入力として学習ステップ内にてforward()を呼び、lossの値を計算します。ここまでは普通のpytorchの学習方法を変わりはありません。

# DeepSpeedLightの定義

  def forward(self, *inputs, **kwargs):
    loss = self.module(*inputs, **kwargs)
    return loss

2-3. backward()を実行して勾配を計算

モデルのbackward()を実行することで、各モデル重みの勾配が計算されます。ただし、各プロセスが計算する勾配は担当のデータに関する勾配なので、すべての勾配を足し合わせて平均化します。この処理を行っているのが、allreduce_gradients()です。

# DeepSpeedLightの定義

  def backward(self, loss):
    # optimizerでpartitionごとに勾配を計算
    self.optimizer.backward(loss)
    
    # 全プロセスで計算した勾配の平均をとる
    self.allreduce_gradients()

具体的に勾配の計算を行なっている処理は以下のコードで、分散処理用のdist.all_reduce()関数を使うことで実現しています。

https://github.com/microsoft/DeepSpeed/blob/c61e23b4b108df2af0dda7939ee59d4ae9090415/deepspeed/pt/deepspeed_light.py#L819-L844

4-5 step()でOptimizer Stateを計算し、モデルの重みを更新する

以下のコードにあるように、DeepSpeedLightのstep()が呼ばれると、optimizer.step()がその中で呼ばれます。

# DeepSpeedLightの定義

  def step(self):
    # モデルパラメータ(重み)の更新
    self.optimizer.step()
    
    # 勾配の初期化
    self.optimizer.zero_grad()

それでは、FP16_DeepSpeedZeroOptimizerのstep()の実装を見てましょう。

https://github.com/microsoft/DeepSpeed/blob/c61e23b4b108df2af0dda7939ee59d4ae9090415/deepspeed/pt/deepspeed_zero_optimizer.py#L343-L441

重要な部分だけを見ていくと、以下のコードにあるように、get_flat_partition()関数を使ってself.params_in_partition[i]からフラット化された情報を取り出しています。

https://github.com/microsoft/DeepSpeed/blob/c61e23b4b108df2af0dda7939ee59d4ae9090415/deepspeed/pt/deepspeed_zero_optimizer.py#L373-L382

params_in_partitionのテンソルはフラット化されていないので、このテンソルの勾配をフラット化して、自分の担当の範囲を切り出すという操作を行っています。

get_flat_partition()関数の定義

この後、DeepSpeedにラップされる前のオリジナルのoptimizer、つまりoptimizerself.optimizer.step()を呼ぶことによって、フラット化された重み(つまりsingle_partition_of_fp32_groupsに対応するテンソル)の更新を行っています。

https://github.com/microsoft/DeepSpeed/blob/c61e23b4b108df2af0dda7939ee59d4ae9090415/deepspeed/pt/deepspeed_zero_optimizer.py#L388

最後に、各パーティションで更新された重みをpartition間で共有し合っています。

https://github.com/microsoft/DeepSpeed/blob/c61e23b4b108df2af0dda7939ee59d4ae9090415/deepspeed/pt/deepspeed_zero_optimizer.py#L399-L425

ここでは、各partition内の重みをnum_shard個のシャードに分割しています。このシャード単位でdist.all_gather()によって他のpartitionとデータのやり取りを行っています。

この操作によってZero1の重みが更新され、一連の動きをすべて説明することができました。ここまでの手順を追うことで、ZeRO1の仕組みがコードレベルでおおまかに理解できると思います。詳細は各コードのGithubのコードのリンク先を参照してください。

まとめ

今回はZeRO1の技術についてアニメーションとコードを交えながら内部の仕組みを説明しました。内部で使われている操作は、テンソルをフラット化して分割したり、pytorchのdistパッケージを使ってデータを相互に共有したりしているだけで、すごく複雑なことをしているわけではありません。また機会があればZeRO2やZeRO3についても解説記事を書きたいと思います。

今回の記事を書くにあたり、以下の動画を大いに参考にしました。 Hands-on Tutorials DeepSpeed 02 @KDD2020

採用情報

Turing では自動運転モデルの学習や、自動運転を支えるための基盤モデルの作成のために分散並列学習の知見を取り入れた研究開発を行っています。興味がある方は、Turing の公式 Web サイト採用情報などをご覧ください。話を聞きたいという方はやAIチームのディレクターの山口さんの Twitter DM からでもお気軽にご連絡ください。

Tech Blog - Turing

Discussion