📮

PyTorch と onnxruntime-web でつくるリアルタイム郵便番号推論アプリ

に公開

1. はじめに

おなじみ MNIST を使って、数字推論アプリを作成しました。

デモサイト
https://nyamadamadamada.github.io/predict-number-app/

GitHub
https://github.com/Nyamadamadamada/predict-number-app

  • 概要:郵便番号を手書きで入力し、認識した数字から該当地域を表示するアプリ
  • モデル作成:Python,PyTorch,ONNX
  • データセット:MNIST
  • UI:onnxruntime-web,React,TypeScript,Vite,Chakra UI

なぜ作った?

機械学習の基礎を学ぶ中で、CNN モデルを構築し、「実際に動くアプリを作りたい!」と思い、作成しました。

2. 全体アーキテクチャ設計

構成図

onnxruntime-webを使うことで、ブラウザ上で処理が完結しています。
ユーザーの入力をモデルに渡し、推論結果をもとに郵便番号辞書(JSON)から地域名を取得します。

3. モデル設計

モデルの中身

torchinfoより、モデルのサマリーを表示した結果は次です。

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
SimpleCNN                                [1, 10]                   --
├─Sequential: 1-1                        [1, 32, 7, 7]             --
│    └─Conv2d: 2-1                       [1, 16, 28, 28]           160
│    └─ReLU: 2-2                         [1, 16, 28, 28]           --
│    └─Dropout2d: 2-3                    [1, 16, 28, 28]           --
│    └─MaxPool2d: 2-4                    [1, 16, 14, 14]           --
│    └─Conv2d: 2-5                       [1, 32, 14, 14]           4,640
│    └─ReLU: 2-6                         [1, 32, 14, 14]           --
│    └─Dropout2d: 2-7                    [1, 32, 14, 14]           --
│    └─MaxPool2d: 2-8                    [1, 32, 7, 7]             --
├─Sequential: 1-2                        [1, 10]                   --
│    └─Linear: 2-9                       [1, 128]                  200,832
│    └─ReLU: 2-10                        [1, 128]                  --
│    └─Dropout: 2-11                     [1, 128]                  --
│    └─Linear: 2-12                      [1, 10]                   1,290
==========================================================================================

SGD vs Adam 比較の実験計画

最適化アルゴリズムは Adam を使用しました。
SGD と精度を比較すると次のようになりました。

モデル 精度 最終の損失関数
SGD 97.4% 0.814
Adam 99.1% 0.273

Adam の方が精度が高いことがわかります。

また、グラフにすると次のようになります。
エポック1の段階で、Adamの方が損失関数が小さいことがわかります。

ONNX 量子化

ブラウザでモデルを使用するにあたり、ファイル容量を軽量化することで、
処理速度を高速化できる可能性があります。

INT8 に動的量子化を行い、4 分の 1 に削減できました。

モデル サイズ
FP32 810 KB
INT8(動的) 221 KB

ソースコード

from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
    model_input="model.fp32.onnx",
    model_output="model.int8.dynamic.onnx",
    weight_type=QuantType.QInt8,
    op_types_to_quantize=['MatMul', 'Gemm'],  # Conv を外す
)

op_types_to_quantizeオプションにて、量子化するオペレータタイプを指定します。
「MatMul(行列積)」と「Gemm(一般化行列積)」の演算を量子化の対象とし、
「Conv(畳み込み)」をはずしています。

一般的に、CNNは静的量子化の方が推奨されています。(動的量子化は毎回scale/zero-pointを計算するため)
しかし、今回はより手軽に導入できる動的量子化を採用しました。

https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html

4. ブラウザ上での推論

描画データを Canvas から取得後、28×28 ピクセルにリサイズし、0〜1 に正規化して Tensor に変換します。
モデル出力に、Softmax 関数 を適用して各クラスの確率を算出し、最大値のインデックスを予測値として返します。

おわりに

本アプリでは、PyTorch モデルを ONNX 形式に変換し、onnxruntime-webを用いてブラウザ内で高速推論を実現しました。

実装当初はなにもわからず、「グレースケールはチャンネルが1とは一体?」「なぜフロントエンドで softmax 関数を使っているんだ?」とわからないことだらけでしたが、
実装していくにつれて、理解が深まりました。

数字の推論という非常に初歩的なアプリでしたが、学びは多く、楽しく実装できました。

参考

https://qiita.com/Alesion30/items/27713d7a65dc2d12b259

Deep learning で画像認識 ④〜畳み込みニューラルネットワークの構成〜
https://lp-tech.net/articles/LVB9R

畳み込みニューラルネットワークの精度向上
https://free.kikagaku.ai/tutorial/basic_of_computer_vision/learn/pytorch_convolution_technic

Pytorch – Fashion-MNIST で CNN モデルによる画像分類を行う
https://pystyle.info/pytorch-cnn-based-classification-model-with-fashion-mnist/

書籍

ゼロから作る Deep Learning
https://www.oreilly.co.jp/books/9784873117584/

Python と Keras によるディープラーニング
https://book.mynavi.jp/ec/products/detail/id=90124

Discussion