📖

Keras のモデルを使って推論できる frugally-deep を試してみた。

4 min read

はじめに

C++ ヘッダオンリーのライブラリが大好きなので、しばしば色々なライブラリを探索する趣味があるのですが、frugally-deep というライブラリを見つけたので試してみました。

frugally-deep とは

https://github.com/Dobiasd/frugally-deep
  • モダンでピュアな C++ で書かれた小さなヘッダーのみのライブラリ
  • 非常に簡単に統合して使うことができる
  • FunctionalPlus、Eigen、json にのみ依存しており、これらもヘッダーオンリー
  • 逐次モデルだけでなく、関数型 API で作成された、より複雑なトポロジー計算グラフの推論(model.predict)もサポート
  • TensorFlow の(小さな)サブセット、つまり予測をサポートするために必要な操作を再実装している
  • TensorFlow をリンクするよりもはるかに小さいバイナリサイズ
  • 32ビットの実行ファイルにコンパイルしても動作
  • システムの中で最も強力な GPU を完全に無視し、予測ごとに1つのCPUコアしか使わない ;-) 🤔
  • TensorFlow に比べて1つのCPUコアでもかなり高速で、複数の予測を並行して実行できる。好きなだけCPUを利用して、アプリケーション/パイプラインの全体的な予測スループットを向上させることができる。

この様に若干、自虐的なライブラリです。以前記事を書いた Genann は ANSI C から使えるライブラリでしたが、frugally-deep は C++ です。

https://zenn.dev/mattn/articles/b43444214e06f3

Keras モデルの利用

frugally-deep の売り文句の1つに、Keras のモデルを利用できる事があります。実際は Keras で作り HDF5 で出力したファイルを JSON に変換し、その JSON を利用する方法です。試しに FizzBuzz のモデルを作ってみます。

import numpy as np

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.lite.python import lite


def fizzbuzz(i):
    if i % 15 == 0:
        return np.array([0, 0, 0, 1]).astype(np.float32)
    elif i % 5 == 0:
        return np.array([0, 0, 1, 0]).astype(np.float32)
    elif i % 3 == 0:
        return np.array([0, 1, 0, 0]).astype(np.float32)
    else:
        return np.array([1, 0, 0, 0]).astype(np.float32)


def bin(i, num_digits):
    return np.array([i >> d & 1 for d in range(num_digits)]).astype(np.float32)


trX = np.array([bin(i, 7) for i in range(1, 101)])
trY = np.array([fizzbuzz(i) for i in range(1, 101)])
model = Sequential()
model.add(Dense(64, input_dim=7))
model.add(Activation('tanh'))
model.add(Dense(4, input_dim=64))
model.add(Activation('softmax'))
model.compile(
    loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(trX, trY, epochs=3600, batch_size=64)
model.save('fizzbuzz_model.h5')

7要素に OneHot された入力を学習して4要素を出力するモデルです。実行すると fizzbuzz_model.h5 が生成されます。frugally-deep のリポジトリにある keras_export/convert_model_py を使って以下の様に変換します。

$ python keras_export/convert_model.py fizzbuzz_model.h5 fizzbuzz_model.json

中身をちょっと覗いたところウェイトの部分は BASE64 されたバイナリになっている様でした。

利用してみる

Keras で生成したモデルから変換された JSON ファイルを利用するには以下の様なコードを用意します。

#include <vector>
#include <iterator>
#include <algorithm>
#include <fdeep/fdeep.hpp>

static std::vector<float>
bin(int n, size_t digits) {
  std::vector<float> ret;
  for (auto i = 0; i < digits; i++) ret.push_back((float)((n >> i) & 1));
  return ret;
}

static int
dec(const std::vector<float> d) {
  auto it = std::max_element(d.begin(), d.end());
  return std::distance(d.begin(), it);
}

int
main() {
  const auto model = fdeep::load_model("fizzbuzz_model.json");
  for (auto i = 1; i <= 100; i++) {
    const auto result = model.predict(
        {fdeep::tensor(fdeep::tensor_shape(static_cast<std::size_t>(7)),
        bin(i, 7))});
    switch (dec(result[0].to_vector())) {
      case 0: std::cout << i << std::endl; break;
      case 1: std::cout << "Fizz" << std::endl; break;
      case 2: std::cout << "Buzz" << std::endl; break;
      case 3: std::cout << "FizzBuzz" << std::endl; break;
    }
  }
}

前述の様に、frugally-deep は FunctionalPlus と nlohmann/json に依存しているので、ビルドには以下の様な Makefile を用意しておくと良いでしょう。

SRCS = \
	fizzbuzz.cxx

OBJS = $(subst .cxx,.o,$(SRCS))

CXX = clang++
CXXFLAGS = -std=c++20 \
	-I c:/dev/frugally-deep/include \
	-I c:/dev/FunctionalPlus/include \
	-I c:/dev/nlohmann_json/include \
	-I c:/msys64/mingw64/include/eigen3
LIBS = 
TARGET = fizzbuzz
ifeq ($(OS),Windows_NT)
TARGET := $(TARGET).exe
endif

.SUFFIXES: .cxx .o

all : $(TARGET)

$(TARGET) : $(OBJS)
	$(CXX) -o $@ $(OBJS) $(LIBS)

.cxx.o :
	$(CXX) -c $(CXXFLAGS) -I. $< -o $@

clean :
	rm -f *.o $(TARGET)

ヘッダオンリーですが、ビルドにはそこそこ時間とマシンパワーが必要になります。実行すると見慣れた結果が表示されます。

1
2
Fizz
4
Buzz
Fizz
7
8
Fizz
Buzz
11
Fizz
13
14
FizzBuzz
16

Keras で HDF5 なモデルさえ吐けば簡単に利用する事ができるので、割と便利なのではと思います。

まとめ

Keras のモデルを利用できる C++ のヘッダオンリー機械学習ライブラリ、frugally-deep を紹介しました。今回は FizzBuzz という簡単な例でしたが、パフォーマンスも悪くないので、画像データの様な大きい入力データを扱う事もそれほど苦ではないと思います。興味のある方はチャレンジしてみては如何でしょうか。

この記事に贈られたバッジ

Discussion

ログインするとコメントできます