🥣

【機械学習メモ】混合精度の利用

2024/07/06に公開

混合精度とは

混合精度とは、16bitと32bit浮動小数点型の両方をモデルに使ってモデルのトレーニングを高速化し、使用するメモリを少なくする手法。
混合精度を利用することで、少ないリソースの中で学習を可能にする。

イメージは以下の記事の図が分かりやすい。
https://qiita.com/MotonobuHommi/items/f12a500d6c475ce59790#3-mixed-precision

各bitはIEEE 浮動小数点演算規格において、以下をさす。
※浮動小数点演算=実数をコンピュータで処理(演算や記憶、通信)するために有限桁の小数で近似値として扱う方式

  • 単精度 : 32bit
  • 倍精度 : 64bit
  • 半精度 : 16bit

TensorFlowの混合精度のガイドを動かしてみる

TensorFlowが提供するガイドを参考に、Kerasを利用して混合精度を用いた学習を行ってみる。

モデルの作成

# ライブラリのインストール
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision

# 混合精度を利用するためのdtypeポリシーを設定
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy) 

# inputの指定
# GPU利用時、混合精度によるメリットを確認できるよう4096ユニットのDenseレイヤを形成
inputs = keras.Input(shape=(784,), name='digits')
num_units = 4096

dense1 = layers.Dense(num_units, activation='relu', name='dense_1')
x = dense1(inputs)
dense2 = layers.Dense(num_units, activation='relu', name='dense_2')
x = dense2(x)

# outputの指定
# outputは32bitで出力する
x = layers.Dense(10, name='dense_logits')(x)
outputs = layers.Activation('softmax', dtype='float32', name='predictions')(x)

# モデルの作成
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop(),
              metrics=['accuracy'])

※Denseレイヤでは全結合層を形成し、このレイヤーに入力されたデータのすべてのユニットが次の層のすべてのユニットと結合される。
Dense層については以下の記事を参考。
https://qiita.com/Ishotihadus/items/c2f864c0cde3d17b7efb#ニューラルネットを作る
https://zenn.dev/mi_ztyanya/books/9bcdf19bb90504/viewer/f19568

学習

MNISTのデータセットを利用して学習を行う。

# MNISTのデータを利用して学習する
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
initial_weights = model.get_weights()

history = model.fit(x_train, y_train,
                    batch_size=8192,
                    epochs=5,
                    validation_split=0.2)
test_scores = model.evaluate(x_test, y_test, verbose=2)
print('Test loss:', test_scores[0])
print('Test accuracy:', test_scores[1])

以下の実行結果が得られた。

Epoch 1/5
6/6 [==============================] - 3s 126ms/step - loss: 2.3732 - accuracy: 0.4058 - val_loss: 0.7861 - val_accuracy: 0.7893
Epoch 2/5
6/6 [==============================] - 0s 66ms/step - loss: 1.0039 - accuracy: 0.7030 - val_loss: 0.4294 - val_accuracy: 0.8905
Epoch 3/5
6/6 [==============================] - 0s 67ms/step - loss: 0.5551 - accuracy: 0.8302 - val_loss: 0.5790 - val_accuracy: 0.8098
Epoch 4/5
6/6 [==============================] - 0s 64ms/step - loss: 0.4478 - accuracy: 0.8620 - val_loss: 0.2923 - val_accuracy: 0.9076
Epoch 5/5
6/6 [==============================] - 0s 62ms/step - loss: 0.3981 - accuracy: 0.8697 - val_loss: 0.3019 - val_accuracy: 0.9081
313/313 - 1s - loss: 0.3142 - accuracy: 0.9020 - 666ms/epoch - 2ms/step
Test loss: 0.3141573965549469
Test accuracy: 0.9020000100135803

「126ms/step」の部分が各ステップの所要時間を示す。
ステップごとの実行時間について、混合精度を利用することで高速化できる。

混合精度の利用なしの場合はこのような実行結果。

Epoch 1/5
6/6 [==============================] - 4s 289ms/step - loss: 2.4416 - accuracy: 0.3534 - val_loss: 0.8367 - val_accuracy: 0.7702
Epoch 2/5
6/6 [==============================] - 1s 236ms/step - loss: 0.8884 - accuracy: 0.7355 - val_loss: 0.4825 - val_accuracy: 0.8613
Epoch 3/5
6/6 [==============================] - 1s 242ms/step - loss: 0.5855 - accuracy: 0.8029 - val_loss: 0.5534 - val_accuracy: 0.8153
Epoch 4/5
6/6 [==============================] - 1s 239ms/step - loss: 0.4164 - accuracy: 0.8741 - val_loss: 0.4726 - val_accuracy: 0.8491
Epoch 5/5
6/6 [==============================] - 1s 240ms/step - loss: 0.3455 - accuracy: 0.8935 - val_loss: 0.2465 - val_accuracy: 0.9298
313/313 - 1s - loss: 0.2530 - accuracy: 0.9247 - 624ms/epoch - 2ms/step
Test loss: 0.252986341714859
Test accuracy: 0.9247000217437744
GitHubで編集を提案

Discussion