🥣
【機械学習メモ】混合精度の利用
混合精度とは
混合精度とは、16bitと32bit浮動小数点型の両方をモデルに使ってモデルのトレーニングを高速化し、使用するメモリを少なくする手法。
混合精度を利用することで、少ないリソースの中で学習を可能にする。
イメージは以下の記事の図が分かりやすい。
各bitはIEEE 浮動小数点演算規格において、以下をさす。
※浮動小数点演算=実数をコンピュータで処理(演算や記憶、通信)するために有限桁の小数で近似値として扱う方式
- 単精度 : 32bit
- 倍精度 : 64bit
- 半精度 : 16bit
TensorFlowの混合精度のガイドを動かしてみる
TensorFlowが提供するガイドを参考に、Kerasを利用して混合精度を用いた学習を行ってみる。
モデルの作成
# ライブラリのインストール
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision
# 混合精度を利用するためのdtypeポリシーを設定
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
# inputの指定
# GPU利用時、混合精度によるメリットを確認できるよう4096ユニットのDenseレイヤを形成
inputs = keras.Input(shape=(784,), name='digits')
num_units = 4096
dense1 = layers.Dense(num_units, activation='relu', name='dense_1')
x = dense1(inputs)
dense2 = layers.Dense(num_units, activation='relu', name='dense_2')
x = dense2(x)
# outputの指定
# outputは32bitで出力する
x = layers.Dense(10, name='dense_logits')(x)
outputs = layers.Activation('softmax', dtype='float32', name='predictions')(x)
# モデルの作成
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=keras.optimizers.RMSprop(),
metrics=['accuracy'])
※Denseレイヤでは全結合層を形成し、このレイヤーに入力されたデータのすべてのユニットが次の層のすべてのユニットと結合される。
Dense層については以下の記事を参考。
学習
MNISTのデータセットを利用して学習を行う。
# MNISTのデータを利用して学習する
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
initial_weights = model.get_weights()
history = model.fit(x_train, y_train,
batch_size=8192,
epochs=5,
validation_split=0.2)
test_scores = model.evaluate(x_test, y_test, verbose=2)
print('Test loss:', test_scores[0])
print('Test accuracy:', test_scores[1])
以下の実行結果が得られた。
Epoch 1/5
6/6 [==============================] - 3s 126ms/step - loss: 2.3732 - accuracy: 0.4058 - val_loss: 0.7861 - val_accuracy: 0.7893
Epoch 2/5
6/6 [==============================] - 0s 66ms/step - loss: 1.0039 - accuracy: 0.7030 - val_loss: 0.4294 - val_accuracy: 0.8905
Epoch 3/5
6/6 [==============================] - 0s 67ms/step - loss: 0.5551 - accuracy: 0.8302 - val_loss: 0.5790 - val_accuracy: 0.8098
Epoch 4/5
6/6 [==============================] - 0s 64ms/step - loss: 0.4478 - accuracy: 0.8620 - val_loss: 0.2923 - val_accuracy: 0.9076
Epoch 5/5
6/6 [==============================] - 0s 62ms/step - loss: 0.3981 - accuracy: 0.8697 - val_loss: 0.3019 - val_accuracy: 0.9081
313/313 - 1s - loss: 0.3142 - accuracy: 0.9020 - 666ms/epoch - 2ms/step
Test loss: 0.3141573965549469
Test accuracy: 0.9020000100135803
「126ms/step」の部分が各ステップの所要時間を示す。
ステップごとの実行時間について、混合精度を利用することで高速化できる。
混合精度の利用なしの場合はこのような実行結果。
Epoch 1/5
6/6 [==============================] - 4s 289ms/step - loss: 2.4416 - accuracy: 0.3534 - val_loss: 0.8367 - val_accuracy: 0.7702
Epoch 2/5
6/6 [==============================] - 1s 236ms/step - loss: 0.8884 - accuracy: 0.7355 - val_loss: 0.4825 - val_accuracy: 0.8613
Epoch 3/5
6/6 [==============================] - 1s 242ms/step - loss: 0.5855 - accuracy: 0.8029 - val_loss: 0.5534 - val_accuracy: 0.8153
Epoch 4/5
6/6 [==============================] - 1s 239ms/step - loss: 0.4164 - accuracy: 0.8741 - val_loss: 0.4726 - val_accuracy: 0.8491
Epoch 5/5
6/6 [==============================] - 1s 240ms/step - loss: 0.3455 - accuracy: 0.8935 - val_loss: 0.2465 - val_accuracy: 0.9298
313/313 - 1s - loss: 0.2530 - accuracy: 0.9247 - 624ms/epoch - 2ms/step
Test loss: 0.252986341714859
Test accuracy: 0.9247000217437744
Discussion