🌴
Kerasで学習済みモデルを読み込んで一部分だけ実行する方法
Kerasでモデルや重みの保存・読み込み方法の解説記事はあるが、全体を保存して一部分だけを実行するやり方が載っていなかったのでメモ。
自力で見つけた方法なので、もっといいやり方があるかもしれない。
実行環境
TensorFlow内部のKerasを使用したので、バックエンドとしてTensorFlowが動作する。
- Ubuntu 18.04.3 LTS
- Python 3.6.9
- tensorflow-gpu 2.2.0 (Keras 2.3.0-tf)
前提:モデル作成と保存
エンコーダ‐デコーダモデルが最もポピュラーな例だと思うので、それを採用した例。
例ではFunctionalモデルだが、Sequentialモデルでも同じはず。
from tensorflow import keras
encoder_input = keras.layers.Input((32, ))
encoder_output = keras.layers.Dense(16)(encoder_input)
encoder = keras.Model(encoder_input, encoder_output)
decoder_input = keras.layers.Input((16, ))
decoder_output = keras.layers.Dense(32)(decoder_input)
decoder = keras.Model(decoder_input, decoder_output)
model_input = keras.layers.Input((32, ))
model = keras.Model(model_input, decoder(encoder(model_input)))
# Run: model.fit()
with open('path/to/model/file.json', 'w') as f:
f.write(model.to_json())
model.save_weights('path/to/weights/file.h5')
encoder
やdecoder
は、model
に含まれるので保存する必要はない。
コード
読み込んだmodel
のメンバ変数であるlayers
がレイヤーの実体を指しているので、添字を指定して読み込むだけで構成要素のモデルにアクセスできる。
with open('path/to/model/file.json', 'r') as f:
model = keras.models.model_from_json(f.read())
model_model.save_weights('path/to/weights/file.h5')
encoder = model.layers[1] # これ
decoder = model.layers[2]
# Run: encoder.predict()
添字はmodel.summary()
を実行すれば、何番目か確認できる。
>>> model.summary()
Model: "model_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_3 (InputLayer) [(None, 32)] 0
_________________________________________________________________
model (Model) (None, 16) 528 <- Encoder
_________________________________________________________________
model_1 (Model) (None, 32) 1056 <- Decoder
=================================================================
Total params: 1,072
Trainable params: 1,072
Non-trainable params: 0
_________________________________________________________________
モデルは同一の実体を指しているので、model
で追加学習をしてもencoder
とdecoder
に反映される。
あとがき
間違いがあったらごめんね。
Discussion