👋

ONNX Runtime をつかって Web ブラウザで機械学習モデル推論してみる

に公開

web ブラウザで機械学習モデルを動かしてみたいと思い、試行錯誤してみました。成果物としては、 React アプリ上で画像認識をします。

完成形

推論は開発に利用している m1 mac で、 0.1 秒ほどで完了します。この方法では、モデルをフロントエンドで動かす都合上、モデルを公開することになります。

ONNX Runtime

いろいろ方法がありますが、今回は ONNX Runtime というライブラリを使うことにしました。
onnx(オニキス)は機械学習モデルのオープンなフォーマットです。Opern Nerural Network の略です。pytorch などで実装したモデルをこの形式で出力することができます。
ONNX Runtime は onnx 形式のモデルをクロスプラットフォームで実行する環境です。cpu や gpu をつかって実行することができます。

https://onnxruntime.ai/

基本情報

以下の点で選びました。

  • マイクロソフトが管理
  • 開発が活発
  • ドキュメントが充実
  • pytorch で作成したモデルを実行できる(ONNX 対応)

他にも類似のものがいくつかでてきましたが以下の理由で今回は ONIX Runtime にしました。

今回は ONNX Runtime をつかって Web ブラウザで機械学習モデル推論してみます。

実践

こちらの Next.js で書かれたサンプルをもとに実装していきます。

https://github.com/microsoft/onnxruntime-nextjs-template

最終更新日が 3 年前と古く、使っているパッケージが現環境でうまく動作しない問題があるので適宜書き換えて利用します。

モデルの用意

モデルは画像認識できるものを選びます。
今後は自分でつくったモデルを実行してみたいですが、まずは既存のものを利用します。

ONNX Model Zoo https://github.com/onnx/models というサイトで onix 形式で既存のモデルが並んでいます。

今回は画像認識モデルの Squeezenet を選びました。AlexNet なみの性能でパラメータ数が大幅に少ないモデルです。

参考にしたサンプルもこちらを使っていたので採用しました。同じモデルを使うことで後の画像の前処理の実装は流用できます。onnx runtime はモデルの実装によっては対応していないものもあるそうなので、他のモデルを利用するときは、この点も考慮が必要です。

Squeezenet は以下でダウンロードできます。
https://github.com/onnx/models/blob/main/validated/vision/classification/squeezenet/model/squeezenet1.1-7.tar.gz

react web アプリを用意

vite で react アプリの雛形を作成します。

npm create vite@latest test_ts_react_onnx  -- --template react-ts
cd test_ts_react_onnx
npm i
npm run dev

初期画面が表示されます。
react初期画面

onnix runtime を読み込む

onnix runtime を読み込みます。合わせてその他の react アプリに必要なパッケージも追加します。

package.json に必要なパッケージを追記します。

{
  "name": "test_ts_react_onnx",
  "private": true,
  "version": "0.0.0",
  "type": "module",
  "scripts": {
    "dev": "vite",
    "build": "tsc -b && vite build",
    "lint": "eslint .",
    "preview": "vite preview"
  },
  "dependencies": {
    "@emotion/react": "^11.14.0",
    "@emotion/styled": "^11.14.0",
    "@mui/material": "^6.3.0",
    "@types/lodash": "^4.17.13",
    "lodash": "^4.17.21",
    "onnxruntime-web": "^1.20.1",
    "react": "^18.3.1",
    "react-dom": "^18.3.1"
  },
  "devDependencies": {
    "@eslint/js": "^9.17.0",
    "@types/react": "^18.3.18",
    "@types/react-dom": "^18.3.5",
    "@vitejs/plugin-react": "^4.3.4",
    "eslint": "^9.17.0",
    "eslint-plugin-react-hooks": "^5.0.0",
    "eslint-plugin-react-refresh": "^0.4.16",
    "globals": "^15.14.0",
    "typescript": "~5.6.2",
    "typescript-eslint": "^8.18.2",
    "vite": "^6.0.5"
  }
}

インストールします。

npm i

モデルを public ディレクトリに配置

さきほどダウンロードした zip ファイルを解凍し、squeezenet1.1.onnx ファイルを public ディレクトリに配置します。
このモデルをプログラム上で利用します。

モデル実行用 util プログラムの作成

以下 3 つのファイルを作成します

  • predict.ts
    • react ui 側から推論を開始するための関数を定義します。predict.ts から以下 2 つのファイルの関数を呼び出します。
  • imageHelper.ts
    • 画像をモデルに渡すために必要な形式に変換します。具体的には[3, 224, 224]のテンソルにします。
  • modelHelper.ts
    • モデルを読み込んで、推論を実行します。

predict.ts

画像をテンソル化し、それをモデルに渡し、推論結果と推論時間を返します。

こちらはサンプルのコードほぼそのままです。linter の警告だけ無効化しています。

import { getImageTensorFromPath } from "./imageHelper";
import { runSqueezenetModel } from "./modelHelper";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export async function inferenceSqueezenet(
  path: string
): Promise<[any, number]> {
  // 1. Convert image to tensor
  const imageTensor = await getImageTensorFromPath(path);
  // 2. Run model
  const [predictions, inferenceTime] = await runSqueezenetModel(imageTensor);
  // 3. Return predictions and the amount of time it took to inference.
  return [predictions, inferenceTime];
}

imageHelper.ts

画像を読み込み、[3, 224, 224]のテンソルに変換します。

サンプルで多用されている jimp というパッケージは、バージョンが古く現バージョンではそのまま動かないことや、vite では最新版の jimp でも動作しない問題があります。
今回は、canvas を利用する方法に書き換えることで対応しました。

import { Tensor } from "onnxruntime-web";

export async function getImageTensorFromPath(
  path: string,
  dims: number[] = [1, 3, 224, 224]
): Promise<Tensor> {
  // 1. load the image
  const image = await loadImageFromPath(path);
  // 2. convert to tensor
  const imageTensor = imageDataToTensor(image, dims);
  // 3. return the tensor
  return imageTensor;
}

async function loadImageFromPath(path: string): Promise<HTMLImageElement> {
  return new Promise((resolve, reject) => {
    const img = new Image();
    img.crossOrigin = "anonymous"; // CORSポリシー対応が必要な場合
    img.onload = () => resolve(img);
    img.onerror = reject;
    img.src = path;
  });
}

function imageDataToTensor(image: HTMLImageElement, dims: number[]): Tensor {
  const [width, height] = [dims[2], dims[3]];

  // Canvas作成
  const canvas = document.createElement("canvas");
  canvas.width = width;
  canvas.height = height;
  const ctx = canvas.getContext("2d");

  if (!ctx) {
    throw new Error("Failed to get 2D context");
  }

  // 画像をリサイズしてCanvasに描画
  ctx.drawImage(image, 0, 0, width, height);

  // ピクセルデータを取得
  const imageData = ctx.getImageData(0, 0, width, height);
  const data = imageData.data;

  // R, G, Bチャンネルの配列を作成
  const [redArray, greenArray, blueArray] = [
    new Array<number>(),
    new Array<number>(),
    new Array<number>(),
  ];

  // RGBチャンネルを分離
  for (let i = 0; i < data.length; i += 4) {
    redArray.push(data[i]);
    greenArray.push(data[i + 1]);
    blueArray.push(data[i + 2]);
    // アルファチャンネルはスキップ
  }

  // RGB配列を結合して[3, 224, 224]の形式に変換
  const transposedData = redArray.concat(greenArray).concat(blueArray);

  // Float32Arrayに変換
  const float32Data = new Float32Array(dims[1] * dims[2] * dims[3]);
  for (let i = 0; i < transposedData.length; i++) {
    float32Data[i] = transposedData[i] / 255.0;
  }

  // Tensorを作成して返す
  return new Tensor("float32", float32Data, dims);
}

modelHelper.ts

モデルを読み込み、実行し、結果から必要なトップ 4 のデータをラベルと合わせて取得します。
ラベルは data/imagenet.tsを配置し、参照しています。

こちらもサンプルほぼそのままです。さきほど public ディレクトリに配置した squeezenet1.1.onnx を読み込んでいます。

import * as ort from "onnxruntime-web";
import _ from "lodash";
import { imagenetClasses } from "../data/imagenet";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export async function runSqueezenetModel(
  preprocessedData: any
): Promise<[any, number]> {
  console.log("Running Squeezenet model");
  // Create session and set options. See the docs here for more options:
  //https://onnxruntime.ai/docs/api/js/interfaces/InferenceSession.SessionOptions.html#graphOptimizationLevel
  const session = await ort.InferenceSession
    // public ディレクトリからの相対パス
    .create("/squeezenet1.1.onnx", {
      executionProviders: ["webgl"],
      graphOptimizationLevel: "all",
    });
  console.log("Inference session created");
  // Run inference and get results.
  const [results, inferenceTime] = await runInference(
    session,
    preprocessedData
  );
  return [results, inferenceTime];
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
async function runInference(
  session: ort.InferenceSession,
  preprocessedData: any
): Promise<[any, number]> {
  // Get start time to calculate inference time.
  const start = new Date();
  // create feeds with the input name from model export and the preprocessed data.
  const feeds: Record<string, ort.Tensor> = {};
  feeds[session.inputNames[0]] = preprocessedData;

  // Run the session inference.
  const outputData = await session.run(feeds);
  // Get the end time to calculate inference time.
  const end = new Date();
  // Convert to seconds.
  const inferenceTime = (end.getTime() - start.getTime()) / 1000;
  // Get output results with the output name from the model export.
  const output = outputData[session.outputNames[0]];
  //Get the softmax of the output data. The softmax transforms values to be between 0 and 1
  const outputSoftmax = softmax(Array.prototype.slice.call(output.data));

  //Get the top 5 results.
  const results = imagenetClassesTopK(outputSoftmax, 5);
  console.log("results: ", results);
  return [results, inferenceTime];
}

//The softmax transforms values to be between 0 and 1
// eslint-disable-next-line @typescript-eslint/no-explicit-any
function softmax(resultArray: number[]): any {
  // Get the largest value in the array.
  const largestNumber = Math.max(...resultArray);
  // Apply exponential function to each result item subtracted by the largest number, use reduce to get the previous result number and the current number to sum all the exponentials results.
  const sumOfExp = resultArray
    .map((resultItem) => Math.exp(resultItem - largestNumber))
    .reduce((prevNumber, currentNumber) => prevNumber + currentNumber);
  //Normalizes the resultArray by dividing by the sum of all exponentials; this normalization ensures that the sum of the components of the output vector is 1.
  return resultArray.map((resultValue) => {
    return Math.exp(resultValue - largestNumber) / sumOfExp;
  });
}
/**
 * Find top k imagenet classes
 */
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function imagenetClassesTopK(classProbabilities: any, k = 5) {
  const probs = _.isTypedArray(classProbabilities)
    ? Array.prototype.slice.call(classProbabilities)
    : classProbabilities;

  // eslint-disable-next-line @typescript-eslint/no-explicit-any
  const sorted = _.reverse(
    _.sortBy(
      probs.map((prob: any, index: number) => [prob, index]),
      (probIndex: Array<number>) => probIndex[0]
    )
  );

  const topK = _.take(sorted, k).map((probIndex: Array<number>) => {
    const iClass = imagenetClasses[probIndex[1]];
    return {
      id: iClass[0],
      index: parseInt(probIndex[1].toString(), 10),
      name: iClass[1].replace(/_/g, " "),
      probability: probIndex[0],
    };
  });
  return topK;
}

これで、predict.ts の inferenceSqueezenet 関数を実行すれば、画像を渡して推論ができます。

react ui から推論を行う

次は推論対象の画像を表示して、推論ボタンを押して、推論を実行する画面を作成していきます。

完成形

以下のような画面を作成します。

  • 「No IMAGE」の部分に推論対象の画像が表示されます。
  • 「RANDOM IMAGE」ボタンを押すと、あらかじめ用意された画像がランダムに選ばれます。
  • 「CHOOSE IMAGE」ボタンを押すと、端末上の任意の画像を選択することができます。
  • 「CLASSIFY IMAGE」を押すと推論が実行されます。
    alt text

推論実行後の画面は以下のようになります。
下に推論結果と、確率、推論時間が表示されます。
alt text

画像を用意する

public ディレクトリ以下に適当な画像を用意します。
今回は、no_image_square.jpg と推論用の猫の画像などを配置しました。
alt text

App.tsx を編集する

以下のように App.tsx を変更します。

変数 image には推論対象の画像が入ります。
関数 setRandomImage や関数 chooseImage は、変数 image に画像をそれぞれの方法で設定します。
関数 classifyImage で、先に util として用意した関数 inferenceSqueezenet を呼び出して、推論を行います。

ui は mui パッケージをつかって構成しています。

import { useState, useCallback, useRef } from "react";
import "./App.css";
import { Box, Button, Container, Typography } from "@mui/material";
import { inferenceSqueezenet } from "./utils/predict";

function App() {
  const [image, setImage] = useState<string | null>(null);
  const noImage = "/no_image_square.jpg";

  const [imageIndex, setImageIndex] = useState<number>(0);
  const imageList = ["necklace.webp", "strawberry.webp", "cat.webp"];

  const [resultLabel, setResultLabel] = useState<string>("");
  const [resultConfidence, setResultConfidence] = useState<string>("");
  const [inferenceTime, setInferenceTime] = useState<string>("");
  const [error, setError] = useState<string | null>(null);

  const fileInputRef = useRef<HTMLInputElement | null>(null);

  const setRandomImage = (currentImageIndex: number) => {
    let newIndex = Math.floor(Math.random() * imageList.length);
    while (newIndex === currentImageIndex) {
      newIndex = Math.floor(Math.random() * imageList.length);
    }
    setImageIndex(newIndex);
    setImage(imageList[newIndex]);
  };

  const chooseImage = () => {
    if (fileInputRef.current) {
      fileInputRef.current.click();
    }
  };

  const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
    const file = e.target.files?.[0];
    if (file) {
      const reader = new FileReader();
      reader.onload = () => {
        setImage(reader.result as string);
      };
      reader.readAsDataURL(file);
    }
  };

  const classifyImage = useCallback(async () => {
    if (!image) {
      setError("Please select an image first.");
      return;
    }

    setResultLabel("Inferencing...");
    setResultConfidence("");
    setInferenceTime("");
    setError(null);

    try {
      const [inferenceResult, inferenceTime] = await inferenceSqueezenet(image);

      const topResult = inferenceResult[0];
      setResultLabel(topResult.name.toUpperCase());
      setResultConfidence(topResult.probability);
      setInferenceTime(`Inference speed: ${inferenceTime} seconds`);
    } catch (error) {
      console.error("Error during inference:", error);
      setError("Error during inference");
    }
  }, [image]);

  return (
    <>
      <Container maxWidth="sm">
        <Typography variant="h3">React Image Classification App</Typography>

        <img src={image ?? noImage} alt="Selected Image" height={240} />
        <Box display="flex" justifyContent="center">
          <Button
            variant="contained"
            color="primary"
            onClick={() => setRandomImage(imageIndex)}
          >
            Random Image
          </Button>
          <Box width={16} />
          <Button variant="contained" color="primary" onClick={chooseImage}>
            CHOOSE IMAGE
          </Button>
          <input
            type="file"
            accept="image/*"
            ref={fileInputRef}
            style={{ display: "none" }}
            onChange={handleFileChange}
          />
        </Box>

        <Box height={16} />

        <Button variant="contained" color="primary" onClick={classifyImage}>
          Classify Image
        </Button>

        <Box height={16} />

        <Typography variant="h4">Results:</Typography>
        <Typography variant="body1">{resultLabel}</Typography>
        <Typography variant="body1">{resultConfidence}</Typography>
        <Typography variant="body1">{inferenceTime}</Typography>
        {error && (
          <Typography variant="body1" color="error">
            {error}
          </Typography>
        )}
      </Container>
    </>
  );
}

export default App;

これで完成です。

おわりに

web ブラウザで機械学習モデルを動かしてみました。react web app 上で、画像認識をすることができました。

今後、自作の機械学習モデルや他のモデルを動かす際には、前処理や、結果の取り出し部分の実装がさらに必要になります。このあたりを効率よく実装できる方法を調査したいです。

Discussion