🤖

C++/ONNX Runtime/gRPC/マルチモーダルで推論サーバーを構築してみた

2025/01/06に公開

はじめに 📘

この記事は ラクスパートナーズ Advent Calendar 2024 の 23 日目の記事になります。

体調不良で記事の公開が遅れてしまいました…😇

今回は、C++を用いて推論サーバーを構築した際の手法と結果を共有します。

昨年末から C++コードに初めて触れました。その中で「C++で推論サーバーを作るとどうなるのか?」「推論のレイテンシやスループットはどのようなパフォーマンスを示すのか?」といった疑問が生まれました。
そこで、C++で推論サーバーを実際に構築し、そのパフォーマンスを検証することにしました。

C++を触って日が浅い為、至らない実装等あると思います。
不明点や改善提案があれば、ぜひコメントいただけると幸いです。

記事の目的(自分) 🎯

  • C++で推論サーバーを構築するスキルの習得
  • C++実装に慣れる
  • ONNX Runtime を利用した推論の理解
  • gRPC によるサーバー構築の学習

この記事の対象者 👩‍💻

  • C++で推論サーバー構築に挑戦したい人
  • gRPC サーバーの実装に興味がある人
  • ONNX Runtime を C++で利用してみたい人

注意事項 ⚠️

本記事で紹介するモデルや前処理の実装は、本番環境での利用を前提としていません。
動作確認を目的としており、一部の実装に本番環境で使用するには不適切な箇所が含まれる可能性があります。

想定ケース 💡

以下のようなマルチモーダルな機械学習モデルを使用した推論サーバーを構築します:

  1. 入力形式:
    • 画像、動画、テキスト、テーブルデータの 4 種類のデータを入力。
  2. 画像の前処理:
    • リサイズと正規化を行い、ResNet50 のavg_pool層を使用。
  3. 動画の前処理:
    • 動画から 30 フレームを抽出し、各フレームに画像と同じ前処理を適用。
    • フレームが不足する場合は黒埋めで補完。
  4. テキストの前処理:
    • Pybind を用いて Python コードで処理を実装。
    • OCR を利用してテキスト抽出を行う。
  5. テーブルデータの前処理:
    • One Hot エンコード、Multi Hot エンコードを実行。

開発環境 🛠️

  • M1 Mac mini 16GB

ディレクトリ構成 📂

📂 .
├── 📄 CMakeLists.txt
├── 📄 README.md
├── 📂 external
├── 📂 include
│   ├── 📂 config
│   │   └── 📄 global_config.h
│   ├── 📂 domain
│   │   ├── 📄 mock_model_config_provider.h
│   │   ├── 📄 model_config_provider.h
│   │   ├── 📄 resnet_batch_config_provider.h
│   │   └── 📄 resnet_config_provider.h
│   ├── 📂 inference
│   │   └── 📄 inference_manager.h
│   ├── 📂 preprocessor
│   │   ├── 📄 context_preprocessor.h
│   │   ├── 📄 encoder_interface.h
│   │   ├── 📄 image_preprocessor.h
│   │   ├── 📄 multi_hot_encoder.h
│   │   ├── 📄 preprocessor_interface.h
│   │   ├── 📄 single_hot_encoder.h
│   │   ├── 📄 text_preprocessor.h
│   │   └── 📄 video_preprocessor.h
│   ├── 📂 server
│   │   └── 📄 prediction_service_impl.h
│   └── 📂 utils
│       ├── 📄 python_manager.h
│       └── 📄 timer.h
├── 📂 models
│   ├── 📄 multi_modal_model.onnx
│   └── 📄 resnet50.onnx
├── 📄 prediction_client.py
├── 📂 proto
│   ├── 📄 predict.grpc.pb.cc
│   ├── 📄 predict.grpc.pb.h
│   ├── 📄 predict.pb.cc
│   ├── 📄 predict.pb.h
│   └── 📄 predict.proto
├── 📄 pyproject.toml
├── 📂 scripts
│   ├── 📄 build.sh
│   ├── 📄 debug_build.sh
│   ├── 📄 format_cmake.sh
│   ├── 📄 format_code.sh
│   ├── 📄 run_server.sh
│   └── 📄 run_test.sh
├── 📂 src
│   ├── 📂 config
│   │   └── 📄 global_config.cpp
│   ├── 📂 inference
│   │   └── 📄 inference_manager.cpp
│   ├── 📄 main.cpp
│   ├── 📂 preprocessor
│   │   ├── 📄 context_preprocessor.cpp
│   │   ├── 📄 image_preprocessor.cpp
│   │   ├── 📄 multi_hot_encoder.cpp
│   │   ├── 📄 single_hot_encoder.cpp
│   │   ├── 📄 text_preprocessor.cpp
│   │   └── 📄 video_preprocessor.cpp
│   ├── 📂 python
│   │   └── 📄 text_preprocessor.py
│   └── 📂 server
│       └── 📄 prediction_service_impl.cpp
└── 📂 tests
    └── 📂 units
        └── 📄 inference_manager_test.cpp

開発環境構築 ⚙️

onnxruntime のライブラリは、Homebrew 経由でのインストールだとうまく動作しなかったため、公式サイトからビルド済みバイナリを直接ダウンロードして使用しています。

# Homebrew の準備
brew update
brew upgrade

# ライブラリのインストール
brew install opencv
brew install grpc
brew install openssl
brew install protobuf
brew install abseil
brew install jsoncpp

# arm 用にビルド済みのバイナリをダウンロードして展開
mkdir -p external
curl -L -o external/onnxruntime-osx-arm64-1.20.1.tgz https://github.com/microsoft/onnxruntime/releases/download/v1.20.1/onnxruntime-osx-arm64-1.20.1.tgz
cd external
tar -xvzf onnxruntime-osx-arm64-1.20.1.tgz

# C++サーバー用ファイルの生成
cd proto
protoc --proto_path=. --cpp_out=. --grpc_out=. --plugin=protoc-gen-grpc=$(which grpc_cpp_plugin) predict.proto

# Pythonクライアント用ファイルの生成(必要に応じて)
protoc --proto_path=. --python_out=. --grpc_python_out=. predict.proto

Reference 🌐

https://onnxruntime.ai/

https://grpc.io/

実装 👆

本当は色々と実装の解説を書いていきたいところですが実装だけで今回は力つきました。。。笑
なので、実装コードをペタペタ貼っていくだけとなります。

モデル構築
resnet50
import torch
from torchvision.models import resnet50


# ResNet50モデルをロードし、トップ層を除外するクラス
class ResNet50NoTop(torch.nn.Module):
    def __init__(self):
        super(ResNet50NoTop, self).__init__()
        # 事前学習済みのResNet50モデルをロード
        self.resnet_base = resnet50(pretrained=True)
        # トップ層(Fully Connected Layer)を除外し、Global Average Pooling手前までを取得
        self.features = torch.nn.Sequential(*(list(self.resnet_base.children())[:-2]))
        # Global Average Pooling層
        self.global_avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        # 特徴量抽出
        x = self.features(x)  # トップ層を除外した特徴マップ
        x = self.global_avg_pool(x)  # Global Average Pooling
        x = torch.flatten(x, 1)  # ベクトル化 (batch_size, 2048)
        return x


# モデルのインスタンス化と推論モードへの設定
model = ResNet50NoTop()
model.eval()  # 推論モードに切り替え

# ダミー入力を作成 (NCHW形式: バッチサイズ, チャンネル, 高さ, 幅)
dummy_input = torch.randn(1, 3, 224, 224)

# モデルをONNX形式にエクスポート
onnx_model_path = "resnet50.onnx"
torch.onnx.export(
    model,  # PyTorchモデル
    dummy_input,  # ダミー入力
    onnx_model_path,  # 保存先
    opset_version=13,  # ONNX opsetバージョン
    input_names=["input"],  # 入力ノードの名前
    output_names=["avg_pool"],  # 出力ノードの名前
    dynamic_axes={
        "input": {0: "batch_size"},  # バッチサイズを動的軸として指定
        "avg_pool": {0: "batch_size"},
    },
)

print(f"ONNXモデルが保存されました: {onnx_model_path}")
multi_modal_model
import tensorflow as tf
import tf2onnx


# マルチモーダルモデルの構築
def create_model():
    # 画像データの処理
    image_input = tf.keras.Input(shape=(2048,), name="image_input")
    image_processed = tf.keras.layers.Dense(256, activation="linear")(image_input)
    image_processed = tf.keras.layers.BatchNormalization()(image_processed)
    image_processed = tf.keras.layers.LeakyReLU(alpha=0.01)(image_processed)

    # テキストデータの処理
    text_input = tf.keras.Input(shape=(1024,), name="text_input")
    text_embedding = tf.keras.layers.Embedding(input_dim=30522, output_dim=32)(text_input)
    text_processed = tf.keras.layers.GlobalAveragePooling1D()(text_embedding)
    text_processed = tf.keras.layers.Dense(64, activation="linear")(text_processed)
    text_processed = tf.keras.layers.BatchNormalization()(text_processed)
    text_processed = tf.keras.layers.LeakyReLU(alpha=0.01)(text_processed)

    # テーブルデータの処理
    table_input = tf.keras.Input(shape=(1160,), name="table_input")
    table_processed = tf.keras.layers.Dense(16, activation="linear")(table_input)
    table_processed = tf.keras.layers.BatchNormalization()(table_processed)
    table_processed = tf.keras.layers.LeakyReLU(alpha=0.01)(table_processed)

    # 動画データの処理
    video_input = tf.keras.Input(shape=(30, 2048), name="video_input")
    video_processed = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(512, activation="linear"))(video_input)
    video_processed = tf.keras.layers.BatchNormalization()(video_processed)
    video_processed = tf.keras.layers.TimeDistributed(tf.keras.layers.LeakyReLU(alpha=0.01))(video_processed)
    video_processed = tf.keras.layers.Flatten()(video_processed)

    # モダリティの統合
    combined = tf.keras.layers.Concatenate()(
        [image_processed, text_processed, table_processed, video_processed]
    )
    combined = tf.keras.layers.Dense(64, activation="linear")(combined)
    combined = tf.keras.layers.BatchNormalization()(combined)
    combined = tf.keras.layers.LeakyReLU(alpha=0.01)(combined)
    final_output = tf.keras.layers.Dense(1, activation="linear")(combined)

    # モデル構築
    model = tf.keras.Model(
        inputs=[image_input, text_input, table_input, video_input],
        outputs=final_output,
    )
    return model


# モデルの生成
model = create_model()

# ONNX形式に変換
onnx_model_path = "multimodal_model.onnx"
spec = [
    tf.TensorSpec((None, 2048), tf.float32, name="image_input"),
    tf.TensorSpec((None, 1024), tf.float32, name="text_input"),
    tf.TensorSpec((None, 1160), tf.float32, name="table_input"),
    tf.TensorSpec((None, 30, 2048), tf.float32, name="video_input"),
]

# TensorFlowからONNX形式に変換
model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13)
with open(onnx_model_path, "wb") as f:
    f.write(model_proto.SerializeToString())

print(f"ONNXモデルが保存されました: {onnx_model_path}")
前処理 Interface
EncoderInterface.h
#pragma once

#include <json/json.h>

#include <map>
#include <string>
#include <vector>

class EncoderInterface
{
   public:
    virtual ~EncoderInterface() = default;

    virtual std::map<std::string, std::vector<float>> transform(const Json::Value& data) const = 0;
};
preprocessor_interface.h
#pragma once
#include <onnxruntime_cxx_api.h>

#include <opencv2/opencv.hpp>
#include <string>
#include <vector>

#include "inference/inference_manager.h"

class ContextPreprocessorInterface
{
   public:
    virtual ~ContextPreprocessorInterface() = default;

    virtual std::vector<float> preprocess(const std::string& context_data) const = 0;
};

class ImagePreprocessorInterface
{
   public:
    virtual ~ImagePreprocessorInterface() = default;

    virtual std::vector<float> preprocess(const std::shared_ptr<InferenceManager>& session,
                                          Ort::AllocatorWithDefaultOptions& allocator,
                                          const std::string& image_path) const = 0;
};

class VideoPreprocessorInterface
{
   public:
    virtual ~VideoPreprocessorInterface() = default;

    virtual std::vector<float> preprocess(const std::shared_ptr<InferenceManager>& session,
                                          Ort::AllocatorWithDefaultOptions& allocator,
                                          const std::string& video_path) const = 0;

   private:
    virtual std::vector<float> process_frame(const cv::Mat& frame, size_t height, size_t width,
                                             size_t channels) const = 0;
};

class TextPreprocessorInterface
{
   public:
    virtual ~TextPreprocessorInterface() = default;

    virtual std::vector<float> preprocess(const std::string& image_path,
                                          const std::string& text_data) const = 0;
};
前処理 Image
image_preprocessor.h
#pragma once
#include <onnxruntime_cxx_api.h>

#include <opencv2/opencv.hpp>
#include <string>
#include <vector>

#include "inference/inference_manager.h"
#include "preprocessor/preprocessor_interface.h"

class ImagePreprocessorImpl : public ImagePreprocessorInterface
{
   public:
    std::vector<float> preprocess(const std::shared_ptr<InferenceManager>& session,
                                  Ort::AllocatorWithDefaultOptions& allocator,
                                  const std::string& image_path) const override;
};
image_preprocessor.cpp
#include "preprocessor/image_preprocessor.h"

#include <stdexcept>

#include "inference/inference_manager.h"

std::vector<float> ImagePreprocessorImpl::preprocess(
    const std::shared_ptr<InferenceManager>& session, Ort::AllocatorWithDefaultOptions& allocator,
    const std::string& image_path) const
{
    // 画像の読み込み
    cv::Mat image = cv::imread(image_path);
    if (image.empty())
    {
        throw std::runtime_error("Failed to load image");
    }

    // フレームをRGBに変換しリサイズ
    cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
    cv::resize(image, image, cv::Size(224, 224));
    image.convertTo(image, CV_32F, 1.0 / 255.0);

    // チャンネルごとの正規化
    const float mean[3] = {0.485, 0.456, 0.406};
    const float std[3] = {0.229, 0.224, 0.225};
    std::vector<cv::Mat> channels(3);
    cv::split(image, channels);  // チャンネルごとに分割
    for (int c = 0; c < 3; ++c)
    {
        channels[c] = (channels[c] - mean[c]) / std[c];  // 正規化
    }
    cv::merge(channels, image);  // 再結合

    // フラット化して保存
    std::vector<float> preprocessed(image.total() * image.channels());
    memcpy(preprocessed.data(), image.data, preprocessed.size() * sizeof(float));

    // 入力の形状
    std::vector<float> resnet_preprocessed = session->runInference({preprocessed});

    return resnet_preprocessed;
}
前処理 Video
video_preprocessor.h
#pragma once
#include <onnxruntime_cxx_api.h>

#include <opencv2/opencv.hpp>
#include <string>
#include <thread>
#include <vector>

#include "inference/inference_manager.h"
#include "preprocessor/preprocessor_interface.h"

class VideoPreprocessorImpl : public VideoPreprocessorInterface
{
   public:
    std::vector<float> preprocess(const std::shared_ptr<InferenceManager>& session,
                                  Ort::AllocatorWithDefaultOptions& allocator,
                                  const std::string& video_path) const override;

   private:
    std::vector<float> process_frame(const cv::Mat& frame, size_t height, size_t width,
                                     size_t channels) const override;
};
video_preprocessor.cpp
#include "preprocessor/video_preprocessor.h"

#include <stdexcept>

#include "inference/inference_manager.h"

std::vector<float> VideoPreprocessorImpl::preprocess(
    const std::shared_ptr<InferenceManager>& session, Ort::AllocatorWithDefaultOptions& allocator,
    const std::string& video_path) const
{
    cv::VideoCapture cap(video_path);
    if (!cap.isOpened())
    {
        throw std::runtime_error("Failed to open video");
    }

    int target_frames = 30;
    size_t height = 224, width = 224, channels = 3;

    // 全フレームを読み取る
    std::vector<cv::Mat> frames(target_frames);
    for (int i = 0; i < target_frames; ++i)
    {
        if (!cap.read(frames[i]))
        {
            frames[i] = cv::Mat();  // 空フレーム(処理で黒埋めされる)
        }
    }

    // 出力ベクトルを事前に確保
    std::vector<std::vector<float>> processed_frames(target_frames);

    // スレッドで並列処理
    std::vector<std::thread> threads;
    for (int i = 0; i < target_frames; ++i)
    {
        threads.emplace_back(
            [&, i]() { processed_frames[i] = process_frame(frames[i], height, width, channels); });
    }

    // 全スレッドの完了を待機
    for (auto& thread : threads)
    {
        if (thread.joinable())
        {
            thread.join();
        }
    }

    // フラット化
    std::vector<float> flat_processed_frames;
    for (const auto& frame : processed_frames)
    {
        flat_processed_frames.insert(flat_processed_frames.end(), frame.begin(), frame.end());
    }

    std::vector<float> resnet_preprocessed = session->runInference({flat_processed_frames});
    return resnet_preprocessed;
}

std::vector<float> VideoPreprocessorImpl::process_frame(const cv::Mat& frame, size_t height,
                                                        size_t width, size_t channels) const
{
    if (frame.empty())
    {
        return std::vector<float>(height * width * channels, 0.0f);  // 黒埋めフレーム
    }

    cv::Mat image;
    cv::cvtColor(frame, image, cv::COLOR_BGR2RGB);
    cv::resize(image, image, cv::Size(width, height));
    image.convertTo(image, CV_32F, 1.0 / 255.0);

    cv::Mat mean_mat(image.size(), CV_32FC3, cv::Scalar(0.485, 0.456, 0.406));
    cv::Mat std_mat(image.size(), CV_32FC3, cv::Scalar(0.229, 0.224, 0.225));
    image = (image - mean_mat) / std_mat;

    std::vector<float> flat_frame(image.total() * image.channels());
    memcpy(flat_frame.data(), image.data, flat_frame.size() * sizeof(float));
    return flat_frame;
}
前処理 text
text_preprocessor.h
#pragma once
#include <pybind11/embed.h>
#include <pybind11/numpy.h>

#include <string>
#include <vector>

#include "preprocessor/preprocessor_interface.h"

class TextPreprocessorImpl : public TextPreprocessorInterface
{
   public:
    std::vector<float> preprocess(const std::string& image_path,
                                  const std::string& text_data) const override;
};
text_preprocessor.cpp
#include "preprocessor/text_preprocessor.h"

#include <iostream>
#include <stdexcept>
namespace py = pybind11;

std::vector<float> TextPreprocessorImpl::preprocess(const std::string& image_path,
                                                    const std::string& text_data) const
{
    PyGILState_STATE gil_state = PyGILState_Ensure();  // Acquire GIL
    try
    {
        std::vector<float> output;

        // Python object manipulation
        {
            py::module_ py_module = py::module_::import("text_preprocessor");
            py::object py_result =
                py_module.attr("preprocess_text_with_ocr")(image_path, text_data);

            // Convert NumPy array to std::vector<float>
            py::array_t<float> array = py_result.cast<py::array_t<float>>();
            output.resize(array.size());
            std::memcpy(output.data(), array.data(), array.size() * sizeof(float));
        }

        PyGILState_Release(gil_state);  // Release GIL
        return output;
    }
    catch (const py::error_already_set& e)
    {
        PyGILState_Release(gil_state);  // Release GIL on error
        std::cerr << "Python Error: " << e.what() << std::endl;
        throw std::runtime_error("Preprocessing failed");
    }
    catch (...)
    {
        PyGILState_Release(gil_state);  // Release GIL on unknown error
        throw;
    }
}
text_preprocessor.py
import re
import time

import numpy as np
from numpy.typing import NDArray
from transformers import AutoTokenizer


# モック関数でOCR処理をシミュレート
def mock_ocr_process(image_path: str, sleep_time: float = 1.0) -> str:
    time.sleep(sleep_time)  # API呼び出しの遅延を再現
    return "Extracted text from image"  # 固定値を返す


def preprocess_text(texts: str, maxlen: int = 50) -> NDArray[np.float32]:
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    def clean_text(text: str) -> str:
        """
        基本的なテキストクリーニング。
        """
        text = text.lower()  # 小文字化
        text = re.sub(r"https?://\S+|www\.\S+", "", text)  # URL削除
        text = re.sub(r"[^a-zA-Z0-9\s]", "", text)  # 特殊文字削除
        text = re.sub(r"\d+", "[NUM]", text)  # 数字を正規化
        return text

    # テキストのクリーニング
    cleaned_texts = clean_text(texts)

    # テキストがmaxlenを超える場合、前半と後半に分割
    words = cleaned_texts.split()
    if len(words) > maxlen:
        print(f"Input text exceeds maxlen ({len(words)} > {maxlen}). Splitting...")
        midpoint = len(words) // 2
        front_segment = " ".join(words[:midpoint])  # 前半
        back_segment = " ".join(words[midpoint:])  # 後半
    else:
        front_segment = cleaned_texts
        back_segment = ""

    # 前半と後半をそれぞれトークナイゼーション
    front_tokenized = tokenizer(
        front_segment,
        max_length=maxlen,
        padding="max_length",
        truncation=True,
        return_tensors="np",
    )["input_ids"].astype(np.float32)

    back_tokenized = (
        tokenizer(
            back_segment,
            max_length=maxlen,
            padding="max_length",
            truncation=True,
            return_tensors="np",
        )["input_ids"].astype(np.float32)
        if back_segment
        else np.zeros_like(front_tokenized)
    )

    # 前半と後半を結合
    combined_tokenized = np.concatenate([front_tokenized, back_tokenized], axis=1)

    return combined_tokenized


# OCR処理を含むテキスト前処理
def preprocess_text_with_ocr(
    image_path: str, texts: str, maxlen: int = 512
) -> NDArray[np.float32]:
    """
    OCRを利用したテキスト前処理。

    Args:
        tokenizer (AutoTokenizer): トークナイザーオブジェクト。
        image_path (str): OCR対象の画像パス。
        maxlen (int): 最大トークン数。

    Returns:
        NDArray[np.float32]: 結合されたトークナイズ済みデータ。
    """
    # OCR処理
    extracted_text = mock_ocr_process(image_path, sleep_time=0.5)
    print(f"OCR extracted text: {extracted_text}")

    # 抽出されたテキストを前処理
    return preprocess_text(texts=texts + extracted_text, maxlen=maxlen)
前処理 Context
single_hot_encoder.h
#pragma once

#include "encoder_interface.h"

class SingleHotEncoder : public EncoderInterface
{
   public:
    explicit SingleHotEncoder(const std::map<std::string, std::vector<std::string>>& mapping)
        : mapping_(mapping)
    {
    }

    std::map<std::string, std::vector<float>> transform(const Json::Value& data) const override;

   private:
    std::map<std::string, std::vector<std::string>> mapping_;
};
single_hot_encoder.cpp
#include "preprocessor/single_hot_encoder.h"

#include <algorithm>

std::map<std::string, std::vector<float>> SingleHotEncoder::transform(const Json::Value& data) const
{
    std::map<std::string, std::vector<float>> processed;

    for (const auto& [col, classes] : mapping_)
    {
        std::vector<float> one_hot(classes.size(), 0.0f);
        if (data.isMember(col))
        {
            const std::string& value = data[col].asString();
            auto it = std::find(classes.begin(), classes.end(), value);
            if (it != classes.end())
            {
                one_hot[std::distance(classes.begin(), it)] = 1.0f;
            }
        }
        processed[col] = one_hot;
    }

    return processed;
}
multi_hot_encoder.h
#pragma once

#include "encoder_interface.h"

class MultiHotEncoder : public EncoderInterface
{
   public:
    explicit MultiHotEncoder(const std::map<std::string, std::vector<std::string>>& mapping)
        : mapping_(mapping)
    {
    }

    std::map<std::string, std::vector<float>> transform(const Json::Value& data) const override;

   private:
    std::map<std::string, std::vector<std::string>> mapping_;
};
multi_hot_encoder.cpp
#include "preprocessor/multi_hot_encoder.h"

#include <algorithm>

std::map<std::string, std::vector<float>> MultiHotEncoder::transform(const Json::Value& data) const
{
    std::map<std::string, std::vector<float>> processed;

    for (const auto& [col, classes] : mapping_)
    {
        std::vector<float> multi_hot(classes.size(), 0.0f);
        if (data.isMember(col) && data[col].isArray())
        {
            for (const auto& value : data[col])
            {
                auto it = std::find(classes.begin(), classes.end(), value.asString());
                if (it != classes.end())
                {
                    multi_hot[std::distance(classes.begin(), it)] = 1.0f;
                }
            }
        }
        processed[col] = multi_hot;
    }

    return processed;
}
context_preprocessor.h
#pragma once
#include <json/json.h>

#include <map>
#include <string>
#include <vector>

#include "preprocessor/multi_hot_encoder.h"
#include "preprocessor/preprocessor_interface.h"
#include "preprocessor/single_hot_encoder.h"

class ContextPreprocessorImpl : public ContextPreprocessorInterface
{
   public:
    explicit ContextPreprocessorImpl(
        const std::vector<std::string>& float_columns,
        const std::map<std::string, std::vector<std::string>>& single_hot_mapping,
        const std::map<std::string, std::vector<std::string>>& multi_hot_mapping,
        const std::vector<std::string>& context_order)
        : float_columns_(float_columns),
          single_hot_encoder_(SingleHotEncoder(single_hot_mapping)),
          multi_hot_encoder_(MultiHotEncoder(multi_hot_mapping)),
          context_order_(context_order)
    {
    }

    std::vector<float> preprocess(const std::string& context_data) const override;

   private:
    std::vector<std::string> float_columns_;
    SingleHotEncoder single_hot_encoder_;
    MultiHotEncoder multi_hot_encoder_;
    std::vector<std::string> context_order_;
};
context_preprocessor.cpp
#include "preprocessor/context_preprocessor.h"

#include <sstream>
#include <stdexcept>

std::vector<float> ContextPreprocessorImpl::preprocess(const std::string& context_data) const
{
    Json::CharReaderBuilder builder;
    Json::Value root;
    std::istringstream s(context_data);
    std::string errs;

    if (!Json::parseFromStream(builder, s, &root, &errs))
    {
        throw std::runtime_error("Failed to parse context data: " + errs);
    }

    std::map<std::string, std::vector<float>> processed_data;

    for (const auto& col : float_columns_)
    {
        if (root.isMember(col))
        {
            processed_data[col] = {root[col].asFloat()};
        }
    }

    auto single_hot_data = single_hot_encoder_.transform(root);
    processed_data.insert(single_hot_data.begin(), single_hot_data.end());

    auto multi_hot_data = multi_hot_encoder_.transform(root);
    processed_data.insert(multi_hot_data.begin(), multi_hot_data.end());

    std::vector<float> combined;
    for (const auto& col : context_order_)
    {
        if (processed_data.find(col) != processed_data.end())
        {
            combined.insert(combined.end(), processed_data[col].begin(), processed_data[col].end());
        }
    }

    return combined;
}
サーバー
prediction_service_impl.h
#pragma once
#include <grpcpp/grpcpp.h>
#include <onnxruntime_cxx_api.h>

#include <memory>
#include <string>

#include "domain/mock_model_config_provider.h"
#include "domain/resnet_batch_config_provider.h"
#include "domain/resnet_config_provider.h"
#include "inference/inference_manager.h"
#include "predict.grpc.pb.h"
#include "preprocessor/context_preprocessor.h"
#include "preprocessor/image_preprocessor.h"
#include "preprocessor/text_preprocessor.h"
#include "preprocessor/video_preprocessor.h"

class PredictionServiceImpl final : public predict::PredictionService::Service
{
   public:
    PredictionServiceImpl(std::unique_ptr<ImagePreprocessorImpl> image_preprocessor,
                          std::unique_ptr<VideoPreprocessorImpl> video_preprocessor,
                          std::unique_ptr<ContextPreprocessorImpl> context_preprocessor,
                          std::unique_ptr<TextPreprocessorImpl> text_preprocessor,
                          std::unique_ptr<MockModelConfigProvider> main_config,
                          std::unique_ptr<ResnetConfigProvider> resnet_config,
                          std::unique_ptr<ResnetBatchConfigProvider> resnet_batch_config);

    grpc::Status Predict(grpc::ServerContext* context, const predict::PredictRequest* request,
                         predict::PredictResponse* response) override;

   private:
    std::unique_ptr<ImagePreprocessorImpl> image_preprocessor_;
    std::unique_ptr<VideoPreprocessorImpl> video_preprocessor_;
    std::unique_ptr<ContextPreprocessorImpl> context_preprocessor_;
    std::unique_ptr<TextPreprocessorImpl> text_preprocessor_;
    Ort::Env env_;
    std::shared_ptr<InferenceManager> session_;
    std::shared_ptr<InferenceManager> resnet_session_;
    std::shared_ptr<InferenceManager> resnet_batch_session_;
    std::unique_ptr<Ort::AllocatorWithDefaultOptions> allocator_;
};
prediction_service_impl.cpp
#include "server/prediction_service_impl.h"

#include <filesystem>
#include <iostream>
#include <thread>

#include "domain/mock_model_config_provider.h"
#include "domain/resnet_batch_config_provider.h"
#include "domain/resnet_config_provider.h"
#include "inference/inference_manager.h"
#include "utils/timer.h"

PredictionServiceImpl::PredictionServiceImpl(
    std::unique_ptr<ImagePreprocessorImpl> image_preprocessor,
    std::unique_ptr<VideoPreprocessorImpl> video_preprocessor,
    std::unique_ptr<ContextPreprocessorImpl> context_preprocessor,
    std::unique_ptr<TextPreprocessorImpl> text_preprocessor,
    std::unique_ptr<MockModelConfigProvider> main_config,
    std::unique_ptr<ResnetConfigProvider> resnet_config,
    std::unique_ptr<ResnetBatchConfigProvider> resnet_batch_config)
    : image_preprocessor_(std::move(image_preprocessor)),
      video_preprocessor_(std::move(video_preprocessor)),
      context_preprocessor_(std::move(context_preprocessor)),
      text_preprocessor_(std::move(text_preprocessor)),
      session_(std::make_shared<InferenceManager>(main_config->getModelPath(),
                                                  main_config->getInputShapes())),
      resnet_session_(std::make_shared<InferenceManager>(resnet_config->getModelPath(),
                                                         resnet_config->getInputShapes())),
      resnet_batch_session_(std::make_shared<InferenceManager>(
          resnet_batch_config->getModelPath(), resnet_batch_config->getInputShapes())),
      allocator_(std::make_unique<Ort::AllocatorWithDefaultOptions>())
{
    std::cout << "PredictionServiceImpl initialized." << std::endl;
}

grpc::Status PredictionServiceImpl::Predict(grpc::ServerContext* context,
                                            const predict::PredictRequest* request,
                                            predict::PredictResponse* response)
{
    Timer timer("Prediction API");

    try
    {
        std::vector<float> context_tensor;
        std::vector<float> image_tensor;
        std::vector<float> video_tensor;
        std::vector<float> text_tensor;

        std::vector<std::thread> threads;
        threads.emplace_back(
            [&]()
            {
                context_tensor = measure_time(
                    "context_preprocessor_->preprocess",
                    [&]() { return context_preprocessor_->preprocess(request->context_data()); });
            });
        threads.emplace_back(
            [&]()
            {
                image_tensor =
                    measure_time("image_preprocessor_->preprocess",
                                 [&]()
                                 {
                                     return image_preprocessor_->preprocess(
                                         resnet_session_, *allocator_, request->image_data());
                                 });
            });
        threads.emplace_back(
            [&]()
            {
                video_tensor =
                    measure_time("video_preprocessor_->preprocess",
                                 [&]()
                                 {
                                     return video_preprocessor_->preprocess(
                                         resnet_batch_session_, *allocator_, request->video_data());
                                 });
            });
        threads.emplace_back(
            [&]()
            {
                text_tensor = measure_time("text_preprocessor_->preprocess",
                                           [&]()
                                           {
                                               return text_preprocessor_->preprocess(
                                                   request->image_data(), request->text_data());
                                           });
            });

        measure_time("Thread join",
                     [&]()
                     {
                         for (auto& thread : threads)
                         {
                             if (thread.joinable())
                             {
                                 thread.join();
                             }
                         }
                     });

        // 推論実行
        std::vector<std::vector<float>> inputs = {image_tensor, text_tensor, context_tensor,
                                                  video_tensor};
        std::vector<float> predictions = session_->runInference(inputs);

        // 結果をレスポンスに設定
        for (float value : predictions)
        {
            response->add_predictions(value);
        }

        return grpc::Status::OK;
    }
    catch (const std::exception& ex)
    {
        return grpc::Status(grpc::StatusCode::INTERNAL, ex.what());
    }
}
推論
inference_manager.h
#pragma once
#include <onnxruntime_cxx_api.h>

#include <string>
#include <vector>

class InferenceManager
{
   public:
    // コンストラクタでモデルパスと入力形状を受け取る
    InferenceManager(const std::string& model_path,
                     const std::vector<std::vector<int64_t>>& input_shapes);

    // 推論を実行する
    std::vector<float> runInference(const std::vector<std::vector<float>>& input_tensors);

   private:
    Ort::Env env_;
    Ort::Session session_;
    Ort::AllocatorWithDefaultOptions allocator_;
    std::vector<std::vector<int64_t>> input_shapes_;  // 入力形状を保持

    static Ort::Session createSession(const Ort::Env& env, const std::string& model_path);
};
inference_manager.cpp
#include "inference/inference_manager.h"

#include <filesystem>
#include <iostream>
#include <stdexcept>

InferenceManager::InferenceManager(const std::string& model_path,
                                   const std::vector<std::vector<int64_t>>& input_shapes)
    : env_(ORT_LOGGING_LEVEL_WARNING, "InferenceManager"),
      session_(createSession(env_, model_path)),
      input_shapes_(input_shapes)
{
}

std::vector<float> InferenceManager::runInference(
    const std::vector<std::vector<float>>& input_tensors)
{
    if (input_tensors.size() != input_shapes_.size())
    {
        throw std::invalid_argument("Input tensors and shapes size mismatch.");
    }

    // Prepare ONNX inputs
    std::vector<std::string> input_names;
    std::vector<Ort::Value> ort_inputs;

    for (size_t i = 0; i < input_tensors.size(); ++i)
    {
        auto input_name = session_.GetInputNameAllocated(i, allocator_);
        input_names.push_back(input_name.get());

        Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
            allocator_.GetInfo(), const_cast<float*>(input_tensors[i].data()),
            input_tensors[i].size() * sizeof(float), input_shapes_[i].data(),
            input_shapes_[i].size());
        ort_inputs.push_back(std::move(input_tensor));
    }

    // ONNX の入力名ポインタの準備
    std::vector<const char*> input_name_ptrs;
    input_name_ptrs.reserve(input_names.size());
    for (const auto& name : input_names)
    {
        input_name_ptrs.push_back(name.c_str());
    }

    auto output_name_ptr = session_.GetOutputNameAllocated(0, allocator_);
    const char* output_name = output_name_ptr.get();

    // 推論を実行
    auto output_tensors = session_.Run(Ort::RunOptions{nullptr}, input_name_ptrs.data(),
                                       ort_inputs.data(), input_name_ptrs.size(), &output_name, 1);

    // 結果を取得
    float* output_data = output_tensors.front().GetTensorMutableData<float>();
    size_t output_size = output_tensors.front().GetTensorTypeAndShapeInfo().GetElementCount();

    return std::vector<float>(output_data, output_data + output_size);
}

Ort::Session InferenceManager::createSession(const Ort::Env& env, const std::string& model_path)
{
    if (!std::filesystem::exists(model_path))
    {
        throw std::runtime_error("Model file: " + model_path + " does not exits");
    }
    Ort::SessionOptions session_options;
    session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);

    return Ort::Session(env, model_path.c_str(), session_options);
}
Utils
timer.h
#pragma once
#include <chrono>
#include <iostream>
#include <string>

class Timer
{
   public:
    explicit Timer(const std::string& name)
        : name_(name), start_time_(std::chrono::high_resolution_clock::now())
    {
    }

    ~Timer()
    {
        auto end_time = std::chrono::high_resolution_clock::now();
        auto duration =
            std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time_).count();
        std::cout << "[" << name_ << "] executed in " << duration << " ms.\n";
    }

   private:
    std::string name_;
    std::chrono::high_resolution_clock::time_point start_time_;
};

template <typename Func, typename... Args>
auto measure_time(const std::string& func_name, Func&& func, Args&&... args)
{
    using namespace std::chrono;

    auto start_time = high_resolution_clock::now();

    if constexpr (std::is_void_v<std::invoke_result_t<Func, Args...>>)
    {
        // 戻り値がvoidの場合
        std::invoke(std::forward<Func>(func), std::forward<Args>(args)...);
        auto end_time = high_resolution_clock::now();
        auto duration = duration_cast<milliseconds>(end_time - start_time).count();
        std::cout << "Function [" << func_name << "] executed in " << duration << " ms.\n";
    }
    else
    {
        // 戻り値がvoid以外の場合
        auto result = std::invoke(std::forward<Func>(func), std::forward<Args>(args)...);
        auto end_time = high_resolution_clock::now();
        auto duration = duration_cast<milliseconds>(end_time - start_time).count();
        std::cout << "Function [" << func_name << "] executed in " << duration << " ms.\n";
        return result;
    }
}
python_manager.h
class PythonInterpreterManager
{
   public:
    PythonInterpreterManager()
    {
        PyConfig config;
        PyConfig_InitPythonConfig(&config);

        // program_name の設定は不要
        Py_InitializeFromConfig(&config);

        PyConfig_Clear(&config);

        // ビルドディレクトリと仮想環境のsite-packagesをsys.pathに追加
        PyRun_SimpleString(
            "import site\n"
            "import signal\n"
            "import sys\n"
            "sys.path.append(\"" PYTHON_SCRIPT_PATH
            "\")\n"
            "sys.path.append(\"" PYTHON_SITE_PACKAGES
            "\")\n"
            "signal.signal(signal.SIGINT, signal.SIG_DFL)\n");

        _mainThreadState = PyEval_SaveThread();  // GILを解放
    }

    ~PythonInterpreterManager()
    {
        if (Py_IsInitialized())
        {
            PyEval_RestoreThread(_mainThreadState);  // GILを再取得
            Py_Finalize();
        }
    }

    PyThreadState* GetThreadState() { return _mainThreadState; }

   private:
    PyThreadState* _mainThreadState = nullptr;
};
Config
global_config.h
#pragma once
#include <string>

namespace GlobalConfig
{
// グローバルな設定値
inline const std::string MODEL_PATH = "models/model.onnx";
inline const std::string RESNET_MODEL_PATH = "models/resnet50.onnx";
inline const std::string SERVER_ADDRESS = "0.0.0.0:50051";

// 初期化関数
void printConfig();
}  // namespace GlobalConfig
global_config.cpp
#include "config/global_config.h"

#include <iostream>

void GlobalConfig::printConfig()
{
    std::cout << "Server Address: " << SERVER_ADDRESS << std::endl;
    std::cout << "Model Path: " << MODEL_PATH << std::endl;
    std::cout << "ResNet Model Path: " << RESNET_MODEL_PATH << std::endl;
}
tests/
inference_manager_test.cpp
#include "inference/inference_manager.h"

#include <gtest/gtest.h>

TEST(InferenceManagerTest, ThrowsWhenModelPathIsInvalid)
{
    EXPECT_THROW(InferenceManager("/invalid/path/to/model.onnx", {{1, 3, 224, 224}}),
                 std::runtime_error);
}

TEST(InferenceManagerTest, RunsInferenceCorrectly)
{
    // モデルファイルのパスと入力形状を指定
    InferenceManager manager("models/resnet50.onnx", {{1, 3, 224, 224}});

    // 入力データを用意
    std::vector<std::vector<float>> input_tensors(1);
    input_tensors[0].resize(224 * 224 * 3, 0.5f);

    // 推論の結果を取得
    std::vector<float> output = manager.runInference(input_tensors);

    // 結果を検証
    EXPECT_EQ(output.size(), 2048);  // 例: クラス数が1000の場合
}
main.cpp
main.cpp
#include <grpcpp/grpcpp.h>

#include <random>

#include "domain/mock_model_config_provider.h"
#include "domain/resnet_batch_config_provider.h"
#include "domain/resnet_config_provider.h"
#include "server/prediction_service_impl.h"
#include "utils/python_manager.h"
#include "config/global_config.cpp"

void RunServer(const std::string& server_address, const std::string& model_path,
               const std::string& resnet_model_path)
{
    PythonInterpreterManager python_manager;

    auto image_preprocessor = std::make_unique<ImagePreprocessorImpl>();
    auto video_preprocessor = std::make_unique<VideoPreprocessorImpl>();

    // Main Model
    std::unique_ptr<MockModelConfigProvider> main_model_config =
        std::make_unique<MockModelConfigProvider>(model_path);
    TableData table_data = main_model_config->get_table_data();
    std::vector<std::string> context_order = main_model_config->get_context_order(table_data);
    const auto& float_cols = main_model_config->getFloatCols();
    const auto& single_hot_cols = main_model_config->getSingleHotCols();
    const auto& multi_hot_cols = main_model_config->getMultiHotCols();

    // Resnet Model
    std::unique_ptr<ResnetConfigProvider> resnet_model_config =
        std::make_unique<ResnetConfigProvider>(resnet_model_path);
    std::unique_ptr<ResnetBatchConfigProvider> resnet_batch_model_config =
        std::make_unique<ResnetBatchConfigProvider>(resnet_model_path);

    auto context_preprocessor = std::make_unique<ContextPreprocessorImpl>(
        float_cols, single_hot_cols, multi_hot_cols, context_order);
    auto text_preprocessor = std::make_unique<TextPreprocessorImpl>();

    PredictionServiceImpl service(std::move(image_preprocessor), std::move(video_preprocessor),
                                  std::move(context_preprocessor), std::move(text_preprocessor),
                                  std::move(main_model_config), std::move(resnet_model_config),
                                  std::move(resnet_batch_model_config));

    grpc::ServerBuilder builder;
    builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
    builder.RegisterService(&service);
    builder.SetSyncServerOption(grpc::ServerBuilder::SyncServerOption::NUM_CQS, 4);
    builder.SetSyncServerOption(grpc::ServerBuilder::SyncServerOption::MIN_POLLERS, 4);
    builder.SetSyncServerOption(grpc::ServerBuilder::SyncServerOption::MAX_POLLERS, 8);
    std::unique_ptr<grpc::Server> server(builder.BuildAndStart());

    std::cout << "Server listening on " << server_address << std::endl;
    server->Wait();
}

int main()
{
    GlobalConfig::printConfig();
    RunServer(GlobalConfig::SERVER_ADDRESS, GlobalConfig::MODEL_PATH, GlobalConfig::RESNET_MODEL_PATH);
    return 0;
}
CMakeLists.txt
cmake_minimum_required(VERSION 3.16)

project(
    PredictionService
    LANGUAGES CXX
    VERSION 1.0)

# C++標準を指定
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

# Python 環境を設定
set(Python_FIND_VIRTUALENV FIRST)
set(PYBIND11_FINDPYTHON ON)
find_package(Python REQUIRED COMPONENTS Interpreter Development)
find_package(pybind11 REQUIRED)

message(STATUS "Python interpreter: ${Python_EXECUTABLE}")
message(STATUS "Python include directories: ${Python_INCLUDE_DIRS}")
message(STATUS "Python libraries: ${Python_LIBRARIES}")
message(STATUS "Python version: ${Python_VERSION}")

# Pythonのヘッダファイルとライブラリを明示的に追加
include_directories(${Python_INCLUDE_DIRS})
link_directories(${Python_LIBRARIES})

# OpenCV を検索
find_package(OpenCV REQUIRED)

# ONNX Runtime ライブラリを設定
include_directories(${CMAKE_SOURCE_DIR}/external/onnxruntime-osx-arm64-1.20.1/include)

# gRPC と Protobuf を検索
find_package(gRPC REQUIRED)
find_package(Protobuf REQUIRED)

# JSONCPP ライブラリを検索
find_package(JsonCpp REQUIRED)

# absl ライブラリを検索
find_package(absl REQUIRED)

# 共通ライブラリを作成
add_library(
    inference_library STATIC
    src/inference/inference_manager.cpp
    src/preprocessor/image_preprocessor.cpp
    src/preprocessor/video_preprocessor.cpp
    src/preprocessor/context_preprocessor.cpp
    src/preprocessor/multi_hot_encoder.cpp
    src/preprocessor/single_hot_encoder.cpp
    src/preprocessor/text_preprocessor.cpp
    src/config/global_config.cpp
    src/server/prediction_service_impl.cpp)

# ライブラリにインクルードパスを設定
target_include_directories(
    inference_library
    PRIVATE ${Python_INCLUDE_DIRS}
            ${CMAKE_SOURCE_DIR}/src
            ${CMAKE_SOURCE_DIR}/include
            ${CMAKE_SOURCE_DIR}/proto
            ${CMAKE_SOURCE_DIR}/external/onnxruntime-osx-arm64-1.20.1/include
            ${OpenCV_INCLUDE_DIRS}
            ${Protobuf_INCLUDE_DIRS})

# ライブラリに必要なリンクを追加
target_link_libraries(
    inference_library
    PRIVATE ${Python_LIBRARIES}
            ${CMAKE_SOURCE_DIR}/external/onnxruntime-osx-arm64-1.20.1/lib/libonnxruntime.dylib
            ${OpenCV_LIBS}
            gRPC::grpc++
            # Protobuf::libprotobuf
            JsonCpp::JsonCpp
            absl::base
            absl::status
            absl::strings
            absl::log
            absl::check
            absl::failure_signal_handler
            pthread)

add_library(common_interface INTERFACE)
target_include_directories(
    common_interface
    INTERFACE ${CMAKE_SOURCE_DIR}/src
            ${CMAKE_SOURCE_DIR}/include
            ${CMAKE_SOURCE_DIR}/proto
            ${CMAKE_SOURCE_DIR}/external/onnxruntime-osx-arm64-1.20.1/include
            ${OpenCV_INCLUDE_DIRS}
            ${Protobuf_INCLUDE_DIRS}
            ${Python_INCLUDE_DIRS})

# 本番用の実行ファイルを作成
add_executable(PredictionService src/main.cpp proto/predict.grpc.pb.cc proto/predict.pb.cc)

# 実行ファイルにライブラリをリンク
target_link_libraries(PredictionService PRIVATE inference_library common_interface)

# 仮想環境のPythonパスを取得
execute_process(
    COMMAND ${Python_EXECUTABLE} -c
    "import sys; from distutils.sysconfig import get_python_lib; print(get_python_lib())"
    OUTPUT_VARIABLE SITE_PACKAGES
    OUTPUT_STRIP_TRAILING_WHITESPACE
)

# 必要なパスをコンパイル時定数に追加
target_compile_definitions(PredictionService PRIVATE
    PYTHON_SITE_PACKAGES="${SITE_PACKAGES}"
    PYTHON_SCRIPT_PATH="${CMAKE_SOURCE_DIR}/src/python"
)

# テスト用設定
include(FetchContent)
fetchcontent_declare(googletest URL https://github.com/google/googletest/archive/release-1.12.1.zip
                                    DOWNLOAD_EXTRACT_TIMESTAMP TRUE)
fetchcontent_makeavailable(googletest)

# テスト用実行ファイルを作成
add_executable(run_tests tests/units/inference_manager_test.cpp)

# テスト用ターゲットにライブラリをリンク
target_link_libraries(run_tests PRIVATE inference_library common_interface gtest_main gmock)


# テストの自動化
enable_testing()
add_test(
    NAME run_tests
    COMMAND run_tests
    WORKING_DIRECTORY ${CMAKE_SOURCE_DIR})
proto/predict.proto
syntax = "proto3";

package predict;

// メッセージ定義
message PredictRequest {
    string image_data = 1; // 画像URL
    string video_data = 2; // 動画URL
    string context_data = 3; // JSON形式のコンテキストデータ
    string text_data = 4; // テキストデータ
}

message PredictResponse {
    repeated float predictions = 1;
}

// サービス定義
service PredictionService {
    rpc Predict(PredictRequest) returns (PredictResponse);
}

Scripts

build.sh
# !/bin/bash

set -e

# プロジェクトのルートディレクトリを設定

PROJECT_ROOT=$(dirname "$(readlink -f "$0")")/..

# ビルドディレクトリを作成

BUILD_DIR="$PROJECT_ROOT/build"
mkdir -p "$BUILD_DIR"

# 利用可能な CPU コア数を取得(macOS 用の修正)

if [[ "$OSTYPE" == "darwin"* ]]; then
    NUM_CORES=$(sysctl -n hw.ncpu)
else
    NUM_CORES=$(nproc)
fi

# デバッグビルドの実行

cd "$BUILD_DIR"
cmake -DCMAKE_BUILD_TYPE=Debug ..
make -j"$NUM_CORES"

echo "Build completed successfully."
format_cmake.sh
#!/bin/bash

# プロジェクトのルートディレクトリを取得
PROJECT_ROOT=$(dirname "$(readlink -f "$0")")/..

# CMakeLists.txt と *.cmake ファイルを検索
FILES=$(find "$PROJECT_ROOT" -name "CMakeLists.txt" -o -name "*.cmake")

# cmake-formatのバージョン確認
cmake-format --version

# フォーマットを実行
for file in $FILES; do
    echo "Formatting $file"
    cmake-format -i "$file"
done

echo "All CMake files formatted successfully."
format_code.sh
#!/bin/bash

# スクリプトの中断条件
set -e

# プロジェクトのルートディレクトリを取得
PROJECT_ROOT=$(dirname "$(readlink -f "$0")")/..

# ClangFormatのバージョン確認
clang-format --version

# フォーマットするファイルを検索
FILES=$(find "$PROJECT_ROOT"/src \( -name "*.cpp" -o -name "*.h" \))

# フォーマット実行
for file in $FILES; do
    echo "Formatting $file"
    clang-format -i "$file"
done

echo "All files formatted successfully."
run_server.sh
#!/bin/bash

set -e

# プロジェクトのルートディレクトリを設定
PROJECT_ROOT=$(dirname "$(readlink -f "$0")")/..

# 実行可能ファイルのパス
EXECUTABLE="$PROJECT_ROOT/build/PredictionService"

# サーバーの実行
if [[ -f "$EXECUTABLE" ]]; then
    "$EXECUTABLE"
    # lldb "$EXECUTABLE"
else
    echo "Executable not found. Please run build.sh first."
    exit 1
fi
run_test.sh
#!/bin/bash

set -e  # エラーが発生した場合にスクリプトを終了

# プロジェクトのルートディレクトリを設定
PROJECT_ROOT=$(dirname "$(readlink -f "$0")")/..

# ビルドディレクトリを設定
BUILD_DIR="$PROJECT_ROOT/build"

# ビルドディレクトリを作成
if [ ! -d "$BUILD_DIR" ]; then
    mkdir -p "$BUILD_DIR"
fi

# ビルド実行
cd "$BUILD_DIR"
cmake ..
make -j$(sysctl -n hw.ncpu)

# テストを実行
echo "Running tests..."
ctest --output-on-failure

まとめ ✍️

C++での推論サーバーの実装は不慣れな為、想像以上に大変でしたが、この経験を通じて、C++を活用した推論サーバー構築の可能性を自身の選択肢の一つとして捉えられるようになりました。

今回取り組めなかった課題としては、Pythonで同様の推論サーバーを構築し、性能(スループット、レイテンシ)や開発効率を比較することが挙げられます。
時間と余力があれば、これらの比較検証を次回の記事で取り上げ、C++実装の有用性をさらに掘り下げていきたいと考えています。

本記事が、C++での推論サーバー構築を検討している方々の参考になれば幸いです。また、改善点や気になる点があれば、ぜひご意見をお寄せください!

その他、実装中に学んだことメモ(振り返り用)✏️

Debug

ライブラリのリンク確認

  • macOS:

    otool -L ./your_executable
    
    • 実行可能ファイルにリンクされているライブラリを確認できます。

    • 出力例:

      /usr/lib/libc++.1.dylib
      /usr/lib/libSystem.B.dylib
      
  • Linux:

    ldd ./your_executable
    
    • 出力例:

      linux-vdso.so.1
      libstdc++.so.6
      libc.so.6
      

ONNX Runtime のデバッグ

  • 環境変数でログレベルを詳細に設定:

    export ORT_LOG_LEVEL=VERBOSE
    
    • これにより、ONNX Runtime 内部の詳細ログを表示できます。

LLDB の使用

  • 基本操作:

    lldb ./your_executable
    (lldb) run
    
  • エラーで停止後の操作:

    • 続行:

      (lldb) continue
      
    • スタックトレースの確認:

      (lldb) thread backtrace
      
    • 変数の確認:

      (lldb) print variable_name
      
    • ステップ実行:

      (lldb) step  # 関数内部に入る
      (lldb) next  # 関数呼び出しをスキップ
      
RAII の原則

RAII の概要

  • 定義:
    • リソースの取得はコンストラクタで行い、解放はデストラクタで行う設計原則。
  • メリット:
    • 例外安全性の向上: 例外が発生しても、リソースが確実に解放される。
    • リソースリーク防止: 明示的に解放する必要がないため、漏れが発生しにくい。

悪い例(RAII を使わない場合)

std::mutex mtx;

void bad_function() {
    mtx.lock();  // ロック取得
    // 処理中に例外が発生すると…
    throw std::runtime_error("An error occurred!");
    mtx.unlock();  // 解放されない → デッドロック
}
  • 問題点:
    • unlock が呼ばれない場合、ロックが解除されず、他のスレッドが進行できなくなる。

良い例(RAII を使用する場合)

std::mutex mtx;

void good_function() {
    std::lock_guard<std::mutex> lock(mtx);  // コンストラクタでロック取得
    // 処理中に例外が発生しても…
    throw std::runtime_error("An error occurred!");
}  // デストラクタで自動的にロック解除
  • 改善点:
    • std::lock_guard がスコープを抜けるときに自動的に unlock が呼び出されるため、リソースリークが発生しない。

タイマーの RAII 活用例

#include <iostream>
#include <chrono>

class Timer {
   public:
    explicit Timer(const std::string& name)
        : name_(name), start_time_(std::chrono::high_resolution_clock::now()) {}
    ~Timer() {
        auto end_time = std::chrono::high_resolution_clock::now();
        auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time_).count();
        std::cout << "[" << name_ << "] executed in " << duration << " ms.\n";
    }

   private:
    std::string name_;
    std::chrono::high_resolution_clock::time_point start_time_;
};

void example_function() {
    Timer timer("ExampleFunction");
    // 長時間処理
}

Discussion