🖼

転移学習のためダウンロードした学習済みモデルをGoogleDriveに保存して使いまわす

2020/10/18に公開

Google Driveで転移学習を使った画像解析をする際、学習済みモデルを毎回インストールするのは、固定回線のない自分には通信容量の無駄だな、と思ったので、学習済みモデルをマウントしたGoogle Driveに保存して、2回目以降はGoogle Driveからロードして節約するコードを書いてみましたので、まとめます。

コードはGoogle Colaboはこちら

GitHubはこちら

Google Driveのマウント

まず、Google Driveをマウントします。

やり方がわからない場合はこちらのサイトをご覧ください。

参考サイト
ColaboratoryでのGoogle Driveへのマウントが簡単になっていたお話

from google.colab import drive
drive.mount('/content/drive')

学習済みモデルの読込

Kerasを使用しているので、keras.models.load_modelを使って、h5型のファイルに保存します。

参考サイト
Kerasのモデルを保存する方法

Pytorch等の他のディープラーニングモデルを使用している場合は、下記のコードは参考程度にとどめ下さい。

import os

# keras のインポート
from tensorflow import keras
from keras.models import load_model

既に保存したモデルがある場合は、学習済みモデルを再びGoogle Driveに保存しようとしてパケットを無駄遣いしないよう、os.path.isfile()を使って保存済みのモデルがあるかどうか確認させます。

下記のコードはマウントしたGoogleDrive直下にモデルを保存していますが、パスは任意に変更ください。

なお、学習済みモデルはResNet50を使用しています。

他の学習済みモデルを試したい場合は、こちらのサイトをご参照ください。

# すでに保存済みのモデルがGoogle Drive内にある場合は保存済みのモデル(h5型)を読み込む
if os.path.isfile('model.h5'):
    model = load_model('model.h5')

# 保存済みのモデルがない場合は学習済みモデルをダウンロードする
else:
    pre_model = keras.applications.resnet50.ResNet50(
    weights='imagenet', 
    input_shape=(224, 224, 3),
    include_top=False
    )
    pre_model.save('model.h5')

参考サイト
Python 指定ファイルの存在をチェックし、無ければファイルを作成する。有れば追記する。

学習済みモデルが指定したパス内にない場合は、ダウンロードが行われ、下記のように表示されます。

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
94773248/94765736 [==============================] - 1s 0us/step  

summaryを実行すると、ちゃんと以下のように表示されます。

pre_model.summary()
Model: "resnet50"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________

()

conv5_block3_add (Add)          (None, 7, 7, 2048)   0           conv5_block2_out[0][0]           
                                                                 conv5_block3_3_bn[0][0]          
__________________________________________________________________________________________________
conv5_block3_out (Activation)   (None, 7, 7, 2048)   0           conv5_block3_add[0][0]           
==================================================================================================
Total params: 23,587,712
Trainable params: 23,534,592
Non-trainable params: 53,120
__________________________________________________________________________________________________

以上になります、最後までお読みいただきありがとうございました。

Discussion