📱

【検証】Flutterで機械学習モデルを動かせるか(iOS編)

に公開

米国株に200円から投資できるアプリを開発しているwoodstock.clubに入社したnigowです。弊社のアプリはFlutterで書かれていますが、バックエンドではAIに金融や投資に関する質問ができるような、フィンテック領域のユニークな機能も実装されています。
https://zenn.dev/woodstock_tech/articles/cedd7eaf5843ac

入社後初めてFlutterで開発をしていて、モバイル端末で機械学習モデルを動かせないかな?と疑問に思ったので、検証してみました。なお、私が所持しているのはiPhoneなので、今回はiOS上でCoreMLを使って検証していきます。

CoreMLとNeural Engine

CoreMLはAppleが提供している機械学習フレームワークで、macOS / iOS / iPadOS / watchOS / tvOSといったApple製品上で、学習済みのモデルを効率的に実行することができます。

特徴的なのは、オンデバイスで動作する点です。つまり、モデル推論をクラウドに依存せず、ユーザーのデバイス上で完結できるため、シンプルなモデルであればオフラインでも使える・応答が速い・プライバシーを守れるという利点があります。

CoreMLはPyTorchやTensorFlowはもちろん、scikit-learnやXGBoostなど、さまざまなフレームワークで学習したモデルをサポートしており、Appleが提供するcoremltoolsを使えば簡単に変換できます。

さらにiPhoneやiPadでは、A11 Bionic以降のチップにNeural Engineという専用ハードウェアが搭載されています。これはApple独自の機械学習用アクセラレータで、主にディープラーニング系のモデルの高速処理に特化しています。iPhone X以降FaceIDで瞬時にアンロックできるのも、このNeural Engineの発達が大いに関係しています。

CoreMLでは、デフォルトで利用可能なコンピュートユニット(CPU, GPU, Neural Engine)を自動的に選んでくれますが、モデル実行時にMLComputeUnitsを使って指定することもできます。たとえば「CPUだけ使ってくれ!」という指定も可能です。

let config = MLModelConfiguration()
config.computeUnits = .cpuOnly

機械学習モデルの作成・変換

実際にCoreMLで使うモデルとしてシンプルに実証したいので、今回はsklearn定番のアヤメ(Iris)データセットを使い、KNN (K-Nearest Neighbors)で分類モデルを作成しました。Irisデータセットは、各サンプルに4つの特徴量があり、それぞれ「ガク片の長さ・幅」「花弁の長さ・幅」を表します。最終的にそれらの特徴量からアヤメの品種(setosa, versicolor, virginica)を予測するプログラムとなっています。

以下は、Pythonでモデルを学習し、CoreML用の形式に変換するコードです:

from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris
import coremltools as cml

iris = load_iris()
X, y = iris.data, iris.target

knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X, y)

# CoreMLに変換(入力は4次元のベクトル、出力はカテゴリラベルのインデックス)
feature_descriptions = [('input', cml.models.datatypes.Array(4))]
coreml_model = cml.converters.sklearn.convert(knn, feature_descriptions, 'species')
coreml_model.save("IrisClassifier.mlmodel")

このコードでは、sklearnでモデルをトレーニングし、coremltoolsを使ってCoreML形式に変換しています。Array(4)という指定が、iOS側に「このモデルの入力は4つの数値である」という情報を与え、出力名として指定したspeciesがモデルの予測結果として返されます。

作成したIrisClassifier.mlmodelを、そのままでは利用できません。Xcodeで動作させるには、以下のコマンドでコンパイルし、mlmodelc形式に変換する必要があります。

xcrun coremlc compile IrisClassifier.mlmodel path/to/compile/model

実行すると、最適化されたIrisClassifier.mlmodelcディレクトリが生成されます。このディレクトリが、iOSアプリに組み込むためのモデルファイルとなります。今回の検証では、生成したIrisClassifier.mlmodelcをXcodeのプロジェクト内Build PhasesでCopied Bundle Resourcesに追加しました。

xcode iris classifier.mlmodelc

FlutterからMLモデルの呼び出し

FlutterからiOSなどプラットフォーム特有の関数を呼ぶには、MethodChannel APIを利用します。今回のケースでは、Flutter側から_channelのメソッドを呼ぶと、Swift側でCoreMLモデルが動作し、結果が返される仕組みです。

Dart/Flutter側
今回のケースでは、Flutterから機械学習モデルによる予測を呼ぶ際に、拡張性を高めるため、入力となる特徴量をJSON形式の文字列にエンコードしてSwift側へ送るようにしました。

import 'package:flutter/services.dart';

class CoreMLService {
  static const MethodChannel _channel =
      MethodChannel('com.example.flutter_coreml/coreml');

  Future<String?> predictIrisSpecies({
    required double sepalLength,
    required double sepalWidth,
    required double petalLength,
    required double petalWidth,
  }) async {
    return await _channel.invokeMethod('predict', {
      'sepalLength': sepalLength,
      'sepalWidth': sepalWidth,
      'petalLength': petalLength,
      'petalWidth': petalWidth,
    });
  }
}

Swift側
Flutterから受け取ったメソッド呼び出しをhandle(_:result:)で受け取り、call.methodの内容に応じて処理を分岐します。predictメソッドが呼ばれたときは、jsonInputというキーで送られてきたJSON文字列をデコードし、必要な値を取り出して推論処理に渡します。

public func handle(_ call: FlutterMethodCall, result: @escaping FlutterResult) {
    case "predict":
        guard let args = call.arguments as? [String: Any],
              let jsonInput = args["jsonInput"] as? String,
              let data = jsonInput.data(using: .utf8),
              let jsonDict = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
              let sepalLength = jsonDict["sepal_length"] as? Double,
              let sepalWidth = jsonDict["sepal_width"] as? Double,
              let petalLength = jsonDict["petal_length"] as? Double,
              let petalWidth = jsonDict["petal_width"] as? Double else {
            result(FlutterError(code: "INVALID_ARGS", message: "Invalid JSON input format", details: nil))
            return
        }
    
        CoreMLHandler.shared.predict(
            sepalLength: sepalLength,
            sepalWidth: sepalWidth,
            petalLength: petalLength,
            petalWidth: petalWidth
        ) { species, error in
            if let error = error {
                result(FlutterError(code: "PREDICTION_FAILED", message: error.localizedDescription, details: nil))
            } else {
                result(species)
            }
        }
    ...
}

predictメソッドでは、Flutterから受け取った特徴量をCoreMLが期待する形式(MLMultiArray)に変換し、推論を実行しています。

let inputArray = try MLMultiArray(shape: [4], dataType: .double)
inputArray[0] = NSNumber(value: sepalLength)
// inputArrayの1~3番目にもsepal/petalの値をNSNumber型で挿入
...
let provider = try MLDictionaryFeatureProvider(dictionary: ["input": inputArray])
let output = try model.prediction(from: provider)

CoreMLにモデルを渡す際は、inputというラベル名で渡している点に注意が必要です(CoreMLモデルを作成する際に指定したものと一致させる必要があるため)。

結果は、出力されたspeciesラベルから整数のインデックスとして取り出され、それをアヤメの品種名(setosa, versicolor, virginica)に変換して返しています。

let speciesList: [String] = ["setosa", "versicolor", "virginica"]
completion(speciesList[Int(speciesIndex)], nil)

UI

レポジトリ
https://github.com/nigow/flutter-iris

FlutterからNeural Engineを使ってCoreMLモデルを呼び出す検証を行いました。Dart / FlutterからMethodChannelを介してネイティブコードを呼び出し、Swift側でCoreMLモデルを読み込んで推論するという、二段階の構成にはなりますが、思っていたよりもシンプルに実装できた印象です。

本当は、Flutter側で入力からCoreMLモデルの呼び出しまでJSONなどで一貫して完結できる構成にしたかったのですが、scikit-learnとcoremltoolsの仕様を深く理解できておらず、結果としてSwift側に処理が寄ってしまいました。現状では、モデルごとに専用のSwiftアダプターを書く必要があり、拡張性の面ではやや難があります。ただし、coremltoolsのドキュメントをざっと見る限り、sklearnモデルはstr / dict / list型の入力に対応しているようなので、Swift側で受け取ったJSONをそのままモデルに渡す構成も可能かもしれません。

また、CoreMLのモデルファイル(.mlmodelc)の取り扱いにも少し苦労しました。iOSのRunnerでモデルのファイルパスをうまく指定できなかったため、今回はXcodeのCopied Bundle Resourcesに追加する方式を取りました。これだとモデルを更新するたびにXcodeで手動設定が必要になりやや手間です。モデルサイズが小さければ、Dart側でファイルストレージ経由で差し替えるような運用も検討できそうなので、別の方法を発見したら追記したいと思います。

まだまだ改善点はありますが、Neural Engineを活用した軽量な推論をFlutterアプリに組み込むという点では、有意義な検証だったと感じています。特に、クロスプラットフォーム開発において、リアルタイム性が求められる、あるいは頻繁に呼ばれるけれど処理はそこまで複雑ではないAIタスクをローカルで完結させたい場面では、今回のアプローチは十分選択肢に入ると感じました。

ちなみに、Android側にはTFLite(TensorFlow Lite)という似たようなフレームワークも存在しているようなので、次回はこちらも触ってみて比較してみたいと思います。

Discussion