NGレシート検知のための、画像分類モデルの作成とCoreMLへの変換
はじめに
WED株式会社でMLエンジニアをしています、ishi2kiです。
WEDでは、レシート買取アプリONEを通して、ユーザからレシート画像を買い取っています。そして買い取ったレシート情報は、購買活動の分析等に利用しています。
ここで重要になるのが、レシート画像からの情報抽出 (OCR) を精度良く行うことです。店舗名や商品名を正しく取得することができなければ、レシートを買い取っても、それを分析に繋げることができません。
(例えば、このような画像だと何も情報を得られません。)
OCRが上手くいかない理由はいくつかありますが、今回は以下の理由に着目しました。
- そもそもレシートじゃない
- ピンボケしている
- 暗すぎる
- レシートがぐしゃぐしゃになっている
- レシートが2枚以上写っている
これらは、ユーザの撮影の仕方によって解決 (あるいは影響を軽減) することが可能です。
そこで、ユーザの撮影画像が上記の点に該当していないかをアプリ上で判定し、もしいずれかに該当していれば再撮影を促す、という処理をONEに加えることを考えました。
本記事では、NG判定モデルの作成、および、作成モデルをiOSで使うための、CoreMLモデルへの変換について説明します。
NGレシート判定モデル
前項で5つのNG理由を紹介しましたが、今回は学習データの集めやすさの観点から以下2つに絞りました。
- そもそもレシートじゃない (not receipt)
- ピンボケしている (blurry)
この2つに加えて、問題のないレシート画像 (ok) を加えた3つのクラスに画像を分類する多クラス分類モデルを作成します。
なお、モデルの作成はPyTorchを使って行い、その後、CoreMLというiOSで動作するモデルへの変換を行います。
学習データ
WEDには、これまで買い取ってきたレシート画像とOCRの結果が多くあります。それらの中からOCRで文字が取得できなかった画像を中心にラベル付けを行い、各クラス1,896枚ずつ集めました。
モデル
アプリ上に組み込むことを考えると、モデルサイズは可能な限り小さい方が望ましいです。
そこでMobileNet V3というモデルを採用しました。これは、まさに、モバイル端末に特化した画像分類モデルであり、Smallであればモデルサイズは10MBもありません。
また、一から学習するとなると大量の学習データが必要となってしまうため、ImageNetで事前学習されたモデルをベースにしてファインチューニングを行うことにしました。
from torchvision.models import mobilenet_v3_small
weights = 'IMAGENET1K_V1'
model = mobilenet_v3_small(weights=weights)
# クラス数を合わせるため、出力層のみ再定義
labels = ['ok', 'not receipt', 'blurry']
label_num = len(labels)
model.classifier[3] = torch.nn.Linear(1024, label_num)
学習結果
学習データの90%を訓練に使用し、残りの10%でテストを行いました。
結果は下の表のようになりました。
- Accuracy: 95.1%
- Accuracy (OK or NG): 98.9%
なお、Accuracy (OK or NG) は、not receiptとblurryの区別をなくし、1つのNGというクラスだと捉えた場合の精度です。
かなり高い結果を出せているのではないでしょうか。
誤推定した画像の確認も行いましたが、軽度のピンボケなど判定の難しいものが多かったです。
CoreMLへの変換
PyTorchでNG判定モデルを作ることができたので、これをiOSで動作させるためにCoreMLに変換します。
まずは、学習したPyTorchのモデルをロードしますが、ここで出力層の後ろにSoftmaxを追加しています。これは、SwiftでCoreMLを扱うライブラリが、モデルの出力値が0〜1に分布していることを前提にしているためです。
import torch
base_model = mobilenet_v3_small(num_classes=3)
base_model.load_state_dict(torch.load('model_path'))
class NGCheckModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
base_model,
torch.nn.Softmax(dim=1)
)
def forward(self, x):
return self.layers(x)
torch_model = NGCheckModel().eval()
続いて、TorchScriptを経由してCoreMLへ変換します。
# TorchScriptに変換
example_input = torch.rand(1, 3, 224, 224)
traced_model = torch.jit.trace(torch_model, example_input)
# CoreMLに変換
class_labels = ['ok', 'not receipt', 'blurry']
classifier_config = ct.ClassifierConfig(class_labels)
image_input = ct.ImageType(shape=example_input.shape,
color_layout=ct.colorlayout.RGB)
mlmodel = ct.convert(traced_model,
convert_to='neuralnetwork',
inputs=[image_input],
classifier_config=classifier_config)
以上で、iOS上で動作するモデルの作成は完了です。
アプリ上でNG判定を行うことができるようになりました!
おわりに
今回、レシートじゃない画像とピンボケしている画像を判定するモデルを作成しました。
しかし、冒頭で説明したように、NGとみなしたいレシートはまだ残っています。
リリースに向け、今後さらにモデルを強化していく予定です。
Discussion