✈️

Open AI Realtime APIのPythonサンプルコードで、AI発話にカットイン(割り込み)する方法

2024/10/06に公開

はじめに

https://zenn.dev/asap/articles/4368fd306b592a

先日、Open AI Realtime APIのPythonサンプルコードを共有させていただきましたが、相手の発話にカットイン(割り込み)する方法だけ実装していなかったので、そちらを実装したコードを共有いたします。

取り急ぎコードだけ共有し、コードの解説は改めて記事更新する形で共有させていただきます
(コードの解説は10/7くらいに更新になりそうです)

解説に関しても追記いたしました。(10月6日 22:00)

今回のサンプルコードということで、1ファイルで全てのコードを記述していて、わかりやすく記載することを目指しましたので、皆様の参考になれば幸いです。

サンプルコード

相手の発話にカットインする機能を追加したサンプルコードは下記です。
実行方法は前回の記事と同じです。

main.py

import asyncio
import websockets
import pyaudio
import numpy as np
import base64
import json
import queue
import threading
import os
import time

API_KEY = os.environ.get('OPENAI_API_KEY')

# WebSocket URLとヘッダー情報
# OpenAI
WS_URL = "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
HEADERS = {
    "Authorization": "Bearer " + API_KEY,
    "OpenAI-Beta": "realtime=v1"
}

# キューを初期化
audio_send_queue = queue.Queue()
audio_receive_queue = queue.Queue()

# PCM16形式に変換する関数
def base64_to_pcm16(base64_audio):
    audio_data = base64.b64decode(base64_audio)
    return audio_data

# 音声を送信する非同期関数
async def send_audio_from_queue(websocket):
    while True:
        audio_data = await asyncio.get_event_loop().run_in_executor(None, audio_send_queue.get)
        if audio_data is None:
            continue
        
        # PCM16データをBase64にエンコード
        base64_audio = base64.b64encode(audio_data).decode("utf-8")

        audio_event = {
            "type": "input_audio_buffer.append",
            "audio": base64_audio
        }

        # WebSocketで音声データを送信
        await websocket.send(json.dumps(audio_event))

        # キューの処理間隔を少し空ける
        await asyncio.sleep(0)

# マイクからの音声を取得しキューに入れる関数
def read_audio_to_queue(stream, CHUNK):
    while True:
        try:
            audio_data = stream.read(CHUNK, exception_on_overflow=False)
            audio_send_queue.put(audio_data)
        except Exception as e:
            print(f"音声読み取りエラー: {e}")
            break

# サーバーから音声を受信してキューに格納する非同期関数
async def receive_audio_to_queue(websocket):
    print("assistant: ", end = "", flush = True)
    while True:
        response = await websocket.recv()
        if response:
            response_data = json.loads(response)

            # サーバーからの応答をリアルタイムに表示
            if "type" in response_data and response_data["type"] == "response.audio_transcript.delta":
                print(response_data["delta"], end = "", flush = True)
            # サーバからの応答が完了したことを取得
            elif "type" in response_data and response_data["type"] == "response.audio_transcript.done":
                print("\nassistant: ", end = "", flush = True)

            #こちらの発話がスタートしたことをサーバが取得したことを確認する
            if "type" in response_data and response_data["type"] == "input_audio_buffer.speech_started":
                #すでに存在する取得したAI発話音声をリセットする
                while not audio_receive_queue.empty():
                        audio_receive_queue.get() 

            # サーバーからの音声データをキューに格納
            if "type" in response_data and response_data["type"] == "response.audio.delta":
                base64_audio_response = response_data["delta"]
                if base64_audio_response:
                    pcm16_audio = base64_to_pcm16(base64_audio_response)
                    audio_receive_queue.put(pcm16_audio)
                    
        await asyncio.sleep(0)

# サーバーからの音声を再生する関数
def play_audio_from_queue(output_stream):
    while True:
        pcm16_audio = audio_receive_queue.get()
        if pcm16_audio:
            output_stream.write(pcm16_audio)

# マイクからの音声を取得し、WebSocketで送信しながらサーバーからの音声応答を再生する非同期関数
async def stream_audio_and_receive_response():
    # WebSocketに接続
    async with websockets.connect(WS_URL, extra_headers=HEADERS) as websocket:
        print("WebSocketに接続しました。")

        update_request = {
            "type": "session.update",
            "session": {
                "modalities": ["audio", "text"],
                "instructions": "日本語かつ関西弁で回答してください。",
                "voice": "alloy",
                "turn_detection": {
                    "type": "server_vad",
                    "threshold": 0.5,
                }
            }
        }
        await websocket.send(json.dumps(update_request))

        # PyAudioの設定
        INPUT_CHUNK = 2400
        OUTPUT_CHUNK = 2400
        FORMAT = pyaudio.paInt16
        CHANNELS = 1
        INPUT_RATE = 24000
        OUTPUT_RATE = 24000

        # PyAudioインスタンス
        p = pyaudio.PyAudio()

        # マイクストリームの初期化
        stream = p.open(format=FORMAT, channels=CHANNELS, rate=INPUT_RATE, input=True, frames_per_buffer=INPUT_CHUNK)

        # サーバーからの応答音声を再生するためのストリームを初期化
        output_stream = p.open(format=FORMAT, channels=CHANNELS, rate=OUTPUT_RATE, output=True, frames_per_buffer=OUTPUT_CHUNK)

        # マイクの音声読み取りをスレッドで開始
        threading.Thread(target=read_audio_to_queue, args=(stream, INPUT_CHUNK), daemon=True).start()

        # サーバーからの音声再生をスレッドで開始
        threading.Thread(target=play_audio_from_queue, args=(output_stream,), daemon=True).start()

        try:
            # 音声送信タスクと音声受信タスクを非同期で並行実行
            send_task = asyncio.create_task(send_audio_from_queue(websocket))
            receive_task = asyncio.create_task(receive_audio_to_queue(websocket))

            # タスクが終了するまで待機
            await asyncio.gather(send_task, receive_task)

        except KeyboardInterrupt:
            print("終了します...")
        finally:
            if stream.is_active():
                stream.stop_stream()
            stream.close()
            output_stream.stop_stream()
            output_stream.close()
            p.terminate()

if __name__ == "__main__":
    asyncio.run(stream_audio_and_receive_response())


Azure版に接続する

接続情報を下記のように変更することで、Azure版でも動作することを確認しました。


#Azure OpenAI
WS_URL = "wss://realtimeapi-inst-sample.openai.azure.com/openai/realtime?deployment=gpt-4o-realtime-xxxx&api-version=2024-10-01-preview"
HEADERS = {
    "api-key": "xxxxxxxx", 
}

URLの中身のwss://realtimeapi-inst-sample.openai.azure.comはAzure OpenAI リソースの「エンドポイント」を参照してください。
(最初の「https://」はWebsocketに接続するために「wss://」に変更してください)

また、URLにて、deployment=gpt-4o-realtime-xxxxの部分はAzure OpenAI Studioにて、gpt-4o-realtime-previewをデプロイする際の「デプロイ名」を入れてください。
"api-key": "xxxxxxxx"の部分はAzure OpenAI リソースの「キー」です。

実験結果

下記の動画をご覧ください(音量が小さめなので、静かな場所でご覧ください)
https://youtu.be/ky7ophAFNOI

動画の通り、AIの発話中にカットインができていると思います

コードの解説

設計概要

前回の記事で紹介したサンプルコードと比較して大きく変わったところは下記になります。

  • 音声を一旦Queueに格納することで、websocket通信が詰まらないように修正
  • マイク入力音声をOpenAIが取得したことをこちら側でも把握して、出力音声Bufferから削除

設計意図

音声を一旦Queueに格納することで、websocket通信が詰まらないように修正

Websocketから送信された音声をWebsocket処理と同じスレッドで処理をしたり、音声再生を行おうとすると、Websocket処理が詰まるので、一旦Queueにデータを逃すことで、通信が詰まらないように修正しました。
(最初からこの修正をサンプルコードに反映させるかは迷ったのですが、サンプルコードでは、直感的に中身がわかることを重視するべきだと思ったため、見送りました。Queueを間に挟むと4スレッドでの並行処理を行うことになるので、理解する難易度が増えると思いました)

マイク入力音声をOpenAIが取得したことをこちら側でも把握して、出力音声Bufferから削除

OpenAIサーバから送られるイベントのタイミングや中身に関してさまざま実験をしたところ、AI発話の音声に関しては、非常に高速に送られていることがわかりました。

例えば、ユーザ発話後、AI発話の音声が10秒ほどあった場合、その音声は発話開始後1-2秒ほどで全てクライアント側に送信されています。
したがって、AI発話にカットイン(割り込み)する場合、その処理はOpenAIのサーバ側ではなく、クライアント側で実装する必要があることがわかりました。
(OpenAIのサーバからしたら、すでに発話音声を全て送信済みのため、割り込まれても、発話中音声を中止させることができない)

OpenAIのAPI Referenceを確認したところ、ユーザ発話をOpenAIのサーバが取得した際、input_audio_buffer.speech_startedが発行されることがわかったので、このイベントを取得後、発話音声が格納されているQueueの中身を削除することで、発話を中止させています。

コード

前回のサンプルコードから変わっているところを主に紹介します。

Queueを経由した双方向やり取り

マイク入力音声をQueueに格納する関数

audio_send_queue = queue.Queue()
# マイクからの音声を取得しキューに入れる関数
def read_audio_to_queue(stream, CHUNK):
    while True:
        try:
            audio_data = stream.read(CHUNK, exception_on_overflow=False)
            audio_send_queue.put(audio_data)
        except Exception as e:
            print(f"音声読み取りエラー: {e}")
            break

ここで、マイク入力音声を細かい単位(CHUNK)でaudio_send_queueに格納しています。

Queueに格納したマイク入力音声をWebsocketでサーバに送信

# 音声を送信する非同期関数
async def send_audio_from_queue(websocket):
    while True:
        audio_data = await asyncio.get_event_loop().run_in_executor(None, audio_send_queue.get)
        if audio_data is None:
            continue
        
        ・・・

基本的には前回のサンプルコードと同じですが、送信する音声はQueueに格納された音声を取得しています。
この音声は、前の関数でQueueに格納されたマイク入力音声なので、マイクから音声が入力されるたびにこの関数が実行され、websocketで音声が送信されることになります。

OpenAIから取得したAI発話音声を受け取りQueueに格納する関数(割り込み実装)

audio_receive_queue = queue.Queue()
# サーバーから音声を受信してキューに格納する非同期関数
async def receive_audio_to_queue(websocket):
    ・・・
    while True:
        try:
            response = await websocket.recv()
            if response:
                response_data = json.loads(response)

                ・・・

                if "type" in response_data and response_data["type"] == "input_audio_buffer.speech_started":
                    while not audio_receive_queue.empty():
                            audio_receive_queue.get() 

                if "type" in response_data and response_data["type"] == "response.audio.delta":
                    base64_audio_response = response_data["delta"]
                    if base64_audio_response:
                        pcm16_audio = base64_to_pcm16(base64_audio_response)
                        audio_receive_queue.put(pcm16_audio)
                    
                        
        ・・・

前回のサンプルコードと同じようなところは、省略しています。

まず、下記の部分は、前回のサンプルコードとほぼ同一ですが、Websocketから取得された音声を、audio_receive_queueという新しいQueueに格納しています。

if "type" in response_data and response_data["type"] == "response.audio.delta":
    base64_audio_response = response_data["delta"]
    if base64_audio_response:
        pcm16_audio = base64_to_pcm16(base64_audio_response)
        audio_receive_queue.put(pcm16_audio)

続いて、下記の部分にてinput_audio_buffer.speech_startedのイベントを取得してます。
取得したら、AIの発話音声が格納されているaudio_receive_queueの中身を空にしています。

これは、ユーザが発話した段階(つまり割り込み)で、AIが今発話している音声は不要になるため、それ以上発話させないように、Queueは全て削除しています。

if "type" in response_data and response_data["type"] == "input_audio_buffer.speech_started":
    while not audio_receive_queue.empty():
            audio_receive_queue.get() 

ちなみに、Queueを削除する際に

audio_receive_queue = queue.Queue()

と初期化する方法もありますが、並行処理(マルチスレッド処理)を導入している場合、上記の方法で初期化すると別スレッドから参照できなくなることがあるので、中身の要素を全て取り出す形で、Queueの中身を取り除いています。

Queueに格納されたAI発話音声をスピーカーから再生する

# サーバーからの音声を再生する関数
def play_audio_from_queue(output_stream):
    while True:
        pcm16_audio = audio_receive_queue.get()
        if pcm16_audio:
            output_stream.write(pcm16_audio)

コードの通りで、Queueから取得した小さな細切れ(CHUNK)の音声データを、音声を再生するPyAudioのoutput_streamに送信しています。

まとめ

Realtime APIにて、AI発話音声にカットイン(割り込み)をすることができるpythonのサンプルコードを作成しました。
こちらのコードを利用することで、相手の発話に対してカットインが可能になります。

ちなみに、記事執筆時は10月6日だったのですが、APIが利用可能になった10月4日午前と比べると、音声の応答速度が遅くなっている気がします。
おそらく、多くの人が利用するようになったのかなと思います。これ以上遅くなるとちょっと違和感あるかなという感じなので、なるべく応答速度を今後も維持してほしいと思いました。

ここまで読んでくださりありがとうございました。

Discussion