🗣️

音声認識AIのWhisperをUnreal Engineでリアルタイムに動かすためにやったこと

2022/12/23に公開

「Unreal Engine (UE) Advent Calendar 2022 その3」23日目の記事です。

はじめに

OpenAIの音声認識AI「Whisper」がすごいらしい。これをUnreal Engineでリアルタイムに動かせるようにしたら応用範囲が広がっておもしろいんじゃないかと思いました。

(「異議あり!」って実際に声に出させたいよね)
NLPアドベンチャーを音声入力で、みたいな夢も広がる)

しかし、いざやってみたらいろいろな課題にぶつかりました。この記事は、それらをどう解決したかの記録です。

目次

  1. 目標設定:C++とONNX Runtimeで実装する
  2. Whisperの処理の全体感
  3. 課題と対応
    • 課題1:マイク入力と前処理をC++で実装する
    • 課題2:Whisperの機械学習モデルをONNXにエクスポートする
    • 課題3:ONNXモデルをtransformer&FP16向けに最適化する
    • 課題4:ONNXモデルをUnreal Engineから実行できるようにする
    • 課題5:TokenizerをC++で実装する
    • 課題6:ゲームスレッドをブロックしないようマルチスレッド化する
  4. ゲーム向けに使いやすいコンポーネントを用意する
    • 音声をリアルタイムに文字に起こす
    • 指定された短いフレーズの内のどれが話されたかを判別する
    • 指定された長いフレーズの内のどれがどこまで話されたかを判別する

目標設定:C++とONNX Runtimeで実装する

Whisperの公式リポジトリでは、PyTorchによる実装が公開されています。
これをこのままUE4/UE5で使おうとすると、Pythonインタプリタおよび依存するライブラリの再配布が必要になるとともに、UE-Python間の呼び出し等で処理に無駄が多くなりそうです。

そこで、PyTorchで実装された機械学習モデルをONNX形式でエクスポートし、ONNX RuntimeのネイティブライブラリをUnreal C++から呼び出す形で実装することにします。

ONNX Runtimeは、Windows、Linux、Mac、Android、iOSなどの各種プラットフォーム向けにDLLやSOファイルが用意されていますし、ONNX Runtimeで実行する方がPyTorchより処理速度が向上することが多いので、非常に都合が良いです。
(ちなみに、UE5.0で追加されたML Deformerプラグインも内部ではONNX Runtimeを使用しているようです)

また、Whisperのメインの処理の前後に必要な各種依存ライブラリ(Tokenizer等)も、Pythonを使わずC++で完結するようになんとかしていきます。

Whisperの処理の全体感

まずは、Whisperで音声を処理する流れの全体感を確認しましょう。ざっくり言うと下記の通りです。

  1. 前処理
    1.1. マイク入力:マイクから音声波形を取得
    1.2. スペクトル分析:音声波形をLog-Melスペクトログラムに変換
  2. AI処理
    2.1. Whisperエンコード:Log-Melスペクトログラムをエンコーダに入力し、音声特徴量を取得
    2.2. Whisperデコード:音声特徴量をデコーダに入力し、一つ目のトークン(≒単語)を取得
    2.3. Whisperデコードループ:音声特徴量とこれまでのトークンを入力に、次のトークンを取得。これを文末まで繰り返す。
  3. 後処理
    3.1. Tokenizerデコード:トークンの配列を文字列に変換

「1.前処理」は、一般的な音声信号処理です。マイク入力をWhisperに入力可能な形式に変換します。
「2.AI処理」は、Whisperのメインの部分です。GPUを使った並列計算で10~1000ms程度かかる重い処理を行います。
「3.後処理」は、Whisperの出力する「トークン配列」(実体はint配列)を人間が理解可能な文字列に変換する処理です。

課題と対応

上記の処理を目標設定に沿ってUnreal Engineでリアルタイム実行するためには、下記のような課題に対応する必要があります。

  1. 前処理
    • 課題1:マイク入力とスペクトル分析をC++で実装する
  2. AI処理
    • 課題2:Whisperの機械学習モデルをONNXにエクスポートする
    • 課題3:ONNXモデルをtransformer&FP16向けに最適化する
    • 課題4:ONNXモデルをUnreal Engineから実行できるようにする
  3. 後処理
    • 課題5:TokenizerをC++で実装する
  4. その他
    • 課題6:ゲームスレッドをブロックしないようマルチスレッド化する

ということで、これらをゴリゴリとやっつけていきます。
各課題で対応する内容が多岐にわたるため、それぞれどんな分野に関わるかを簡単に表にしました。興味のある部分だけ読んでくださいませ。

課題 Unreal Engine ONNX NLP Whisper
1
2
3
4
5
6

個別課題格闘編

課題1:マイク入力とスペクトル分析をC++で実装する

Whisperの公式実装は録音済みのデータに対して処理を行う形になっているので、マイク入力の仕組みを用意してやる必要があります。また、音声波形に対するスペクトル分析をC++に移植する必要があります。

Unreal Engineを使っていて嬉しいのは、マイク入力のような一般的な機能はエンジン標準で用意されていることです。
UEのAudioCaptureCoreモジュールのFAudioCaptureを使用すると、マイク入力とスペクトル分析は下記のように実装できます。

MyPreprocessor.h
#include "AudioCaptureCore.h"

class MyPreprocessor
{
public:
    bool OpenDefaultAudioStream();
    void StartCapturingAudio();

protected:
    Audio::FAudioCapture AudioCapture;
};
MyPreprocessor.cpp
bool MyPreprocessor::OpenDefaultAudioStream()
{
    if (!AudioCapture.IsStreamOpen())
    {
        Audio::FOnCaptureFunction OnCapture = [this](const float* AudioData, int32 NumFrames, int32 InNumChannels, int32 InSampleRate, double StreamTime, bool bOverFlow)
        {
            // スペクトル分析の中身
        };

        Audio::FAudioCaptureDeviceParams Params;
        if (AudioCapture.OpenCaptureStream(Params, MoveTemp(OnCapture), 1024))
        {
            Audio::FCaptureDeviceInfo Info;
            if (AudioCapture.GetCaptureDeviceInfo(Info))
            {
                Init(Info.PreferredSampleRate, Info.InputChannels);

                // スペクトル分析のための初期化処理

                return true;
            }
        }
    }
    return false;
}

void MyPreprocessor::StartCapturingAudio()
{
    if (AudioCapture.IsStreamOpen())
    {
        AudioCapture.StartStream();
    }
}

マイクからの波形データ(AudioData)が取得されると、FAudioCapture::OpenCaptureStreamに指定したコールバック関数OnCaptureが呼び出されます。この中で波形データに対する処理を行ってやればよいというわけです。

なお、より豊富な実装例が、Engine\Plugins\Runtime\AudioCapture\Source\AudioCapture\Public\AudioCapture.hにあったりしますので、詳しくはそちらを…。上記のコードもほぼそのコピペです。

具体的なスペクトル分析の中身は、あまり書いてもおもしろくないので割愛します。フーリエ変換してlogとって定数を掛けたりするだけなので…。

課題2:Whisperの機械学習モデルをONNXにエクスポートする

目標設定の項で述べた通り、機械学習モデルをPyTorchからONNXへ変換してやりたいわけですが、torch.onnx.exportでさくっと一発とはいきません。

なぜなら、デコーダにtorch.TensorでないKvCacheという名前の変数が使われていて、これがそのままだとONNXに出力されないからです。

なので、KvCache相当の変数をtorch.Tensorで定義しなおし、これを使用するようモデル構造を修正します。
具体的には、まずは下記のようにMultiHeadAttentionをラップするクラスを定義してやります。

model.py
class MultiHeadAttention_cross(nn.Module):
    def __init__(self, in_multiHeadAttention: MultiHeadAttention):
        super().__init__()
        self.multiHeadAttention = in_multiHeadAttention

    def forward(
        self,
        x: Tensor,
        k: Tensor,
        v: Tensor,
        mask: Optional[Tensor] = None,
    ):
        q = self.multiHeadAttention.query(x)
        wv = self.multiHeadAttention.qkv_attention(q, k, v, mask)
        return self.multiHeadAttention.out(wv)

class MultiHeadAttention_self(nn.Module):
    def __init__(self, in_multiHeadAttention: MultiHeadAttention):
        super().__init__()
        self.multiHeadAttention = in_multiHeadAttention

    def forward(
        self,
        x: Tensor,       #(b, n_ctx      , n_state)
        k_cache: Tensor, #(b, n_ctx_cache, n_state)
        v_cache: Tensor, #(b, n_ctx_cache, n_state)
        mask: Optional[Tensor] = None,
    ):
        q = self.multiHeadAttention.query(x) #(b, n_ctx, n_state)
        k = self.multiHeadAttention.key(x)   #(b, n_ctx, n_state)
        v = self.multiHeadAttention.value(x) #(b, n_ctx, n_state)

        if k_cache is not None:
            k = torch.cat((k_cache, k), 1) #(b, n_ctx_cache + n_ctx, n_state)
            v = torch.cat((v_cache, v), 1) #(b, n_ctx_cache + n_ctx, n_state)

        wv = self.multiHeadAttention.qkv_attention(q, k, v, mask)
        return self.multiHeadAttention.out(wv), k, v

ポイントは、

  • MultiHeadAttentionforwardを迂回してqkv_attentionを直接呼ぶことで、KvCacheを使わない
  • Crossアテンションでは、k, vはデコードのループ中は変わらないので、エンコーダで事前に計算しておいてやり、それを入力として渡すようにする
  • Selfアテンションでは、k, vはデコードのループ中で末尾に追加していくので、torch.catしてその結果を出力として返すようにする
    (デコードのSelfアテンションでは、未来へのアテンションをmaskすることでunidirectionalになっていて、後ろに追加されたトークンはそれ以前の結果に影響しないのでこういうことができる)

です。

次に、これらを使用するように、ResidualAttentionBlockTextDecoderをラップするクラスを定義します。

model.py
class ResidualAttentionBlock_tensorCache(nn.Module):
    def __init__(self, in_residualAttentionBlock: ResidualAttentionBlock):
        super().__init__()
        self.originalBlock = in_residualAttentionBlock
        self.attn = MultiHeadAttention_self(in_residualAttentionBlock.attn)
        self.cross_attn = MultiHeadAttention_cross(in_residualAttentionBlock.cross_attn) if in_residualAttentionBlock.cross_attn else None

    def forward(
        self,
        x: Tensor,
        self_k_cache: Optional[Tensor] = None,
        self_v_cache: Optional[Tensor] = None,
        cross_k: Optional[Tensor] = None,
        cross_v: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
    ):
        self_attn_x, self_k_cache_updated, self_v_cache_updated = self.attn(self.originalBlock.attn_ln(x), self_k_cache, self_v_cache, mask=mask) 
        x = x + self_attn_x

        if self.cross_attn:
            x = x + self.cross_attn(self.originalBlock.cross_attn_ln(x), cross_k, cross_v)

        x = x + self.originalBlock.mlp(self.originalBlock.mlp_ln(x))
        return x, self_k_cache_updated, self_v_cache_updated

class TextDecoder_tensorCache(nn.Module):
    def __init__(self, in_textDecoder: TextDecoder, in_n_ctx: int):
        super().__init__()
        self.textDecoder = in_textDecoder
        self.n_ctx = in_n_ctx

        self.blocks = []
        for orginal_block in self.textDecoder.blocks:
            self.blocks.append(ResidualAttentionBlock_tensorCache(orginal_block))

    def forward(self, x: Tensor, 
                n_layer_self_k_cache: Tensor, 
                n_layer_self_v_cache: Tensor,
                n_layer_cross_k: Tensor, 
                n_layer_cross_v: Tensor, 
                positions: Optional[Tensor] = None,
                ):
        pos_emb_slice = self.textDecoder.positional_embedding[positions]
        x = self.textDecoder.token_embedding(x) + pos_emb_slice
        x = x.to(n_layer_cross_k[0].dtype)

        i = 0
        self_k_cache_list = []
        self_v_cache_list = []
        for block in self.blocks:
            x, self_k_cache, self_v_cache = block(x, 
                                                self_k_cache = n_layer_self_k_cache[i], 
                                                self_v_cache = n_layer_self_v_cache[i],
                                                cross_k = n_layer_cross_k[i], 
                                                cross_v = n_layer_cross_v[i], 
                                                mask=self.mask)
            self_k_cache_list.append(self_k_cache)
            self_v_cache_list.append(self_v_cache)
            i += 1

        n_layer_self_k_cache = torch.stack(self_k_cache_list)
        n_layer_self_v_cache = torch.stack(self_v_cache_list)

        x = self.textDecoder.ln(x)

        logits = (x @ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1)).float()

        return logits, n_layer_self_k_cache, n_layer_self_v_cache

各レイヤーごとにキャッシュしてやる必要があるので、TextDecoderblockのループでSelfアテンションのk, vのキャッシュをリストに格納しておき、最後にtorch.stackでこれをtorch.Tensorにしてから返却しています。次のforwardではこれを入力してやればいいというわけです。

なお、positional embeddingのための入力の位置情報は、引数positionsとして外から入力する形にしています。ONNX化した際にpositionsの計算をGPUでやる意味はない(CPUでやってから渡してあげた方が速そう)と思ったのでこうしています。

課題3:ONNXモデルをtransformer&FP16向けに最適化する

Whisperの機械学習モデルをONNX形式にできても、そのままでは十分な性能がでません。
なぜなら、素直にONNXにエクスポートするとFP32での計算となり、FP16で計算しているPyTorch実装より効率が低下するからです。

ここで登場するのが、ONNX公式のTransformer向けモデル最適化ツールです。これは下記のような機能を提供します。

  • ONNX Contrib Operatorで定義されているAttentionなどのtransformer向けFusedノードへの置き換え
  • FP32からFP16への変換
  • 入出力をDynamic shapeに変換

とても便利そうですし、実際これを使うとTransformerを使ったモデルはだいたい倍速くらいになります。

が、問題はこれを適用可能なモデルが限定されていることです。

Supported Models

Here is a list of PyTorch models from Huggingface Transformers that have been tested using the optimizer:

  • BERT
  • DistilBERT
  • DistilGPT2
  • RoBERTa
  • ALBERT
  • GPT-2 (GPT2Model, GPT2LMHeadModel)

For Tensorflow model, we only tested BERT model so far.

上記の不穏な文言とツールの実装を眺めればわかりますが、このツールは処理に特定のモデル構造がハードコードされているため、任意のモデル構造に適用可能な代物ではありません

でも、使いたい。ならば、自分のモデル構造をハードコードする闇の改造を施してやりましょう。

まずは、READMEの記載にしたがってとりあえずツールを実行してどうなるか見てみます。

python -m onnxruntime.transformers.optimizer --input mywhisper_encoder.onnx --output mywhisper_encoder_fp16.onnx --num_heads 12 --hidden_size 768 --float16

おそらく、GeluLayerNormarizationはFusedノードへの置き換えに成功しますが、Attentionノードは置き換えに失敗するのではないでしょうか。
原因は、Attentionノードへの置き換え対象を取得する処理が先に述べたように特定モデル用にハードコードされているからです。なので、そこを改造します。

fuse関数の改造

問題の箇所は、fusion_attention.pyにあります。
ここで定義されているfuse関数は、第一引数にモデル中で見つかったLayerNormarizationノードを受け取り、そこを起点にAttentionノードに置き換える対象を検索して置き換えを行うものです。

fuse関数の中で、下記のようにmatch_parent_pathという関数が何度も使われています。

qkv_nodes = self.model.match_parent_path(
            start_node,
            ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
            [None, None, 0, 0, 0],
        )

これは、起点のノード(start_node)からさかのぼって、特定の種類のノード(ここでは["Add", "MatMul", "Reshape", "Transpose", "MatMul"])が繋がっているかを検索し、繋がっているならそれらのノードを返す関数です。第3引数に指定している数字は、各ノードの親が複数あるときにどれを調べるかをindexで指定するものです。

これを踏まえてソースを睨んでいると、fuse関数は下記図のようにノードを検索していることがわかります。(緑がmatch_parent_pathで検索している範囲、赤字が検索結果のノードの変数名)

fuse関数の目的は、この図に出てくるノードを見つけてくることだけです。そしてそれらのノードを入力として下記のcreate_attention_node関数が呼び出され、Attentionノードへの置き換えが行われます。

new_node = self.create_attention_node(
    mask_index,
    matmul_q,
    matmul_k,
    matmul_v,
    add_q,
    add_k,
    add_v,
    q_num_heads,
    q_hidden_size,
    root_input,
    attention_last_node.output[0],
    add_qk_str,
)

なので、自分のモデルにあわせてmatch_parent_pathを駆使して必要なノードを見つけてきて、create_attention_node関数にぶち込んでやればよいというわけです。がんばりましょう。

create_attention_node関数の改造

…というわけなのですが、実はまだやることがあります。

create_attention_node関数は、onnx.helperを用いてノードを作成してくれます。が、実はこの関数では課題2でせっかく用意したk, vのキャッシュを使った形のAttentionノードを作成してくれません。
そのため、Attentionノードの定義を見ながら、create_attention_node関数を適切に改造して、k, vキャッシュの入出力を追加してやる必要があります。

具体的には、下記のようにAttentionノードの入力のpastと、出力のpresentに適切なノード名を指定してやればOKです。

attention_inputs = [
    input,
    attention_node_name + "_qkv_weight",
    attention_node_name + "_qkv_bias",
    mask_index,
    past_node_name
]

attention_outputs = [
    output,
    present_node_name
]

attention_node = helper.make_node(
    "Attention",
    inputs=attention_inputs,
    outputs=attention_outputs,
    name=attention_node_name,
)

pastpresentのノード名は、やはりmatch_parent_pathを駆使してうまくモデルから見つけてくる必要があります。
これらのノードを見つけやすいようにPyTorch上でモデルをいい感じに改造からエクスポートしてやると、ここは比較的簡単に実装できる気がしますが、私のソースは試行錯誤の末大変汚いことになっていてもはや修正する気も起きないのでちょっと公開は控えさせていただきます。

課題4:ONNXモデルをUnreal Engineから実行できるようにする

ONNXにエクスポートした機械学習モデルをUnreal Engineから実行するためには、ONNX RuntimeをUnreal Engineから呼び出せるようにする必要があります。

  • 一つの方法は、UE5.0から提供されているNeuralNetworkInferenceプラグインの中のONNX Runtimeモジュールを使用する方法です。
    ただし、UE5.1時点で NeuralNetworkInferenceプラグインはExperimental の上、少々使い方がわかり辛いため注意が必要です。

  • そのため、ある程度ONNX Runtimeを使い慣れており、Unreal Engineのプラグインを作成できるのならば、Engine\Plugins\Experimental\NNI以下の実装を参考に、最新のONNX RuntimeのビルドをGitHubから取得してUnreal Engineに組み込むのがシンプルで確実です。

  • あるいは、Unreal EngineでONNX Runtimeを簡便に使用することに特化したプラグインがマーケットプレイスで販売されていたりしますので、それを使用してもよいでしょう。
    (私が作りました)

課題5:TokenizerをC++に実装する

Whisperの機械学習モデルは、音声を入力として受け取り、最終的に0~50256の整数型で表されるToken IDの配列を出力します。
一つのToken IDは、一つのバイト列に対応します。複数Token IDからのバイト列を結合してこれをUTF-8として解釈すると、最終的な文字列が取得できます。
(Byte LevelなByte Pair Encodingがなされている)


Whisperの出力するToken IDの配列から文字列への変換例

Token IDと文字列の相互変換に関する機能を提供するライブラリは、Tokenizerと呼ばれます。
Whisperの公式実装では、Tokenizerとしてhuggingface/tokenizersが使用されています。これはrustで実装されており、残念ながらC++のバインディングは提供されていません。

そこで、huggingface/tokenizersのrust実装を読み、ほしい機能をC++で実装することにしました。といっても、その内容はさほど難しくなく、

  1. 語彙ファイルから、Token IDとバイト列の対応表を取得して保持しておく
  2. Token IDの配列 → バイト列の配列 → 結合された1つのバイト列 → UTF-8の文字列」という変換をする

という処理を作るだけです。(これを理解するのに時間がかかったわけですが…)

上記2に該当するrustの実装はこのあたりこのあたりこのあたりにあり、抜粋するとこんな感じです。

src/tokenizer/mod.rs
/// Decode the given ids, back to a String
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> Result<String> {
    let tokens = ids
        .into_iter()
        .filter_map(|id| {
            self.added_vocabulary
                .id_to_token(id, &self.model)
                .filter(|token| {
                    !skip_special_tokens || !self.added_vocabulary.is_special_token(token)
                })
        })
        .collect::<Vec<_>>();
src/pre_tokenizers/byte_level.rs
/// As a `Decoder`, `ByteLevel` is in charge of converting any byte-level characters to their
/// unicode counterpart, before merging everything back into a single String.
/// This decoder will consume the tokens and merge them in one step to alleviate
/// the fact that single token decoded might be a byte not representable as
/// as String.
impl Decoder for ByteLevel {
    fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
        let toks = tokens
            .into_iter()
            .flat_map(|t| {
                t.chars()
                    .try_fold(vec![], |mut acc, c| {
                        CHAR_BYTES.get(&c).map(|b| {
                            acc.push(*b);
                            acc
                        })
                    })
                    .unwrap_or_else(|| t.as_bytes().to_vec())
            })
            .collect::<Vec<u8>>();
        Ok(vec![String::from_utf8_lossy(&toks).to_string()])
    }
}

イテレータを駆使した構文で、rust初心者の私は初見ではさっぱりわかりませんでしたが、気合で解読して日本語にすると、

  • uint32の配列のidsを、id_to_token関数を使って、stringの配列tokensに変換
  • stringの配列tokensを、CHAR_BYTESマップを使って、uint8の配列toksに変換。さらに、toksをfrom_utf8_lossy.to_stringを使って、最終的な文字列に変換。

ということをしています。先に説明した内容と違って途中にstringへの変換が挟まっていますが、これはあまり本質ではない(はず)です。(おそらく、語彙ファイルの仕様上そうしたのではないかと推測)

され、これをC++で実装したものがこちらです。

TokenizerDecoder.cpp
FString detokenize(const std::vector<int64_t>& ids)
{
    std::vector<uint8_t> utf8bytes;

    for (const int64_t id : ids)
    {
        // process only normal text tokens
        if (vocab->isText(id))
        {
            const auto decodedBytes = vocab->getBytes(id);
            appendBytes(decodedBytes, utf8bytes);
        }
    }

    // add 0 to the last
    if (utf8bytes.size() > 0 && utf8bytes.back() != 0)
    {
        utf8bytes.push_back(0);
    }

    TCHAR* characters = UTF8_TO_TCHAR((char*)utf8bytes.data());
    return FString(characters);
}

途中で使っているvocabは語彙ファイルを読み込んだ情報を保持しているクラスだと思ってください。

vocab->isText(id)で、与えられたToken IDが文字列に変換してよいものかどうかを判定し、OKならvocab->getBytes(id)でIDに対応するバイト配列decodedBytesを取ってきています。
decodedBytesutf8bytesに順に結合していき、末尾に文字列の終了を表す0を追加した上で、UEのマクロUTF8_TO_TCHARを経由してFStringに変換しています。

ちなみに、逆向きの変換(文字列→Token ID配列)は、もう少し複雑です。
後述する「ゲーム向けに使いやすいコンポーネントを用意する」ために必要なのでこちらも気合でC++に移植しました。

とても長くなるので詳細は割愛しますが、この変換のポイントは「Token IDとバイト列の対応表」だけでは変換が一意に決まらないことです。どのバイト列を優先してToken IDにするかという規則が必要で、これは別途mergeファイルとして用意してやることになります。

課題6:ゲームスレッドをブロックしないようマルチスレッド化する

機械学習モデル推論の重い処理でゲームのFPSが低下したり、マイク入力が止まってしまっては困ります。そこで、「1.前処理」、「2.AI処理」をそれぞれゲームスレッドとは別のスレッドで実行することにします。


マルチスレッド化のイメージ。処理の長さは適当

前処理の独立スレッド化

「1.前処理」は、マイクから取得されるデータに対して常に実行する必要があり、また、一般的なCPUで十分短い時間内に実行可能な程度の小さな計算量であることから、CPUの独立したスレッドで逐次実行することとしました。

課題1の項で書いたFAudioCapture::OpenCaptureStreamに指定したコールバック関数OnCaptureの呼び出しは、ゲームスレッドから独立したスレッドで実行されるので、そのスレッドの中でスペクトル分析までやってやるだけです。便利~。

AI処理の非同期タスク化

「2.AI処理」はかなり重い処理であり、かつ、ゲームプログラマが実行タイミングを制御できるようにしたいので、UEのFAsyncTaskを使ってゲームスレッドから非同期タスクとしてスレッドプールに実行を投げる形としました。

FAsyncTask自体の使い方は、historia様の記事unknown_ds様の記事が大変わかりやすく、また、Engine\Source\Runtime\Core\Public\Async\AsyncWork.hのコメントにも実装例が記載されています。

下記の実装では、これらを参考にさせていただきました。というよりほぼそのコピペですが…。異なる点としては、非同期実行した結果をBlueprintに返してあげる仕組みを組み込んだところです。

MyAsyncTask.h
#include "Async/AsyncWork.h"

// Class for async execution of a task
class TEST_API FMyAsyncTask : public FNonAbandonableTask
{
    friend class FAsyncTask<FMyAsyncTask>;

public:
    FMyAsyncTask(TFunction<void()> InWork)
        : Work(InWork)
    {}

    void DoWork()
    {
        // Execute the function specified in the constructor
        Work();
    }

    FORCEINLINE TStatId GetStatId() const
    {
        RETURN_QUICK_DECLARE_CYCLE_STAT(FMyAsyncTask, STATGROUP_ThreadPoolAsyncTasks);
    }

private:
    TFunction<void()> Work;
};
MyAsyncComponent.h
#include "CoreMinimal.h"
#include "Components/ActorComponent.h"
#include "MyAsyncTask.h"
#include "MyAsyncComponent.generated.h"

// 結果通知用のDelegateを定義
DECLARE_DYNAMIC_MULTICAST_DELEGATE_OneParam(FMyResultDispatcher, FMyResult, result);

UCLASS()
class TEST_API UMyAsyncComponent : public UActorComponent
{
    GENERATED_BODY()

    // BPでアサイン可能なDelegateとして定義
    UPROPERTY(BlueprintAssignable) FMyResultDispatcher onAsyncTaskDone;

protected:
    virtual void BeginPlay() override;
    virtual void TickComponent( float DeltaTime, ELevelTick TickType, FActorComponentTickFunction* ThisTickFunction ) override;
    virtual void EndPlay() override;

    // 結果の一時格納用
    FMyResult result;

    // 非同期実行用
    TFunction< void() > taskFunc;
    FAsyncTask<FMyAsyncTask>* asyncTask = nullptr;
}
MyAsyncComponent.cpp
#include "MyAsyncComponent.h"

void UMyAsyncComponent::MainTask()
{
    result = /* AIを使った重い処理 */
}

void UMyAsyncComponent::BeginPlay()
{
    Super::BeginPlay();

    // 非同期タスクを登録
    taskFunc = [this] {
        this->MainTask();
    };
    asyncTask = new FAsyncTask<FMyAsyncTask>(taskFunc);
}

void UMyAsyncComponent::TickComponent(float DeltaTime, ELevelTick TickType, FActorComponentTickFunction* ThisTickFunction)
{
    // 前回非同期タスクが終わっていれば
    if (asyncTask && asyncTask->IsDone())
    {
        // 最後の結果をブロードキャストして
        onAsyncTaskDone.Broadcast(result);;

        // 次の非同期タスクを実行
        asyncTask->StartBackgroundTask();
    }
}

void UMyAsyncComponent::EndPlay(const EEndPlayReason::Type EndPlayReason)
{
    // 非同期タスクを削除
    if (asyncTask)
    {
        asyncTask->EnsureCompletion();
        delete asyncTask;
        asyncTask = nullptr;
    }

    Super::EndPlay(EndPlayReason);
}

非同期実行した結果をBlueprintに返すためにDECLARE_DYNAMIC_MULTICAST_DELEGATE_OneParamでBPアサイン可能なDelegateを定義し、結果をブロードキャストしています。
ブロードキャストの実行タイミングを「非同期タスク終了後の次のTick」かつ「次の非同期タスクの実行前」とすることで、BP側ではデータアクセスが他スレッドと競合する可能性を極力考えなくていいようにしました。

ゲーム向けに使いやすいコンポーネントを用意する

ここまでで、WhisperをUnreal Engineで動かすための基本的な準備は整いました。
ここからは、実際にゲーム的な用途で使用するためのコンポーネントをどう作成したかという話をします。

音声をリアルタイムに文字に起こす

まずは、Whisperのもっとも普通の使い方として、マイク入力をリアルタイムで文字起こしする機能を作りました。
ここでポイントとなるのは、いつからいつまでの音声を対象として処理をするかです。

Whisperの公式実装は、録音データを30秒ごとのチャンクに区切って処理を行っています。リアルタイム用途としては30秒ごとの実行では遅すぎるので、適当な間隔(例えば100ms)ごとに直近30秒の音声データをとってきて処理してやることにしました。

また、Whisperのデコードループが回る回数を極力減らすため、下記のように前時刻との差分のみをデコードするよう効率化を行いました。

  1. Whisperのデコーダに渡すのは「直近の無音区間の終了から現在時刻まで」の音声特徴量だけにする
  2. 「直近の無音区間の終了から次の無音区間の出現まで」の間は、デコードしたトークンを一定の基準でキャッシュする


(0). 通常のデコード処理の場合。毎時刻で直近30秒の音声データに対して素直に処理を行う。この例ではデコードループを7回実行となる。
(1). 直近の無音区間の終了から現在時刻までの音声特徴量だけをデコードする場合。この例ではデコードループは4回実行で済む。
(2). デコードしたトークンをキャッシュする場合。理想的には、毎時刻デコードループは2回程度の実行で済む。

ここで、2の「一定の基準」が曲者でした。
なぜそのような基準が必要かというと、音声特徴量の末尾は発話途中でぶった切られているので、デコードしたトークン列の末尾は正しくない結果になりがちだからです。末尾の正しくないトークンを除いた上で、正しそうなものだけをキャッシュし、次の時刻の音声特徴量のデコードはその続きから再開したいわけです。
が、結局あまりいい方法が思いつかなかったので、単純に「末尾からN個はキャッシュしない」という実装を採用しました。

BPには、

  • OnSpeaking:「直近の無音区間の終了から現在時刻まで」のデコード結果を通知するイベント
  • OnSpoken:「前回の無音区間の終了から直近の無音区間の出現まで」のデコード結果を通知するイベント

の二つを公開し、前者はリアルタイムだが末尾の結果が信頼できないものとして、後者は遅延があるが結果はそれなりに信頼できるものとして使ってもらう形にしました。

指定された短いフレーズの内のどれが話されたかを判別する

自然言語で自由入力するゲームは将来的には発展しそうな気もしますが[1]、直近の応用としては自由入力ではなく「必殺技名を音声入力する」「パーティーメンバーにコマンドを口頭で指示する」といった使い方が有力な気がします。

そこで、事前に決めたフレーズの中から今どれが発話されたかを判定し、特定のイベントをトリガーするような仕組みを作ってみました。

これを実装するには、普通に音声から文字起こししてその結果が指定のフレーズにどれだけ近いかを判断する――という方法でもまあいいのですが、もう少し効率的な方法があります。
それは、Whisperのデコード処理の途中で出てくる「次のトークンの確率分布」を使う方法です。

Whisperのデコード処理では、音声特徴量とこれまでのトークンを入力に次のトークンを出力するという処理をループ実行しますが、そのとき実際に求めているのは「次のトークンがID=0~50256であるそれぞれの確率」です。
今回の問題設定では、指定のフレーズをトークン配列に変換した上で、それを先頭から順にWhisperデコーダに入力して、次のトークンが指定のIDのものである確率を取得してやり、これを繰り返してフレーズ全体で確率を平均すれば、そのフレーズが発話されたかどうかの指標を求めることができます。

この方法の利点は、上記の処理を実際には順に繰り返す必要はなく、GPUの得意とする並列計算でバッチ処理が可能なことです。Whisperデコード処理をバッチで一回実行すればよくなるので、普通にループを回して文字起こししてから比較する方法に比べて、はるかに遅延を削減することができます。

ちなみにこの仕組みは、フレーズを指定するスロットを用意しておき(動画の例だと合計6スロット)、各スロットがどのゲームイベントをトリガーするかをプログラムで指定する形としたので、そのスロットの中身の文言自体はエンドユーザに編集させる余地を残すことができます。(例えば、「前進」を「前に進め」など任意の文言に変更できる)

ゲーム実況などの動画配信で使ってもらうことを視野に入れると、配信者のパーソナリティにあわせて使用する文言が変更できる自由度を持たせておくというのは、もしかすると役に立つ場面があるかもしれません。

指定された長いフレーズの内のどれがどこまで話されたかを判別する

やはり呪文は自分で唱えたいですよね。プレイヤースキルで高速詠唱(=早口言葉)とか熱い。
長い呪文だと一息に言い切ることが難しいので、 「どこまで読み上げられたかを保存しておく」 機能を追加したバージョンも作成しました。

この際、「どこまで読み上げられたか」を判断する基準は、短フレーズと同様に指定のトークンが次に来る確率を求め、それがある閾値以上であるかどうかとしました。

…したのですが、問題が判明しました。喋ってる途中で意外と閾値を下回ります。(上記動画だとピンク色が残っている部分)
なので、閾値をいくつか下回ってもお目こぼしする機能を入れたりしましたが、正直結果は不安定です。

短いフレーズなら認識をミスってもエンドユーザは言い直してくれそうですが、長いフレーズはそうもいかない気がするので、安定した認識を実現する方法を引き続き探りたいと思っています。

終わりに:この記事では省略した話など

誰が読むんだこれ…という長さの記事になりましたが、書き切れていないことがまだあります。

  • Whisperのエンコードで直近の有音区間以外をキャッシュして使いまわす実験
  • WhisperのデコードループをONNXの中で完結させる実験
  • Tokenizerのエンコード処理の実装
  • マルチスレッド化にともなう、前処理済みデータをAI処理で利用する際のスレッドセーフティ
  • ゲーム向けの短フレーズ認識での、認識直後の言い直しを実現するための工夫
  • ゲーム向けの長フレーズ認識での、閾値設定の考え方の模索

これらを書いていると年が明けてしまうので、諦めました。
それではみなさん、メリークリスマス&よいお年をー!

脚注
  1. ChatGPT的な機能を持つNPCがいたりとか ↩︎

Discussion

Hidden comment