Automatic Mixed Precision を理解してみた
はじめに
私は普段深層学習を用いて、画像分類や物体検出のタスクを行なっています。
その際に学習時間や推論時間を短縮するために高速化を行いたい場面があり、その際に利用した技術 Automatic Mixed Precisionについてより深く理解するために調べたことをまとめます。
この記事の内容
Automatic Mixed Precisionの理解に必要な以下の要素について述べ、Automatic Mixed Precisionの理解を深めます。
- Single-Precision / Half-Precision
- Mixed-Precision Computing
- Mixed-Precision Training
- Apex
Single-Precision / Half-Precision
Automatic Mixed Precisionの基本的な考えは、パラメータやデータの情報量を減らすことです。
情報量を減らすことには、主に以下の2つのメリットがあります。
-
メモリ使用量の削減: データを表現するために必要なビット数が少なくなるため、メモリ使用量を削減できます。例えば、Single-Precision (32bit) を Half-Precision (16bit) にすると、単純計算でメモリ使用量は半分になります。これにより、より多くのデータをGPUメモリに格納できるようになり、バッチサイズを大きくできる可能性が生まれます。大きなバッチサイズは、学習の安定化や高速化に寄与することが知られています。
-
計算の高速化: 近年のGPUは、16bit演算を32bit演算よりも高速に処理できるアーキテクチャを採用しています。例えば、NVIDIAのVoltaアーキテクチャ以降では、16bit演算専用のTensorコアが搭載されており、32bit演算と比較して大幅な高速化を実現しています。これは、データの転送量が削減されること、また、より少ないビット数で演算を行うことで、計算回路が簡略化され、処理が高速化されるためです。
これらの理由から、深層学習において、パラメータやデータの情報量を削減することで、学習や推論処理の高速化が期待できます。
コンピュータで数値を扱う場合、その精度(情報をどこまで保持するかのレベル)は、計算速度やメモリ使用量に影響を与えます。より多くの情報を保持すれば正確な計算が可能ですが、計算に時間がかかり、メモリも多く消費します。一方、精度を下げると、計算速度は向上し、メモリ使用量は削減されますが、計算結果の正確さは低下します。
「どこまで情報を保持するかのレベル」をprecisionと呼びます。
多くの情報を保持すれば正確な計算が可能ですが、計算量が多くなります。
コンピュータで数値は2進数のビットで扱いますので、小数の情報をどこまで保持するかではなくビット情報をどこまで保持するかになります。
IEEE 浮動小数点演算標準 / IEEE 754では、コンピュータ上で数値を2進数で表現するための一般的な規則が示されています。
こちらにはいくつかの表現方法が記載されていますが、よく利用される表現方法を記載します。
- Single-Precision (32bit)
- Double-Precision (64bit)
- Half-Precision (16bit)
ビット表現
仮数と指数の部分の情報を調整することで、情報量を減らすことができます。
情報量を減らすと、以下のようなメリットが得られます。
- 計算の高速化: 扱うデータ量が減るため、メモリからのデータ読み込みや演算処理が速くなります。
- メモリ使用量の削減: パラメータや中間データを格納するために必要なメモリ容量を削減できます。これにより、より大きなモデルを学習したり、より大きなバッチサイズで学習したりすることが可能になります。
Mixed-Precision Computing
パラメータやデータの情報量を減らすだけでは正確さが落ちてしまいますが、これを有効的に活用する方法がMixed-Precision Computingになります。
Mixed-Precision Computingとは、ひとつの処理の中で異なるPrecision(例:Half-PrecisionとSingle-Precision)を使い分けることで、計算精度を大きく損なうことなく処理速度を向上させる手法です。
一般的に、精度を下げると、計算速度は向上しますが、計算結果の正確さは低下します。つまり、精度と速度はトレードオフの関係にあります。Mixed-Precision Computingは、このトレードオフをうまく調整し、精度の低下を最小限に抑えつつ、速度の大幅な向上を目指す手法です。深層学習においては、必ずしもすべての計算を高精度で行う必要はなく、部分的に精度を下げることで、全体のパフォーマンスを向上させられる可能性があります。
(処理ごとに、Precisionレベルを変える手法は、Multi-Precision Computingといいます)
処理例としてはHalf-Precisionから計算を開始して数値が計算されるについてより高いレベルのPrecisionで値を保持します。
具体的な計算例としては、「2つの16ビット行列を掛け合わせた結果は、32ビットのサイズで情報を保持する」などになります。
Mixed-Precision Computingは計算速度と精度のバランスを取るための有効な手段です。
しかし、これを深層学習の学習プロセスにそのまま適用するには、いくつかの課題が存在します。
Mixed-Precision Training
ここからは、Mixed-Precision Computingを深層学習の学習処理に導入することを考えます。
以下ような項目が課題になります。
- (課題1) : Half-Precision同士の演算の結果をHalf-Precisionの範囲のみで表現すると誤差が生じる
- (課題2) : Half-Precisionでは表現可能な数値の範囲が狭く、小さな値と大きな値の加算演算などでアンダーフローが発生しやすくなる
- (課題3) : Half-Precisionでは表現可能な数値の範囲が狭く、小さな値を保持できずオーバーフローが発生しやすくなる
これらに対して、いくつかの対策を考えます。
対策1:一部の情報をSingle-Precisionで保持する
一部の情報をSingle-Precisionで保持することで、課題1を解決します。
Half-Precisionの演算結果をSingle-Precisionとして保持し、Single-Precision同士の演算でも低速にならない加算などの演算ではSingle-Precsionを維持し、必要に応じて(後続の演算がHalf-Precisionに最適化されている場合やメモリ使用量を削減したい場合など)Half-Precisionに変換するという考えです。
このようにすることで、演算で生じる誤差を最小限にしながら、計算を高速化することが可能です。
対策2:FP32のWeightsパラメータのマスターコピーを準備する
想定される課題2のような状況は、Weightパラメータの更新の際に起こりえます。
学習率などの計算後、Weightの勾配はWeightに比べて小さくなることが多いです。
スケールが大きく異なる数値同士の加算演算でHalf-Precisionを利用すると結果に変化が起こらず、Weightパラメータが変更されないことになります。
例えば、「Weighパラメータが1.0、その勾配が0.00001の場合、Half-Precisionでは0.00001を表現できず、加算結果が1.0のまま変わらない」のような状況です。
そのために、Weightパラメータのマスターコピーを用意し、Half-PrecisionのWeightパラメータで計算された勾配を用いて、Single-Precisionのマスタコピーの更新を行います。
対策3:Loss Scaling
逆伝播の処理に利用される活性化関数の出力に対する勾配(活性化勾配)は、小さな値になる傾向があることがわかっています。(課題3)
上の図は、SSDをSingle-Precisionでトレーニングしたときに記録された活性化勾配のヒストグラムです。Y 軸は、ログスケール上のすべての値のパーセンテージです。X 軸は、絶対値のログスケールです。
つまり、活性化勾配の情報はビット下部分に偏っており、Half-Precisonにすることで多くの情報を失うことになります。
Lossにスケーリング係数Sを乗算することで、勾配の値をS倍し、ビット表現を上位にシフトさせることで重要な情報を失わないようにします。
注意しなければならい点は、LossにSを乗算したため勾配もS倍されるのでパラメータ更新前に勾配を1/S倍する必要がある点です。
処理フロー
上の3つの手法を以下のように導入します。
(FP16 = Half-Precision / FP32 = Single-Pricision)
- FP16のWeightパラメータなどを利用して順伝播の処理を行う。
- LossにスケールSを乗算する。
- FP16のWeightパラメータや勾配を利用して逆伝播の処理を行う。
- Weightパラメータの勾配をFP32に変換する。
- Weightパラメータの勾配に1/Sを乗算する。
- FP32のWeightパラメータのマスタコピーを更新する。
- FP32のマスターコピーから、FP16でコピーする。(次の順伝播ではFP16のWeightパラメータを使用するため)
このような処理フローで行うことで、精度を落とさず処理速度を向上させることができます。
Apex
Apexは2018年に、NVIDIAが開発したSingle PrecisionからMixed Precisionに自動的に変換できるPyTorchの拡張機能です。
PyTorch 1.6からはapex.amp
からtorch.cuda.amp
にAMPのパッケージが改善されていますが、ここではApexの基本的な機能について整理します。
Apexの機能では演算を3つグループに分け、それらに対応した処理を行います。
- whitelist : Half-Precisionにすることで高速化が見込まれる関数 (行列の乗算や畳み込み演算など)
- blacklist : Half-Precisionでは、精度に支障がある可能性のある関数(softmax関数などの損失関数)
- その他 : どちらにも当てはまらない関数
Apexは関数が呼び出されるたびにその関数が上の3つのどのグループに該当するか確認します。
その後、以下の対応を行います。
- whitelistである場合 : 関数への入力をすべてHalf-Precisionにキャストする
- blacklistである場合 : 関数への入力をすべてSingle-Precisionにキャストする
- どちらでもない場合 : 関数への入力が全て同じPrecisionであることを確認する(同じPrecisionでなければ、精度の高いPrecisionに合わせてキャストする)
このようにすることで、自動的にPrecisionレベルを決定することができます。
結局 Automatic Mixed Precision とは
ここまでの内容をまとめると以下のようになります。
- Single-PrecisionとHalf-Precsionを混合させたMixed-Precision Computingを深層学習の学習処理に導入する
- その際に発生する課題に対して適切な処理を加え精度の低下を抑える
- Apexなどの機能を用いて、Mixed Precisionをコードの修正を最小限に、自動的に適応させる
Discussion