🌴

Kerasで学習済みモデルを読み込んで一部分だけ実行する方法

2 min read

Kerasでモデルや重みの保存・読み込み方法の解説記事はあるが、全体を保存して一部分だけを実行するやり方が載っていなかったのでメモ。
自力で見つけた方法なので、もっといいやり方があるかもしれない。

実行環境

TensorFlow内部のKerasを使用したので、バックエンドとしてTensorFlowが動作する。

  • Ubuntu 18.04.3 LTS
  • Python 3.6.9
  • tensorflow-gpu 2.2.0 (Keras 2.3.0-tf)

前提:モデル作成と保存

エンコーダ‐デコーダモデルが最もポピュラーな例だと思うので、それを採用した例。
例ではFunctionalモデルだが、Sequentialモデルでも同じはず。

from tensorflow import keras

encoder_input = keras.layers.Input((32, ))
encoder_output = keras.layers.Dense(16)(encoder_input)
encoder = keras.Model(encoder_input, encoder_output)

decoder_input = keras.layers.Input((16, ))
decoder_output = keras.layers.Dense(32)(decoder_input)
decoder = keras.Model(decoder_input, decoder_output)

model_input = keras.layers.Input((32, ))
model = keras.Model(model_input, decoder(encoder(model_input)))
# Run: model.fit()

with open('path/to/model/file.json', 'w') as f:
    f.write(model.to_json())
model.save_weights('path/to/weights/file.h5')

encoderdecoderは、modelに含まれるので保存する必要はない。

コード

読み込んだmodelのメンバ変数であるlayersがレイヤーの実体を指しているので、添字を指定して読み込むだけで構成要素のモデルにアクセスできる。

with open('path/to/model/file.json', 'r') as f:
    model = keras.models.model_from_json(f.read())
model_model.save_weights('path/to/weights/file.h5')

encoder = model.layers[1] # これ
decoder = model.layers[2]
# Run: encoder.predict()

添字はmodel.summary()を実行すれば、何番目か確認できる。

>>> model.summary()
Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         [(None, 32)]              0         
_________________________________________________________________
model (Model)                (None, 16)                528       <- Encoder
_________________________________________________________________
model_1 (Model)              (None, 32)                1056      <- Decoder
=================================================================
Total params: 1,072
Trainable params: 1,072
Non-trainable params: 0
_________________________________________________________________

モデルは同一の実体を指しているので、modelで追加学習をしてもencoderdecoderに反映される。

あとがき

間違いがあったらごめんね。

Discussion

ログインするとコメントできます