TensorFlow を用いて、自前学習したモデルを Mobile Web で動かしてみた
TensorFlow を用いて、自前学習したモデルを Mobile Web で動かしてみた
犬猫判別用の教師データがたくさん転がっているので、これを TensorFlow を用いて機械学習させ、それを Mobile Web 上で JS を使って動かすまでをやってみました。
準備物
- Python3 の動く学習用マシン。 GPU を積んでいることが望ましいが、今回は CPU 24コアだけで頑張った。
- 出来上がった JS を配信する Web サーバ。残念ながら、 file: では動かない(らしい)。
- 犬猫教師データ。 zip ファイルとして https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip で配布されている。
学習準備
犬猫教師データを unzip する。それぞれ train(学習用)と validation(検証用)のデータフォルダがあり、混ぜないで運用しなければいけないが、今回はそれを無視して、数が多いというだけで validation フォルダの中のみを使用することに。フォルダ構造は以下のようにすでに cats, dogs ラベルで分けられた状態になっている。
validation
+--cats
| + 写真データ
| :
+--dogs
+ 写真データ
:
学習
以下の学習用 Python3 スクリプトを作成。足りないモジュールは、必要に応じて pip install して下さい。
import tensorflow as tf
import os
# from tensorflow import keras
from tensorflow.keras.applications.mobilenet import MobileNet
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D
# from tensorflow.keras import optimizers
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing import image_dataset_from_directory
IMAGE_SIZE = 224
def get_data():
# 画像ファイル含むデータセットダウンロード
train_dir = os.path.dirname(
"/home/tam/cats_and_dogs_filtered/validation/")
BATCH_SIZE = 32
IMG_SIZE = (IMAGE_SIZE, IMAGE_SIZE)
# 訓練データセット作成
train_dataset = image_dataset_from_directory(
train_dir, shuffle=True, batch_size=BATCH_SIZE, image_size=IMG_SIZE)
return train_dataset
def create_model():
# Lambdaを使わずにMobaileNetのpreprocess_inputを再現することが
# Sequential APIでは難しそうなので、自前で関数を作ってnetworkの外に切り出す
# Sequential APIでモデルを定義する
model = Sequential()
base_model = MobileNet(
input_shape=[IMAGE_SIZE, IMAGE_SIZE, 3],
include_top=False, weights='imagenet')
base_model.trainable = False
# conv_dw_11以降のLaylerを学習対象にする
for layer in base_model.layers[67:]:
layer.trainable = True
model.add(base_model)
model.add(GlobalAveragePooling2D())
model.add(Dense(1, activation='sigmoid'))
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
# モデルのアーキテクチャを出力
model.summary()
return model
model = create_model()
epochs = 10
train_dataset = get_data()
# model.load_weights('model')
history = model.fit(train_dataset,
epochs=epochs)
model.save_weights('model')
肝となるのは create_model() 関数の中のモデル作成部分。何も考えずに作ってはいけなくて、最初に作っていたモデルでは、 tensorflow.js 形式に変換しても、 JS 側での読み込み時に以下のようなエラーを吐いて、うまく動いてくれませんでした。
Uncaught (in promise) Error: Unknown layer: TFOpLambda. This may be due to one of the following reasons:
1. The layer is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.
2. The custom layer is defined in JavaScript, but is not registered properly with tf.serialization.registerClass().
at jN (generic_utils.js:242)
at GI (serialization.js:31)
at u (container.js:1197)
at e.fromConfig (container.js:1225)
at jN (generic_utils.js:277)
at GI (serialization.js:31)
at models.js:295
at u (runtime.js:45)
at Generator._invoke (runtime.js:274)
at Generator.forEach.t.<computed> [as next] (runtime.js:97)
というわけで参考にしたのが以下のページ。
TensorFlow.jsを使ってKerasのモデルを動かすWebアプリを作ってみました。
これを参照しながら、少しいじってモデルを作成したらうまくいきました。
このスクリプトは終了するたびに、学習モデルを保存するので、一回実行して保存したあと、
# model.load_weights('model')
のコメント行を外して、保存した学習モデルを読み込んで再学習するようにします。
ここからはシェルでスクリプトを回し続けて、パソコンさんにいっぱい勉強してもらいます。
JS 用に変換
tensorflow.js では、 tensorflow 用の学習モデルをそのままでは読み込めないらしいので、以下のスクリプトで読み込める形式に変換します。
# import numpy as np
import tensorflow as tf
import tensorflowjs as tfjs
# from tensorflow import keras
from tensorflow.keras.applications.mobilenet import MobileNet
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D
# from tensorflow.keras import optimizers
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing import image_dataset_from_directory
IMAGE_SIZE = 224
def create_model():
# Lambdaを使わずにMobaileNetのpreprocess_inputを再現することが
# Sequential APIでは難しそうなので、自前で関数を作ってnetworkの外に切り出す
# Sequential APIでモデルを定義する
model = Sequential()
base_model = MobileNet(
input_shape=[IMAGE_SIZE, IMAGE_SIZE, 3],
include_top=False, weights='imagenet')
base_model.trainable = False
# conv_dw_11以降のLaylerを学習対象にする
for layer in base_model.layers[67:]:
layer.trainable = True
model.add(base_model)
model.add(GlobalAveragePooling2D())
model.add(Dense(1, activation='sigmoid'))
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
# モデルのアーキテクチャを出力
model.summary()
return model
model = create_model()
model.load_weights('model')
tfjs.converters.save_keras_model(model, "./tfjs_model")
カレントディレクトリに tfjs_model が出来れば成功です。
Web サーバ上へ配置
先程作成した tfjs_model と以下の index.html, script.js を Web サーバへ配置します。
これらの HTML と JS も
TensorFlow.jsを使ってKerasのモデルを動かすWebアプリを作ってみました。
を参考にさせてもらいました。
<!DOCTYPE html>
<html>
<head>
<title>cats_vs_dogs TEST</title>
<!--Import TensorFlow.js-->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script>
</head>
<body>
<h1>cats_vs_dogs model test</h1>
<p id="status"></p>
<video autoplay playsinline muted id="webcam" width="224" height="224"></video>
<p id="result"></p>
<!--Import app.js-->
<script src="./script.js"></script>
</body>
</html>
const webcamElement = document.getElementById('webcam');
async function run() {
document.getElementById("status").textContent = "model load";
const model = await tf.loadLayersModel("./tfjs_model/model.json");
console.log(model.summary());
document.getElementById("status").textContent = "finish";
const webcam = await tf.data.webcam(webcamElement);
while (true) {
const img = await webcam.capture();
const result = await model.predict(tf.expandDims(img.div(tf.scalar(127.5).sub(tf.scalar(1)))));
// console.log(result.dataSync());
const arrDogScore = await Array.from(result.dataSync());
if (arrDogScore[0] > 0.5) {
document.getElementById("result").textContent = "イヌです";
} else {
document.getElementById("result").textContent = "ネコです";
}
img.dispose();
// Give some breathing room by waiting for the next animation frame to
// fire.
await tf.nextFrame();
}
}
document.addEventListener('DOMContentLoaded', run());
結果
Web サーバにアクセスして、カメラを許可したところ、無事犬猫判定が出来ました。
Discussion