PyTorch と onnxruntime-web でつくるリアルタイム郵便番号推論アプリ
1. はじめに
おなじみ MNIST を使って、数字推論アプリを作成しました。
デモサイト
GitHub
- 概要:郵便番号を手書きで入力し、認識した数字から該当地域を表示するアプリ
- モデル作成:Python,PyTorch,ONNX
- データセット:MNIST
- UI:onnxruntime-web,React,TypeScript,Vite,Chakra UI
なぜ作った?
機械学習の基礎を学ぶ中で、CNN モデルを構築し、「実際に動くアプリを作りたい!」と思い、作成しました。
2. 全体アーキテクチャ設計
構成図
onnxruntime-web
を使うことで、ブラウザ上で処理が完結しています。
ユーザーの入力をモデルに渡し、推論結果をもとに郵便番号辞書(JSON)から地域名を取得します。
3. モデル設計
モデルの中身
torchinfo
より、モデルのサマリーを表示した結果は次です。
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
SimpleCNN [1, 10] --
├─Sequential: 1-1 [1, 32, 7, 7] --
│ └─Conv2d: 2-1 [1, 16, 28, 28] 160
│ └─ReLU: 2-2 [1, 16, 28, 28] --
│ └─Dropout2d: 2-3 [1, 16, 28, 28] --
│ └─MaxPool2d: 2-4 [1, 16, 14, 14] --
│ └─Conv2d: 2-5 [1, 32, 14, 14] 4,640
│ └─ReLU: 2-6 [1, 32, 14, 14] --
│ └─Dropout2d: 2-7 [1, 32, 14, 14] --
│ └─MaxPool2d: 2-8 [1, 32, 7, 7] --
├─Sequential: 1-2 [1, 10] --
│ └─Linear: 2-9 [1, 128] 200,832
│ └─ReLU: 2-10 [1, 128] --
│ └─Dropout: 2-11 [1, 128] --
│ └─Linear: 2-12 [1, 10] 1,290
==========================================================================================
SGD vs Adam 比較の実験計画
最適化アルゴリズムは Adam を使用しました。
SGD と精度を比較すると次のようになりました。
モデル | 精度 | 最終の損失関数 |
---|---|---|
SGD | 97.4% | 0.814 |
Adam | 99.1% | 0.273 |
Adam の方が精度が高いことがわかります。
また、グラフにすると次のようになります。
エポック1の段階で、Adamの方が損失関数が小さいことがわかります。
ONNX 量子化
ブラウザでモデルを使用するにあたり、ファイル容量を軽量化することで、
処理速度を高速化できる可能性があります。
INT8 に動的量子化を行い、4 分の 1 に削減できました。
モデル | サイズ |
---|---|
FP32 | 810 KB |
INT8(動的) | 221 KB |
ソースコード
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input="model.fp32.onnx",
model_output="model.int8.dynamic.onnx",
weight_type=QuantType.QInt8,
op_types_to_quantize=['MatMul', 'Gemm'], # Conv を外す
)
op_types_to_quantize
オプションにて、量子化するオペレータタイプを指定します。
「MatMul(行列積)」と「Gemm(一般化行列積)」の演算を量子化の対象とし、
「Conv(畳み込み)」をはずしています。
一般的に、CNNは静的量子化の方が推奨されています。(動的量子化は毎回scale/zero-pointを計算するため)
しかし、今回はより手軽に導入できる動的量子化を採用しました。
4. ブラウザ上での推論
描画データを Canvas から取得後、28×28 ピクセルにリサイズし、0〜1 に正規化して Tensor
に変換します。
モデル出力に、Softmax 関数 を適用して各クラスの確率を算出し、最大値のインデックスを予測値として返します。
おわりに
本アプリでは、PyTorch モデルを ONNX 形式に変換し、onnxruntime-web
を用いてブラウザ内で高速推論を実現しました。
実装当初はなにもわからず、「グレースケールはチャンネルが1とは一体?」「なぜフロントエンドで softmax 関数を使っているんだ?」とわからないことだらけでしたが、
実装していくにつれて、理解が深まりました。
数字の推論という非常に初歩的なアプリでしたが、学びは多く、楽しく実装できました。
参考
Deep learning で画像認識 ④〜畳み込みニューラルネットワークの構成〜
畳み込みニューラルネットワークの精度向上
Pytorch – Fashion-MNIST で CNN モデルによる画像分類を行う
書籍
ゼロから作る Deep Learning
Python と Keras によるディープラーニング
Discussion