Faster WhisperとDistil Whisperの音声認識gRPCサーバ
概要
自前のgRPCサーバに音声認識モデルをホスティングしてみました。
音声認識モデルのWhisperを使いたい場合、第一候補としてはOpanAIのAPI利用があるかと思います。
普段の利用であれば十分ですが、利用料が多い場合Rate Limitの問題で使用できなくなる可能性があり、その場合は自前で持つ必要があります。
今回は検証としてWhisperを載せた音声認識のgRPCサーバを作成してみたいと思います。
比較的載せやすいFaster WhisperとDistil Whisperを載せてCPU推論させてみます。
GitHub: https://github.com/hosimesi/code-for-techblogs/tree/main/whisper_grpc_server
検証環境
以下の検証環境を用意しました。
チップ: Apple M3 Max
メモリ: 64 GB
ディレクトリ構成
.
├── README.md
├── clients
│ ├── preprocess.py
│ ├── proto
│ │ ├── inference_pb2.py
│ │ ├── inference_pb2.pyi
│ │ └── inference_pb2_grpc.py
│ └── request.py
├── compose.yaml
├── docker
│ └── Dockerfile
├── pretrained_models
│ ├── faster-distil-whisper-large-v2
│ │ ├── config.json
│ │ ├── model.bin
│ │ ├── preprocessor_config.json
│ │ ├── tokenizer.json
│ │ └── vocabulary.json
│ ├── faster-distil-whisper-large-v3
│ │ ├── config.json
│ │ ├── model.bin
│ │ ├── preprocessor_config.json
│ │ ├── tokenizer.json
│ │ └── vocabulary.json
│ ├── faster-whisper-large-v2
│ │ ├── config.json
│ │ ├── model.bin
│ │ ├── tokenizer.json
│ │ └── vocabulary.json
│ └── faster-whisper-large-v3
│ ├── config.json
│ ├── model.bin
│ ├── preprocessor_config.json
│ ├── tokenizer.json
│ └── vocabulary.json
├── proto
│ ├── codegen.py
│ └── inference.proto
├── pyproject.toml
├── requirements-dev.lock
├── requirements.lock
├── samples
│ └── audio.wav
└── src
├── main.py
├── models
│ ├── __init__.py
│ ├── base_asr_model.py
│ ├── faster_distil_whisper_large_v2_model.py
│ ├── faster_distil_whisper_large_v3_model.py
│ ├── faster_whisper_large_v2_model.py
│ ├── faster_whisper_large_v3_model.py
│ └── model_info.py
├── proto
│ ├── inference_pb2.py
│ ├── inference_pb2.pyi
│ └── inference_pb2_grpc.py
├── services
│ ├── __init__.py
│ └── inference_server.py
└── utils
├── __init__.py
├── consts.py
├── enums.py
└── logging.py
事前準備
pythonの準備
まずPython環境を用意します。今回はryeを使用します。
rye pin 3.12.3
そして、以下のpyproject.tomlを用意します
[project]
name = "whisper-grpc-server"
version = "0.1.0"
description = "Add your description here"
authors = [
{ name = "hosimesi", email = "hosimesi11@gmail.com" }
]
dependencies = [
"grpcio>=1.64.1",
"grpcio-tools>=1.64.1",
"faster-whisper>=1.0.2",
"pydantic>=2.7.4",
"transformers>=4.41.2",
"accelerate>=0.31.0",
"datasets>=2.20.0",
]
readme = "README.md"
requires-python = ">= 3.12"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = [
"pytest>=8.2.2",
"mypy>=1.10.0",
"ruff>=0.4.10",
"mypy-protobuf>=3.6.0",
]
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["src/whisper_grpc_server"]
[tool.rye.scripts]
server = { cmd = "docker compose up --build" }
lint = { chain = ["format:ruff-check", "lint:ruff-fix", "lint:mypy", "lint:pytest" ] }
"lint:ruff-check" = "ruff format . --check --diff"
"lint:ruff-fix" = "ruff check ."
"lint:mypy" = "mypy . --no-site-packages --explicit-package-bases"
"lint:pytest" = "pytest -vv"
モデルの準備
Hugging Faceからそれぞれのモデルをダウンロードしておき、pretrained_models/model_name/
に入れておきます。
- faster-whisper-large-v2
- faster-whisper-large-v3
- faster-distil-whisper-large-v2
- faster-distil-whisper-large-v3
gRPCサーバの構築
モデルの定義
まず、すべての音声認識モデルのベースクラスを用意します。
from abc import abstractmethod
import numpy as np
class BaseASRModel:
@abstractmethod
def transcribe(
self,
audio_array: np.ndarray,
language: str = "ja",
task: str = "transcribe",
without_timestamps: bool = True,
) -> str:
raise NotImplementedError
このベースクラスを継承した各モデルの定義を書いて行きます。以下はfaster-whisper-large-v2の例ですが、他のモデルも同様です。
import os
import numpy as np
from faster_whisper import WhisperModel
from src.models.base_asr_model import BaseASRModel
from src.utils import consts
from src.utils.enums import ASRModel
from src.utils.logging import get_logger
logger = get_logger(__name__)
class FasterWhisperLargeV2Model(BaseASRModel):
def __init__(
self,
device: str = "cpu",
compute_type: str = "int8",
cpu_threads: int = os.cpu_count() or 1,
num_workers: int = os.cpu_count() or 1,
):
dir_name = os.path.join(
consts.PRETRAINED_MODEL_DIR, ASRModel.FASTER_WHISPER_LARGE_V2.value
)
self.model = WhisperModel(
dir_name,
device=device,
compute_type=compute_type,
cpu_threads=cpu_threads,
num_workers=num_workers,
)
def transcribe(
self,
audio_array: np.ndarray,
language: str = "ja",
task: str = "transcribe",
without_timestamps: bool = False,
) -> str:
transcription = ""
segments, info = self.model.transcribe(
audio=audio_array,
language=language,
task=task,
without_timestamps=without_timestamps,
)
for segment in segments:
transcription += segment.text
return transcription
Protoファイルの定義
protoファイルを準備します。今回は入力として音声をbyte列として受け取り、認識結果の文字列を返します。
syntax = "proto3";
package inference;
message TranscribeRequest {
bytes audio_bytes = 1;
string target = 2;
}
message ASRResult {
string transcription = 1;
}
message TranscribeResponse {
map<string, ASRResult> result = 1;
}
service ASRInferenceServer {
rpc transcribe(TranscribeRequest) returns (TranscribeResponse) {}
}
そしてこのprotoファイルからソースコードを生成しておきます。
from grpc.tools import protoc
protoc.main(
(
"",
"-I.",
"--python_out=./src/",
"--grpc_python_out=./src/",
"--mypy_out=./src/",
"./proto/inference.proto",
)
)
gRPCサーバの定義
gRPCサーバの実装をしていきます。
Bytesで送られてくるので、こちら側でnumpy配列に直してtranscribeメソッドに流していきます。
音声ということもあり、messageが長くなる可能性があったので、optionsでmessageサイズを変更しています。
async def serve(bind_address: str) -> None:
logger.info("Starting new server.")
server = grpc.aio.server(
ThreadPoolExecutor(max_workers=50),
options=[
("grpc.max_send_message_length", 50 * 1024 * 1024),
("grpc.max_receive_message_length", 50 * 1024 * 1024),
],
)
inference_pb2_grpc.add_ASRInferenceServerServicer_to_server(ASRInferenceServer(), server)
server.add_insecure_port(bind_address)
logger.info("The server started successfully.")
await server.start()
await server.wait_for_termination()
class ASRInferenceServer(inference_pb2_grpc.ASRInferenceServerServicer):
def __init__(self) -> None:
self.models = ALL_MODELS
async def _extract_audio_from_request(self, audio_bytes: bytes) -> np.ndarray:
audio_array = np.frombuffer(audio_bytes, dtype=np.float32)
return audio_array
async def transcribe(self, request: Any, context: Any) -> inference_pb2.TranscribeResponse:
audio_array = await self._extract_audio_from_request(audio_bytes=request.audio_bytes)
target: str = request.target
try:
transcription = self.models[target].transcribe(audio_array=audio_array)
except Exception:
logger.error(traceback.format_exc())
return inference_pb2.TranscribeResponse(result={target: inference_pb2.ASRResult(transcription=transcription)})
音声認識
サーバの起動
実際にdocker環境で動かしてみます。
docker compose up --build
今4モデルをメモリに載せていますが、大体7GBくらい消費しています。
リクエストを投げる
簡単に音声ファイルを読み込んでbyte列にして送るクライアントを実装します。
「こんばんは、今日はいい日でした。」と話した約5秒ほどの音声ファイルのデータを先ほど作成したgRPCサーバに投げていきます。
ターゲットにモデル名を入れることで、ターゲットのモデルで音声認識が行われるようになります。
import asyncio
import time
from enum import Enum
import ffmpeg
import grpc
import numpy as np
from google.protobuf.json_format import MessageToDict
from proto import inference_pb2, inference_pb2_grpc
SIGNED_INT16_MAX = 32768.0
class ASRModel(str, Enum):
FASTER_WHISPER_LARGE_V3 = "faster-whisper-large-v3"
FASTER_WHISPER_LARGE_V2 = "faster-whisper-large-v2"
FASTER_DISTIL_WHISPER_LARGE_V3 = "faster-distil-whisper-large-v3"
FASTER_DISTIL_WHISPER_LARGE_V2 = "faster-distil-whisper-large-v2"
def load_audio_file(audio_file_path: str, sampling_rate: int = 16000, channels: int = 1) -> list[np.ndarray]:
# Create an input stream from the audio file
input_stream = ffmpeg.input(audio_file_path)
# Convert the audio to raw PCM data with a sample rate of `sampling_rate`
output_stream = ffmpeg.output(input_stream, "pipe:", format="s16le", acodec="pcm_s16le", ar=sampling_rate)
# Run the conversion and capture the raw PCM data
pcm_data, _ = ffmpeg.run(output_stream, overwrite_output=True, capture_stdout=True, capture_stderr=True)
# Convert the raw PCM data to a numpy array of 16-bit signed integers
audio_data = np.frombuffer(pcm_data, np.int16)
# Normalize the audio data to the range [-1.0, 1.0]
audio_data = audio_data.astype(np.float32, order="C") / SIGNED_INT16_MAX
# Split the audio data into separate channels
return [audio_data]
async def transcribe(grpc_stub, audio_data, asr_model):
start = time.time()
request = inference_pb2.TranscribeRequest(audio_bytes=audio_data.tobytes(), target=asr_model)
grpc_response = await grpc_stub.transcribe(request)
response = MessageToDict(grpc_response).get("result", {})
print(f"Response time: {time.time() - start} seconds")
print(response)
async def main():
audio_file_path = "samples/audio.wav"
audio_data = load_audio_file(audio_file_path)
grpc_channel = grpc.aio.insecure_channel("localhost:8080")
grpc_stub = inference_pb2_grpc.ASRInferenceServerStub(grpc_channel)
await transcribe(grpc_stub, audio_data[0], ASRModel.FASTER_WHISPER_LARGE_V2)
await transcribe(grpc_stub, audio_data[0], ASRModel.FASTER_WHISPER_LARGE_V3)
await transcribe(grpc_stub, audio_data[0], ASRModel.FASTER_DISTIL_WHISPER_LARGE_V3)
await transcribe(grpc_stub, audio_data[0], ASRModel.FASTER_DISTIL_WHISPER_LARGE_V2)
if __name__ == "__main__":
asyncio.run(main())
結果
一回の結果なので正確性には欠けますが、以下のような結果になりました。
認識結果については深く言及しないですが、Distil Whisperは速度面でも速いことが確認できました。
モデル名 | 認識結果 | 速度 |
---|---|---|
faster-distil-whisper-large-v2 | Good today was a today the today the today the today. Good today was a good today | 21秒 |
faster-distil-whisper-large-v3 | Good- Good-Wan-Wa, today was good day good day good day. Good night Good day today, good day day day day | 21秒 |
faster-whisper-large-v2 | こんばんは、今日はいい日でした。良い日でした。 | 35秒 |
faster-whisper-large-v3 | こんばんは。今日は良い日でした。こんばんは。今日は良い日でした。こんばんは。 | 39秒 |
最後に
今回は自前でgRPCサーバを立て、Faster WhisperとDistil Whisperをホスティングしてみました。
initial promptやvad filter、beam sizeなどチューニングできる部分はたくさんあるので、いろいろ触ってみたいです。CPUでも簡単に動かせるので、普段使いにも便利です。
参考
Discussion