🔱

Faster WhisperとDistil Whisperの音声認識gRPCサーバ

2024/06/27に公開

概要

自前のgRPCサーバに音声認識モデルをホスティングしてみました。
音声認識モデルのWhisperを使いたい場合、第一候補としてはOpanAIのAPI利用があるかと思います。
普段の利用であれば十分ですが、利用料が多い場合Rate Limitの問題で使用できなくなる可能性があり、その場合は自前で持つ必要があります。
今回は検証としてWhisperを載せた音声認識のgRPCサーバを作成してみたいと思います。
比較的載せやすいFaster WhisperDistil Whisperを載せてCPU推論させてみます。

GitHub: https://github.com/hosimesi/code-for-techblogs/tree/main/whisper_grpc_server

検証環境

以下の検証環境を用意しました。

チップ: Apple M3 Max
メモリ: 64 GB

ディレクトリ構成

.
├── README.md
├── clients
│   ├── preprocess.py
│   ├── proto
│   │   ├── inference_pb2.py
│   │   ├── inference_pb2.pyi
│   │   └── inference_pb2_grpc.py
│   └── request.py
├── compose.yaml
├── docker
│   └── Dockerfile
├── pretrained_models
│   ├── faster-distil-whisper-large-v2
│   │   ├── config.json
│   │   ├── model.bin
│   │   ├── preprocessor_config.json
│   │   ├── tokenizer.json
│   │   └── vocabulary.json
│   ├── faster-distil-whisper-large-v3
│   │   ├── config.json
│   │   ├── model.bin
│   │   ├── preprocessor_config.json
│   │   ├── tokenizer.json
│   │   └── vocabulary.json
│   ├── faster-whisper-large-v2
│   │   ├── config.json
│   │   ├── model.bin
│   │   ├── tokenizer.json
│   │   └── vocabulary.json
│   └── faster-whisper-large-v3
│       ├── config.json
│       ├── model.bin
│       ├── preprocessor_config.json
│       ├── tokenizer.json
│       └── vocabulary.json
├── proto
│   ├── codegen.py
│   └── inference.proto
├── pyproject.toml
├── requirements-dev.lock
├── requirements.lock
├── samples
│   └── audio.wav
└── src
    ├── main.py
    ├── models
    │   ├── __init__.py
    │   ├── base_asr_model.py
    │   ├── faster_distil_whisper_large_v2_model.py
    │   ├── faster_distil_whisper_large_v3_model.py
    │   ├── faster_whisper_large_v2_model.py
    │   ├── faster_whisper_large_v3_model.py
    │   └── model_info.py
    ├── proto
    │   ├── inference_pb2.py
    │   ├── inference_pb2.pyi
    │   └── inference_pb2_grpc.py
    ├── services
    │   ├── __init__.py
    │   └── inference_server.py
    └── utils
        ├── __init__.py
        ├── consts.py
        ├── enums.py
        └── logging.py

事前準備

pythonの準備

まずPython環境を用意します。今回はryeを使用します。

rye pin 3.12.3

そして、以下のpyproject.tomlを用意します

[project]
name = "whisper-grpc-server"
version = "0.1.0"
description = "Add your description here"
authors = [
    { name = "hosimesi", email = "hosimesi11@gmail.com" }
]
dependencies = [
    "grpcio>=1.64.1",
    "grpcio-tools>=1.64.1",
    "faster-whisper>=1.0.2",
    "pydantic>=2.7.4",
    "transformers>=4.41.2",
    "accelerate>=0.31.0",
    "datasets>=2.20.0",
]
readme = "README.md"
requires-python = ">= 3.12"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.rye]
managed = true
dev-dependencies = [
    "pytest>=8.2.2",
    "mypy>=1.10.0",
    "ruff>=0.4.10",
    "mypy-protobuf>=3.6.0",
]

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.build.targets.wheel]
packages = ["src/whisper_grpc_server"]

[tool.rye.scripts]
server = { cmd = "docker compose up --build" }
lint = { chain = ["format:ruff-check", "lint:ruff-fix", "lint:mypy", "lint:pytest" ] }
"lint:ruff-check" = "ruff format . --check --diff"
"lint:ruff-fix" = "ruff check ."
"lint:mypy" = "mypy . --no-site-packages --explicit-package-bases"
"lint:pytest" = "pytest -vv"

モデルの準備

Hugging Faceからそれぞれのモデルをダウンロードしておき、pretrained_models/model_name/に入れておきます。

gRPCサーバの構築

モデルの定義

まず、すべての音声認識モデルのベースクラスを用意します。

from abc import abstractmethod

import numpy as np


class BaseASRModel:
    @abstractmethod
    def transcribe(
        self,
        audio_array: np.ndarray,
        language: str = "ja",
        task: str = "transcribe",
        without_timestamps: bool = True,
    ) -> str:
        raise NotImplementedError

このベースクラスを継承した各モデルの定義を書いて行きます。以下はfaster-whisper-large-v2の例ですが、他のモデルも同様です。

import os

import numpy as np
from faster_whisper import WhisperModel
from src.models.base_asr_model import BaseASRModel
from src.utils import consts
from src.utils.enums import ASRModel
from src.utils.logging import get_logger

logger = get_logger(__name__)


class FasterWhisperLargeV2Model(BaseASRModel):
    def __init__(
        self,
        device: str = "cpu",
        compute_type: str = "int8",
        cpu_threads: int = os.cpu_count() or 1,
        num_workers: int = os.cpu_count() or 1,
    ):
        dir_name = os.path.join(
            consts.PRETRAINED_MODEL_DIR, ASRModel.FASTER_WHISPER_LARGE_V2.value
        )
        self.model = WhisperModel(
            dir_name,
            device=device,
            compute_type=compute_type,
            cpu_threads=cpu_threads,
            num_workers=num_workers,
        )

    def transcribe(
        self,
        audio_array: np.ndarray,
        language: str = "ja",
        task: str = "transcribe",
        without_timestamps: bool = False,
    ) -> str:
        transcription = ""
        segments, info = self.model.transcribe(
            audio=audio_array,
            language=language,
            task=task,
            without_timestamps=without_timestamps,
        )
        for segment in segments:
            transcription += segment.text
        return transcription

Protoファイルの定義

protoファイルを準備します。今回は入力として音声をbyte列として受け取り、認識結果の文字列を返します。

syntax = "proto3";

package inference;


message TranscribeRequest {
    bytes audio_bytes = 1;
    string target = 2;
}

message ASRResult {
    string transcription = 1;
}

message TranscribeResponse {
    map<string, ASRResult> result = 1;
}

service ASRInferenceServer {
    rpc transcribe(TranscribeRequest) returns (TranscribeResponse) {}
}

そしてこのprotoファイルからソースコードを生成しておきます。

from grpc.tools import protoc

protoc.main(
    (
        "",
        "-I.",
        "--python_out=./src/",
        "--grpc_python_out=./src/",
        "--mypy_out=./src/",
        "./proto/inference.proto",
    )
)

gRPCサーバの定義

gRPCサーバの実装をしていきます。
Bytesで送られてくるので、こちら側でnumpy配列に直してtranscribeメソッドに流していきます。
音声ということもあり、messageが長くなる可能性があったので、optionsでmessageサイズを変更しています。

async def serve(bind_address: str) -> None:
    logger.info("Starting new server.")
    server = grpc.aio.server(
        ThreadPoolExecutor(max_workers=50),
        options=[
            ("grpc.max_send_message_length", 50 * 1024 * 1024),
            ("grpc.max_receive_message_length", 50 * 1024 * 1024),
        ],
    )
    inference_pb2_grpc.add_ASRInferenceServerServicer_to_server(ASRInferenceServer(), server)
    server.add_insecure_port(bind_address)
    logger.info("The server started successfully.")
    await server.start()
    await server.wait_for_termination()


class ASRInferenceServer(inference_pb2_grpc.ASRInferenceServerServicer):
    def __init__(self) -> None:
        self.models = ALL_MODELS

    async def _extract_audio_from_request(self, audio_bytes: bytes) -> np.ndarray:
        audio_array = np.frombuffer(audio_bytes, dtype=np.float32)
        return audio_array

    async def transcribe(self, request: Any, context: Any) -> inference_pb2.TranscribeResponse:
        audio_array = await self._extract_audio_from_request(audio_bytes=request.audio_bytes)

        target: str = request.target
        try:
            transcription = self.models[target].transcribe(audio_array=audio_array)
        except Exception:
            logger.error(traceback.format_exc())

        return inference_pb2.TranscribeResponse(result={target: inference_pb2.ASRResult(transcription=transcription)})

音声認識

サーバの起動

実際にdocker環境で動かしてみます。

docker compose up --build

今4モデルをメモリに載せていますが、大体7GBくらい消費しています。

リクエストを投げる

簡単に音声ファイルを読み込んでbyte列にして送るクライアントを実装します。
こんばんは、今日はいい日でした。」と話した約5秒ほどの音声ファイルのデータを先ほど作成したgRPCサーバに投げていきます。
ターゲットにモデル名を入れることで、ターゲットのモデルで音声認識が行われるようになります。

import asyncio
import time
from enum import Enum

import ffmpeg
import grpc
import numpy as np
from google.protobuf.json_format import MessageToDict
from proto import inference_pb2, inference_pb2_grpc

SIGNED_INT16_MAX = 32768.0


class ASRModel(str, Enum):
    FASTER_WHISPER_LARGE_V3 = "faster-whisper-large-v3"
    FASTER_WHISPER_LARGE_V2 = "faster-whisper-large-v2"
    FASTER_DISTIL_WHISPER_LARGE_V3 = "faster-distil-whisper-large-v3"
    FASTER_DISTIL_WHISPER_LARGE_V2 = "faster-distil-whisper-large-v2"


def load_audio_file(audio_file_path: str, sampling_rate: int = 16000, channels: int = 1) -> list[np.ndarray]:
    # Create an input stream from the audio file
    input_stream = ffmpeg.input(audio_file_path)
    # Convert the audio to raw PCM data with a sample rate of `sampling_rate`
    output_stream = ffmpeg.output(input_stream, "pipe:", format="s16le", acodec="pcm_s16le", ar=sampling_rate)
    # Run the conversion and capture the raw PCM data
    pcm_data, _ = ffmpeg.run(output_stream, overwrite_output=True, capture_stdout=True, capture_stderr=True)
    # Convert the raw PCM data to a numpy array of 16-bit signed integers
    audio_data = np.frombuffer(pcm_data, np.int16)
    # Normalize the audio data to the range [-1.0, 1.0]
    audio_data = audio_data.astype(np.float32, order="C") / SIGNED_INT16_MAX
    # Split the audio data into separate channels
    return [audio_data]


async def transcribe(grpc_stub, audio_data, asr_model):
    start = time.time()
    request = inference_pb2.TranscribeRequest(audio_bytes=audio_data.tobytes(), target=asr_model)
    grpc_response = await grpc_stub.transcribe(request)
    response = MessageToDict(grpc_response).get("result", {})
    print(f"Response time: {time.time() - start} seconds")
    print(response)


async def main():
    audio_file_path = "samples/audio.wav"
    audio_data = load_audio_file(audio_file_path)

    grpc_channel = grpc.aio.insecure_channel("localhost:8080")
    grpc_stub = inference_pb2_grpc.ASRInferenceServerStub(grpc_channel)

    await transcribe(grpc_stub, audio_data[0], ASRModel.FASTER_WHISPER_LARGE_V2)
    await transcribe(grpc_stub, audio_data[0], ASRModel.FASTER_WHISPER_LARGE_V3)
    await transcribe(grpc_stub, audio_data[0], ASRModel.FASTER_DISTIL_WHISPER_LARGE_V3)
    await transcribe(grpc_stub, audio_data[0], ASRModel.FASTER_DISTIL_WHISPER_LARGE_V2)


if __name__ == "__main__":
    asyncio.run(main())

結果

一回の結果なので正確性には欠けますが、以下のような結果になりました。
認識結果については深く言及しないですが、Distil Whisperは速度面でも速いことが確認できました。

モデル名 認識結果 速度
faster-distil-whisper-large-v2 Good today was a today the today the today the today. Good today was a good today 21秒
faster-distil-whisper-large-v3 Good- Good-Wan-Wa, today was good day good day good day. Good night Good day today, good day day day day 21秒
faster-whisper-large-v2 こんばんは、今日はいい日でした。良い日でした。 35秒
faster-whisper-large-v3 こんばんは。今日は良い日でした。こんばんは。今日は良い日でした。こんばんは。 39秒

最後に

今回は自前でgRPCサーバを立て、Faster WhisperとDistil Whisperをホスティングしてみました。
initial promptやvad filter、beam sizeなどチューニングできる部分はたくさんあるので、いろいろ触ってみたいです。CPUでも簡単に動かせるので、普段使いにも便利です。

参考

https://www.ai-shift.co.jp/techblog/3093

Discussion