👋

Style-Bert-VITS2のAPI化

2024/08/28に公開

はじめに

Style-Bert-VITS2(SBV2)は非常に高性能な合成音声AIモデルです。
https://github.com/litagin02/Style-Bert-VITS2

また、ありがたいことに商用利用も可能なライセンス「AGPL-3.0 license、
LGPL-3.0 license」になっております。
このライセンスについては下記の記事が非常に参考になります
https://qiita.com/tatsumi_t2/items/3da688a5123d37986331

そして、こちらの記事に下記の通り記載されています。

AGPLで作られたソフトウェアを自分のソフトウェアに組み込んで利用する場合、そのソフトウェアにもライセンスが伝搬します。

API呼び出しのようなネットワークを介した通信による利用の場合は、組み込みとは見なされずライセンスの伝搬はしないようです。

つまり、SBV2をソフトウェアに組み込んだ場合、ソフトウェア全体のソースコードを公開する必要があります。
しかしながら、SBV2をAPIなどのネットワークを介した通信により利用する場合は、その限りでは無いようです。
したがって、今回はSBV2をAPIで別サーバから呼び出す方法について記述します。
また、そのコードはライセンス的に公開する必要があるため、下記にて公開させていただきます。
https://github.com/personabb/sbv2_api

(2024年8月29日追記)AGPL ver3ライセンスについて

準備

開発環境

著者の開発環境は、M2 Mac、RAM16GBです。

環境構築

python 3.11を利用します。
(3.12は動かないことが確認できています)
pythonが利用できる環境にしてください。

リポジトリのクローン

git clone https://github.com/personabb/sbv2_api.git

必要な音声モデルの取得

「amitaro」モデルと「jvnv-F1-jp」モデルを取得して、下記のように格納してください。


sbv2_api/
    ├ model_assets/
    |      ├ amitaro/
    |      |       ├ amitaro.safetensors
    |      |       ├ config.json
    |      |       └ style_vectors.npy
    |      └ jvnv-F1-jp/
    |              ├ jvnv-F1-jp_e160_s14000.safetensors
    |              ├ config.json
    |              └ style_vectors.npy
    ├ dict_data/
    |     └ default.csv
    ├ sbv2_api.py
    └ client.py

「amitaro」モデルと「jvnv-F1-jp」モデルはSBV2のデフォルトモデルになります。

https://zenn.dev/asap/articles/f8c0621cdd74cc#環境構築
上記の記事の「環境構築」の章まで終われば、「Style-Bert-VITS2」リポジトリの「model_assets」フォルダに上記のモデルのフォルダが格納されているはずなので、それをコピーしてきてください。

また、上記のモデルは下記のコードで呼び出しているので、利用しますが、自分で学習したモデルであっても問題ありません。
その場合は、下記のコードの該当部分を書き換えてください。

必要パッケージのインストール

pip install numpy==1.26.4
pip install style-bert-vits2
pip install sounddevice
pip install fastapi\[all\]

コードの実装

サーバ側

sbv2_api.py
import os
import numpy as np
from pathlib import Path
from style_bert_vits2.nlp import bert_models
from style_bert_vits2.constants import Languages
from style_bert_vits2.tts_model import TTSModel
from style_bert_vits2.logging import logger
from style_bert_vits2.nlp.japanese.user_dict import update_dict
import torch
from pydantic import BaseModel

from fastapi import FastAPI, Depends, Header
from typing import List, Dict, Any
from fastapi.security.api_key import APIKeyHeader
import uvicorn
import json
import time
import glob

device = "cuda" if torch.cuda.is_available() else "cpu"
update_user_dict = False
default_dict_path = "dict_data/default.csv"
compiled_dict_path = "dict_data/user.dic"
bert_models_model = "ku-nlp/deberta-v2-large-japanese-char-wwm"
bert_models_tokenizer = "ku-nlp/deberta-v2-large-japanese-char-wwm"


class SBV2:
    def __init__(self, model_path):
        logger.remove()

        if update_user_dict:
            print("loading user dict")
            update_dict(default_dict_path = Path(default_dict_path), compiled_dict_path = Path(compiled_dict_path))
        

        if device == "auto":
            self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.DEVICE = device

        bert_models.load_model(Languages.JP, bert_models_model)
        bert_models.load_tokenizer(Languages.JP, bert_models_tokenizer)

        style_file = glob.glob(f'{model_path}/*.npy',recursive=True)[0]
        config_file = glob.glob(f'{model_path}/*.json',recursive=True)[0]
        model_file = glob.glob(f'{model_path}/*.safetensors',recursive=True)[0]

        print(style_file)
        print(config_file)
        print(model_file)

        
        self.model_TTS = TTSModel(
            model_path=model_file,
            config_path=config_file,
            style_vec_path=style_file,
            device=self.DEVICE
        )

    def call_TTS(self,message):
        sr, audio = self.model_TTS.infer(text=message)

        return sr, audio
    
    def text2speech(self,message):
        sr, audio = self.model_TTS.infer(text=message)
        sd.play(audio, sr)
        sd.wait()

app = FastAPI()

class SBV2_inputs(BaseModel):
    text: str

class SBV2_init(BaseModel):
    modelname: str

# ユーザごとのインスタンスを管理する辞書
user_instances: Dict[str, Dict] = {}

class Dependencies:
    def __init__(self,api_key, model):
        model_path = f"model_assets/{model}"
        self.sbv2 = SBV2(model_path = model_path)

    def get_sbv2(self):   
        return self.sbv2


def get_user_dependencies(api_key: str,model = None):
    #過去にAPIkeyが登録されていない場合は新規登録
    if api_key not in user_instances:
        if model is None:
            raise Exception("model is required for the first time initialization")
        user_instances[api_key] = Dependencies(api_key, model)
        
    #登録されている場合はそのまま返す
    return user_instances[api_key]

API_KEY_NAME = "api_key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True)
def get_api_key(api_key: str = Depends(api_key_header)):
    return api_key

print("server started")

@app.post("/initialize/")
async def initialize(
    inputs: SBV2_init,
    api_key: str =  Depends(get_api_key)
    ):
    dependencies = get_user_dependencies(api_key, inputs.modelname)
    #初回の実行は`torch.nn.utils.weight_norm`のFutureWarningのせいか、処理時間が長いので、初期化のタイミングで初回の実行を終わらせておく
    _, _ = dependencies.get_sbv2().call_TTS("初期化")
    return {"message": "Initialized"}

@app.post("/process/")
async def process_data(
    inputs: SBV2_inputs,
    api_key: str = Depends(get_api_key),    
):
    dependencies = get_user_dependencies(api_key)
    start_tts = time.time()
    sr, audio = dependencies.get_sbv2().call_TTS(inputs.text)
    print(f"Time taken for TTS: {time.time() - start_tts}")
    return {"audio": audio.tolist(), "sr": sr}


if __name__ == "__main__":
    uvicorn.run(app, host="127.0.0.1", port=8001)

解説

ユーザ辞書の更新

if update_user_dict:
    print("loading user dict")
    update_dict(default_dict_path = Path(default_dict_path), compiled_dict_path = Path(compiled_dict_path))

上記の部分でユーザ辞書の更新をしています。
ユーザ辞書を登録したい場合は、update_user_dictTrueに設定し、default_dict_path = "dict_data/default.csv"に指定されているパスに辞書を置いてください。
(初期はupdate_user_dict=Falseとなっているため、ユーザ辞書を登録していません)

ユーザ辞書に関しては下記をご覧ください。
https://zenn.dev/asap/articles/f8c0621cdd74cc#辞書登録

api_keyに関して

# ユーザごとのインスタンスを管理する辞書
user_instances: Dict[str, Dict] = {}

class Dependencies:
    def __init__(self,api_key, model):
        model_path = f"model_assets/{model}"
        self.sbv2 = SBV2(model_path = model_path)

    def get_sbv2(self):   
        return self.sbv2

def get_user_dependencies(api_key: str,model = None):
    #過去にAPIkeyが登録されていない場合は新規登録
    if api_key not in user_instances:
        if model is None:
            raise Exception("model is required for the first time initialization")
        user_instances[api_key] = Dependencies(api_key, model)
        
    #登録されている場合はそのまま返す
    return user_instances[api_key]

こちらの部分において、api_keyが違う場合、異なるSBV2のインスタンスが生成され(その後、user_instances辞書に保存されます)、api_keyが同じ場合は、user_instances辞書に保存されているSBV2インスタンスを再利用しています。

このように実装することにより、api_keyを変えることで、複数の音声モデルを切り替えて実行することができます。

初期化と実行を分割

@app.post("/initialize/")
async def initialize(
    inputs: SBV2_init,
    api_key: str =  Depends(get_api_key)
    ):
    dependencies = get_user_dependencies(api_key, inputs.modelname)
    #初回の実行は`torch.nn.utils.weight_norm`のFutureWarningのせいか、処理時間が長いので、初期化のタイミングで初回の実行を終わらせておく
    _, _ = dependencies.get_sbv2().call_TTS("初期化")
    return {"message": "Initialized"}

@app.post("/process/")
async def process_data(
    inputs: SBV2_inputs,
    api_key: str = Depends(get_api_key),    
):
    dependencies = get_user_dependencies(api_key)
    start_tts = time.time()
    sr, audio = dependencies.get_sbv2().call_TTS(inputs.text)
    print(f"Time taken for TTS: {time.time() - start_tts}")
    return {"audio": audio.tolist(), "sr": sr}

こちらの通り、初期化(/initialize/)と実行(/process/)は分けています。
初期化時には、api_keyと音声モデルの名前を取得して、SBV2のインスタンスの作成と、user_instances辞書への保存を行っています。

また、SBV2は、インスタンス作成後の最初の実行において、下記のWarningのせいか処理速度が若干遅くなるという課題があるため、初期化時に初回の実行も合わせて実行しています。
(これは私の環境だけかもしれませんので、不要であればコメントアウトしてください)

該当箇所

#初回の実行は`torch.nn.utils.weight_norm`のFutureWarningのせいか、処理時間が長いので、初期化のタイミングで初回の実行を終わらせておく
_, _ = dependencies.get_sbv2().call_TTS("初期化")

Warning

FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.

また、実行時(/process/)にはapi_keyと発話テキストを取得して、api_keyに合わせて作成したインスタンスを辞書から取得し、発話テキストを合成音声にて発話させています。

クライアント側

client.py
import sys, os
import torch
import requests
import numpy as np
import sounddevice as sd

def init_abv2_api(api_key = "sbv2_amitaro", model_name = "amitaro"):
    init_url = "http://127.0.0.1:8001/initialize/"

    # サーバー側のインスタンスを初期化
    headers = {"api_key": api_key}

    init_inputs = {
        "modelname": model_name,
    }

    init_response = requests.post(init_url, json=init_inputs, headers=headers)
    if init_response.status_code == 200:
        print("Initialization successful.")
    else:
        print("Initialization failed.")
        exit(1)

def call_TTS_API(text,api_key = "sbv2_amitaro"):
    url = "http://127.0.0.1:8001/process/"
    headers = {"api_key": api_key}

    inputs = {
        "text": text,
    }

    response = requests.post(url, json=inputs, headers=headers)
    # JSONデータとしてレスポンスを解析
    data = response.json() 

    audio = data['audio']
    audio = np.array(audio, dtype=np.float32)
    audio = audio / 32768.0
    sr = data['sr']

    return audio, sr

if __name__ == "__main__":
    init_abv2_api(api_key = "sbv2_amitaro", model_name = "amitaro")
    init_abv2_api(api_key = "sbv2_jvnv-F1-jp", model_name = "jvnv-F1-jp")

    audio, sr = call_TTS_API("こんにちは。",api_key = "sbv2_amitaro")
    sd.play(audio, sr)
    sd.wait()

    audio, sr = call_TTS_API("こんにちは。",api_key = "sbv2_jvnv-F1-jp")
    sd.play(audio, sr)
    sd.wait()

解説

初期化関数

def init_abv2_api(api_key = "sbv2_amitaro", model_name = "amitaro"):
    init_url = "http://127.0.0.1:8001/initialize/"

    # サーバー側のインスタンスを初期化
    headers = {"api_key": api_key}

    init_inputs = {
        "modelname": model_name,
    }

    init_response = requests.post(init_url, json=init_inputs, headers=headers)
    if init_response.status_code == 200:
        print("Initialization successful.")
    else:
        print("Initialization failed.")
        exit(1)

上記の通り、クライアント側の準備が整ったら、サーバ側にインスタンスを生成してもらうために初期化の要求を行います、
この時にapi_keymodel_nameをサーバ側に投げます。
このapi_keymodel_nameは一意に紐づくため、後述する実行時の関数において、呼び出したい音声モデルのapi_keyを指定する必要があります。

実行時は下記のように呼び出します

init_abv2_api(api_key = "sbv2_amitaro", model_name = "amitaro")
init_abv2_api(api_key = "sbv2_jvnv-F1-jp", model_name = "jvnv-F1-jp")

今回は、上記の通り、SBV2のデフォルトで用意されている2モデルを呼び出しています。
特にあみたろさんのモデルに関しては、音声の質が非常に良いためおすすめです。
下記のサイトにてあみたろさんが提供してくださっている音声素材で学習されています。
https://amitaro.net/voice/livevoice/

model_nameの指定方法に関してはmodel_assetsの直下のフォルダの名前をそのまま指定してください。
コード的には、指定したフォルダの中にあるsafetensor重みなどを検索して取得しています。

もし、自分で用意した学習済みSBV2重みを利用する場合は、ここで使いたいモデルファイルが格納されているフォルダ名を指定してください。

実行時関数

def call_TTS_API(text,api_key = "sbv2_amitaro"):
    url = "http://127.0.0.1:8001/process/"
    headers = {"api_key": api_key}

    inputs = {
        "text": text,
    }

    response = requests.post(url, json=inputs, headers=headers)
    # JSONデータとしてレスポンスを解析
    data = response.json() 

    audio = data['audio']
    audio = np.array(audio, dtype=np.float32)
    audio = audio / 32768.0
    sr = data['sr']

    return audio, sr

サーバ側のSBV2クラスのcall_TTSメソッドをAPIで呼び出す関数です。
呼び出したい音声モデルに紐づくapi_keyと発話したいテキストを引数に指定すると、合成音声後の音声波形audioとサンプリングレートsrが取得できます。

取得した音声波形の再生も含めて下記のように呼び出します。

audio, sr = call_TTS_API("こんにちは。",api_key = "sbv2_amitaro")
sd.play(audio, sr)
sd.wait()
audio, sr = call_TTS_API("こんにちは。",api_key = "sbv2_jvnv-F1-jp")
sd.play(audio, sr)
sd.wait()

上は「あみたろ」さんの音声で、下は「jvnv-F1-jp」の音声で「こんにちは」を発話します。
また、あらかじめ初期化を行っておけば、api_keyを変えてcall_TTS_API関数を呼び出すことで、さまざまな声で音声合成を行うことが可能です。

実行

環境構築済みの2つのターミナルを開き、それぞれのターミナルで下記を実行してください。

python sbv2_api.py
python client.py

sbv2_api.pyを実行した端末において、下記が表示されたら、client.pyのコマンドを実行してください

server started
INFO:     Started server process [11246]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8001 (Press CTRL+C to quit)

client.pyのコマンドを実行すると、2話者の声で「こんにちは」が再生されるはずです。

まとめ

今回はSBV2を別サーバからAPIで呼び出す方法について記載しました。
このように実装することで、クライアント側から音声を指定しつつ、API呼び出しでSBV2を呼べるかと思います。

また、SBV2自体は非常に高性能な合成音声モデルです。
詳しい使い方なども下記のように記事にしているため、ぜひ試してみてください。
https://zenn.dev/asap/articles/f8c0621cdd74cc

ここまで読んでくださり、ありがとうございました!

Discussion