転移学習のためダウンロードした学習済みモデルをGoogleDriveに保存して使いまわす
Google Driveで転移学習を使った画像解析をする際、学習済みモデルを毎回インストールするのは、固定回線のない自分には通信容量の無駄だな、と思ったので、学習済みモデルをマウントしたGoogle Driveに保存して、2回目以降はGoogle Driveからロードして節約するコードを書いてみましたので、まとめます。
コードはGoogle Colaboはこちら。
GitHubはこちら。
Google Driveのマウント
まず、Google Driveをマウントします。
やり方がわからない場合はこちらのサイトをご覧ください。
参考サイト
ColaboratoryでのGoogle Driveへのマウントが簡単になっていたお話
from google.colab import drive
drive.mount('/content/drive')
学習済みモデルの読込
Kerasを使用しているので、keras.models.load_modelを使って、h5型のファイルに保存します。
参考サイト
Kerasのモデルを保存する方法
Pytorch等の他のディープラーニングモデルを使用している場合は、下記のコードは参考程度にとどめ下さい。
import os
# keras のインポート
from tensorflow import keras
from keras.models import load_model
既に保存したモデルがある場合は、学習済みモデルを再びGoogle Driveに保存しようとしてパケットを無駄遣いしないよう、os.path.isfile()を使って保存済みのモデルがあるかどうか確認させます。
下記のコードはマウントしたGoogleDrive直下にモデルを保存していますが、パスは任意に変更ください。
なお、学習済みモデルはResNet50を使用しています。
他の学習済みモデルを試したい場合は、こちらのサイトをご参照ください。
# すでに保存済みのモデルがGoogle Drive内にある場合は保存済みのモデル(h5型)を読み込む
if os.path.isfile('model.h5'):
model = load_model('model.h5')
# 保存済みのモデルがない場合は学習済みモデルをダウンロードする
else:
pre_model = keras.applications.resnet50.ResNet50(
weights='imagenet',
input_shape=(224, 224, 3),
include_top=False
)
pre_model.save('model.h5')
参考サイト
Python 指定ファイルの存在をチェックし、無ければファイルを作成する。有れば追記する。
学習済みモデルが指定したパス内にない場合は、ダウンロードが行われ、下記のように表示されます。
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
94773248/94765736 [==============================] - 1s 0us/step
summaryを実行すると、ちゃんと以下のように表示されます。
pre_model.summary()
Model: "resnet50"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_2 (InputLayer) [(None, 224, 224, 3) 0
__________________________________________________________________________________________________
(略)
conv5_block3_add (Add) (None, 7, 7, 2048) 0 conv5_block2_out[0][0]
conv5_block3_3_bn[0][0]
__________________________________________________________________________________________________
conv5_block3_out (Activation) (None, 7, 7, 2048) 0 conv5_block3_add[0][0]
==================================================================================================
Total params: 23,587,712
Trainable params: 23,534,592
Non-trainable params: 53,120
__________________________________________________________________________________________________
以上になります、最後までお読みいただきありがとうございました。
Discussion