JINSテックブログ
🤖

TypeScriptでTensorflowモデルを手元で動かした

2024/12/03に公開

この投稿は、2024年JINSのアドベントカレンダー3日目の記事です。👓

JOIN記事に憧れがあったので前半は自己紹介、後半は表題の件について紹介します。

自己紹介

10月にJINSにJOINしました、ITデジタル部のいしざき(@ishizak1111)です。
前職では主にPythonで製品開発を含めたデータエンジニアリングをしていました。
現在は主に購入/保証システムの刷新やデータ移行を担当しています。

今後は海外店舗のシステム導入にも携わる予定なので、ドメイン知識や新旧システム仕様のインプットを頑張っています。

動機

自己紹介で述べてしまいましたが、現在JINSはシステム刷新プロジェクトの最中です。
新システムのデータモデリングをしている際に、「自分が作ったシステムで収集したデータをバリバリ活用したいな〜🤖」という気持ちに。

JINSにジョインしてからTypeScriptを使い始めたので、
せっかくならPythonで検討したモデルをTypeScriptで動かしちゃおう!ということで挑戦しました。

Pythonでモデルを作る

コード全体像

model4ts.py
import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"
import tensorflow as tf

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)

model.evaluate(x_test,  y_test, verbose=2)

import tensorflowjs as tfjs
tfjs.converters.save_keras_model(model, "saved_model/")

擦りに擦られたMNIST判定器です。
TensorFlow公式を参考にしました。

Kerasは2系じゃないとtensorflow-jsがモデルロード時に上手く読み込めないので、os.environ["TF_USE_LEGACY_KERAS"] = "1"としています。(https://github.com/tensorflow/tfjs/issues/8321)

tensorflow-jsはjsonのモデルを読み込むので、以下のようにモデルをエクスポートします。

import tensorflowjs as tfjs
tfjs.converters.save_keras_model(model, "saved_model/")

Typescriptでモデルを動かす

コード全体像

tftest.ts
import * as tf from "@tensorflow/tfjs";
import * as fs from "node:fs";
import * as tfn from "@tensorflow/tfjs-node";

const loadModelAndImage = async () => {
  const model_path = "./saved_model/model.json";

  const handler = tfn.io.fileSystem(model_path);
  const model = await tf.loadLayersModel(handler);
  const filepath = "./000000-num7.png";
  const data = fs.readFileSync(filepath);
  let array = new Uint8Array(data);
  const ftensor = tfn.node.decodePng(array, 1).toFloat();
  const reshaped_normalized_tensor = tf.reshape(ftensor, [28, 28]).expandDims().div(255);

  const predictions = model.predict(reshaped_normalized_tensor) as tf.Tensor;

  const topK = 5;
  const topKIndices = tf.topk(predictions, topK).indices.dataSync();

  console.log("Top", topK, "predictions:");
  for (let i = 0; i < topKIndices.length; i++) {
    console.log(`#${i + 1}: ${topKIndices[i]}`);
  }
};

const result = loadModelAndImage();

以下の部分でモデルを読み込んでいます。

const model_path = "./saved_model/model.json";

const handler = tfn.io.fileSystem(model_path);
const model = await tf.loadLayersModel(handler);

公式だとhttpでモデルを引っ張ってきたり、localstrageやネイティブファイルシステムからモデルを読み込んでいるのですが、今回は手元にあるモデルで動かしたかったのでこんな記述になりました。
この1行(const handler = tfn.io.fileSystem(modelPath);)を見つけるのに異常に手間取りました。

続いてモデルに読み込ませるために手元にある画像を読み込む部分

const filepath = "./000000-num7.png";
const data = fs.readFileSync(filepath);
let array = new Uint8Array(data);
const ftensor = tfn.node.decodePng(array, 1).toFloat();
const reshaped_normalized_tensor = tf.reshape(ftensor, [28, 28]).expandDims().div(255);

const predictions = model.predict(reshaped_normalized_tensor) as tf.Tensor;

これもTensorFlow公式ではブラウザの要素?フォーム?から画像読み込んで云々みたいな感じだったので、いい感じに変換するのに苦労しました。
多分もうちょっと良い感じの関数ありそうですが、自分は諦めてコネコネしてます。

そして以下が実行結果。確信度順に推論結果が表示されています。

$ node test.js
Platform node has already been set. Overwriting the platform with node.
Top 5 predictions:
#1: 7
#2: 3
#3: 9
#4: 0
#5: 5

ちゃんと7が推論されているので成功っぽい 🤸‍♂️
TensorflowとTensorflow-jsの推論速度の違いも確認したかったですが、またいつか...!!

失敗したこと (おまけ)

記事ネタ探しの段階では『Pythonのsklearnでサクッとモデル作って、TypescriptでPoC作ろう!』をテーマにしていたのですが、言語間でのモデルファイル(pickleファイル)のRead/Writeがうまく行かず断念しました。

参考リンク

MNIST: https://www.tensorflow.org/tutorials/quickstart/beginner?hl=ja

tfjsのissue: https://github.com/tensorflow/tfjs/issues/8321

手元のモデルファイルを読み込む: https://github.com/tensorflow/tfjs/issues/4568

手元の画像ファイルを読み込む: https://stackoverflow.com/questions/65650770/how-do-i-load-an-local-image-in-javascript-for-tensorflow

JINSテックブログ
JINSテックブログ

Discussion