フロントエンド転移学習のデータセット、どこに置く?【Tensorflow.js】

3 min read読了の目安(約3500字

KNN Classifierを使ったときの雑メモです。

https://teachablemachine.withgoogle.com

フロントエンドでさくっと転移学習するならTeachable Machineが楽ちんなのですが、現状用意されているベースモデルは画像・音声・姿勢のみ。しかも出力されるモデルの構造が調べてもそのあたりのドキュメントないし、例えば今回使ってみたい、姿勢モデルを読み込んでクラス分けしてくれる@teachablemachine/poseも未だ依存性がtfjs@^1.3.1と出遅れをとっているので、なかなかに使いづらいです。

ということで転移学習周りも自力でやってみましょうということで以下のモデルを導入して使ってみます。

@tensorflow-models/knn-classifier

https://www.npmjs.com/package/@tensorflow-models/knn-classifier

執筆時のバージョンは1.2.2。

peer @tensorflow/tfjs-core@"^1.2.1" from @tensorflow-models/knn-classifier@1.2.2

とまあ、こちらもこういう感じなので、インストール時は--legacy-peer-depsをつけておきましょう。tfjs-core@2.7.0で動作していることは確認できました。

npmの依存性問題は色々議論あるようですが調べてる暇無かったのでどなたか教えてください。

学習は一旦Tensor化してから入力

今回はPosenetによる身体の姿勢データを、KNNによる転移学習でクラス分類させるのに使用しました。

https://editor.p5js.org/AndreasRef/sketches/RLv1QbuLa

全体的なコードはp5.jsでml5を使ってやっている上記例が参考になります。が、これは関数が内部で簡略化されているため、本来のtfjsのお作法に従う必要があります。

// poseにはPosenetの出力が入っている
const poseArray = pose.keypoints.map(p => [p.score, p.position.x, p.position.y]); // 行列にする(3次元ベクトル * 17点)
const tfPose = tf.tensor2d(poseArray); // Tensor化
classifier.addExample(tfPose, label); // ここで分類を追加
console.log('KNN class added:', classifier.getClassExampleCount());

データセットの入出力と置き方

本題です。

シリアライズ

まずJSON.stringify()をかけようとするのは当たり前なのですが、特にこのKNN Classifierにおいては保存前のシリアライズについての議論が3年ほど前から様々になされています。

https://github.com/tensorflow/tfjs/issues/633

ということで、上記issueの最後のほうで紹介されている以下のライブラリを使うことにしました。非常に安定していておすすめです。

https://www.npmjs.com/package/tensorset
// KNN分類結果のデータセットを出力する
async getDataSet() {
  return await Tensorset.stringify(this.classifier.getClassifierDataset())
}

これだけでOK。

データセットの出し入れ

mBaaS等の外部サービスをなるだけ使わないということにすると以下の3通りほどあると考えられます。

文字列コピペ

inputなどに都度貼り付けて読み込ませるやり方ですが、もちろん使いません。
クリップボードに入る量ではあるものの、リロードごとに読まなきゃいけなかったりで冗長すぎます。

Local Storage

Tensorset.stringifyの結果は文字列なので、データセットを複数のクライアントで使いまわさないのであればこれが最も手軽です。

Ajax

データセットをテキストで出力しておき、リソースとして読み込ませる方法です。
あらかじめ用意しておいたファイルでも、Form操作で読み込ませてでも使えるので汎用的です。
今回はこの方法でやります。

エクスポート(ローカルへダウンロードして保存)

以下のような関数を作り、Tensorsetで出力した文字列を渡せばダウンロードしてくれます。

function download(tensorsetString) {
  const link = document.createElement('a');
  link.download = 'poses.tensorset';
  link.href = URL.createObjectURL(new Blob([ tensorsetString ], { type: 'text.plain' }));
  link.dataset.downloadurl = ['text/plain', link.download, link.href].join(':');
  link.click();
}

インポート(サーバー配置)

今回はNuxt.jsを使ったため、データセットのファイルは~/staticに置いておくとaxiosを使って取得することができます。

axios.get('./poses.tensorset').then((res) => {
  this.knnDataSet = res.data;
  if (this.knnDataSet) {
    const dataset = Tensorset.parse(this.knnDataSet);
    this.classifier.setClassifierDataset(dataset);
    console.log('KNN Dataset available:', this.classifier.getClassExampleCount());
  }
});

一方で、Tensorflow.jsは負荷が大きい処理のためService Worker上で動かしたいことが多く(今回もそう)、ここからは直接XHRが使えないためaxiosが使いづらいです。
そこでkyというモジュールを使います。

https://zenn.dev/ukkz/articles/d622860d7cdf67

これを用いて書き換えると以下のようになりました。

(async () => {
  const response = await ky.get(self.location.origin + '/poses.tensorset');
  const reader = await response.body.getReader().read();
  const result = new TextDecoder('utf-8').decode(reader.value);
  this.knnDataSet = result;
  if (this.knnDataSet) {
    const dataset = Tensorset.parse(this.knnDataSet);
    this.classifier.setClassifierDataset(dataset);
    console.log('KNN Dataset available:', this.classifier.getClassExampleCount());
  }
})()

Service Workerを使うとモデル読み込みもデータセットの取得も比較的早く済むためとても楽ちんでした。お試しあれ。