🌌

Rust on Lambda でWebSocket とストリームレスポンスを試す

2023/12/17に公開

本記事は Rust Advent Calendar 2023 16 日目の記事です。
https://qiita.com/advent-calendar/2023/rust

AWS Lambda からストリームレスポンスを受け取りたい

AWS Lambda 上で動く Rust アプリケーション上で OpenAI の API を叩いていたのですが、こちらのブログ の方と同じようにレスポンスをヌルヌル表示したくなりました。
そこで、今回以下 2 つの Lambda からストリームレスポンスを受け取る方法を試してみました。

  • Amazon API Gateway + AWS Lambda による WebSocket
  • AWS Lambda のレスポンスストリーミング

1 つ目に関しては、上記のブログ で紹介されているとおり、OpenAI からの出力をヌルヌル表示できそうです。
2 つ目に関しても、こちらのブログで紹介されており、Lambda 単体でもヌルヌル表示できそうです。
自分は Lambda を Rust で実装してみました(Rust Advent Calendar だからね!)。

コードについては GitHub 上で公開しています。
https://github.com/tyrwzl/rust-lambda-stream

Amazon API Gateway + AWS Lambda による WebSocket

まず main 関数ですが、aws-lambda-rust-runtime を利用しており、中身は handler 関数に書いています。

async fn main() -> Result<(), Error> {
    tracing_subscriber::fmt()
        .with_max_level(tracing::Level::INFO)
        .with_target(false)
        .without_time()
        .init();
    run(service_fn(handler)).await
}

次に handler 関数ですが、まず Lambda にやってきた event が WebSocket のどのイベントなのかを判別します。
そして、それぞれのイベントごとに処理を match 文で記載します。

  • event_type が CONNECT: コネクション開始時
  • event_type が DISCONNECT: コネクション終了時
  • event_type が MESSAGE: コネクション確立後、メッセージ受信時
async fn handler(
    event: LambdaEvent<ApiGatewayWebsocketProxyRequest>,
) -> anyhow::Result<ApiGatewayProxyResponse> {
    let event_type = event.payload.request_context.event_type.clone();
    match event_type.as_deref() {
        Some("CONNECT") => {
            tracing::info!("CONNECT");
        }
        Some("MESSAGE") => {
            tracing::info!("MESSAGE");
	    ...
        }
        Some("DISCONNECT") => {
            tracing::info!("DISCONNECT event");
        }
        Some(s) => {
            tracing::error!("Unknown event type: {:?}", s);
        }
        None => {
            tracing::error!("No event type found");
        }
    }

    Ok(ApiGatewayProxyResponse {
        status_code: 200,
        ..Default::default()
    })
}

event_type が MESSAGE の分岐のところでは、クライアントから受信したメッセージに対する処理を記載します。

      Some("MESSAGE") => {
            tracing::info!("MESSAGE");

            let message = event.payload.body.clone();

            let config = OpenAIConfig::new()
                .with_api_key(env::var("OPENAI_API_KEY").expect("No API key found"))
                .with_org_id(env::var("OPENAI_ORG_ID").expect("No org id found"));
            let client = Client::with_config(config);
            let request = CreateCompletionRequestArgs::default()
                .model("text-davinci-003")
                .n(1)
                .prompt(message.expect("No message found"))
                .stream(true)
                .max_tokens(1024_u16)
                .build()?;
            let mut stream = client.completions().create_stream(request).await?;

            let region = "ap-northeast-1";
            let region_provider = RegionProviderChain::first_try(Region::new(region));
            let shared_config = aws_config::defaults(BehaviorVersion::v2023_11_09())
                .region(region_provider)
                .load()
                .await;
            let apigw_id = env::var("APIGW_ID").expect("No APIGW_ID found");
            let api_uri =
                format!("https://{apigw_id}.execute-api.ap-northeast-1.amazonaws.com/prod");
            let api_management_config = config::Builder::from(&shared_config)
                .endpoint_url(api_uri)
                .build();
            let apigw_client =
                aws_sdk_apigatewaymanagement::Client::from_conf(api_management_config);

            let connection_id = &event
                .payload
                .request_context
                .connection_id
                .clone()
                .expect("No connection ID found");

            while let Some(response) = stream.next().await {
                let ccr = response?;
                let blob = Blob::new(
                    serde_json::to_vec(&ccr.choices[0].text.clone()).expect("Could not serialize"),
                );

                apigw_client
                    .post_to_connection()
                    .connection_id(connection_id)
                    .data(blob)
                    .send()
                    .await?;
            }
        }

まず、event 構造体からクライアントから受信したメッセージを抜き取ります。
この受け取ったメッセージを OpenAI に渡します。
今回は OpenAI からのレスポンスをストリーム形式で受信したかったので、async-openai を利用しました。
create_stream を利用することで futures の Stream を受け取ることができます。

OpenAI に API リクエストをしたら、Stream からデータを受信するごとに while ループ内で処理をします。
ループ内では、OpenAI から受け取ったストリームデータを読みとり、WebSocket クライアントに送信します。
AWS API Gateway と Lambda で WebSocket サーバーを構築した場合、API Gateway が WebSocket の管理をしているので、クライアントに送信したいメッセージがある場合は API Gateway の API を Lambda から呼び出してクライアントに送信したいメッセージを API Gateway に渡す必要があります。
ループ内では AWS SDK for Rust を利用して、post_to_connection を呼び出しています。

AWS Lambda のレスポンスストリーミング

こっちは WebSocket より簡素です。
aws-lambda-rust-runtime がレスポンスストリーム用に run_with_streaming_response を用意してくれているので、これを利用します。
また、hyper Body の channel を用意して、送信側を利用して OpenAI からのレスポンスを書き込み、受信側は Lambda のレスポンスとします。

#[tokio::main]
async fn main() -> Result<(), Error> {
    tracing_subscriber::fmt()
        .with_max_level(tracing::Level::INFO)
        .with_target(false)
        .without_time()
        .init();

    lambda_runtime::run_with_streaming_response(service_fn(func)).await?;
    Ok(())
}

#[derive(Deserialize)]
struct Request {
    body: String,
}

async fn func(event: LambdaEvent<Request>) -> anyhow::Result<Response<Body>> {
    let message = event.payload.body;
    let request = CreateCompletionRequestArgs::default()
        .model("text-davinci-003")
        .n(1)
        .prompt(message)
        .stream(true)
        .max_tokens(1024_u16)
        .build()?;

    let (mut tx, rx) = Body::channel();

    tokio::spawn(async move {
        let config = OpenAIConfig::new()
            .with_api_key(env::var("OPENAI_API_KEY").expect("No API key found"))
            .with_org_id(env::var("OPENAI_ORG_ID").expect("No org id found"));
        let client = Client::with_config(config);
        let mut stream = client.completions().create_stream(request).await.unwrap();
        while let Some(response) = stream.next().await {
            let ccr = response.unwrap();
            let chunk: Bytes = ccr.choices[0].text.clone().into();
            tx.send_data(chunk).await.unwrap();
        }
    });

    let resp = Response::builder()
        .header("content-type", "text/html")
        .body(rx)?;

    Ok(resp)
}

Frontend

今回は React で簡単な Frontend 画面を作って、それぞれ同じメッセージを送信してみました。
左側が WebSocket で右側がレスポンスストリームです。
メッセージを送信すると、レスポンスがリアルタイムに返却されていることがわかります。

優しい世界だね

まとめ

Amazon API Gateway と AWS Lambda を使ってお手軽 WebSocket サーバーと、AWS Lambda のレスポンスストリーミングを Rust で作成してみました。
実装が簡単なのはレスポンスストリーミングの方で(aws-lambda-rust-runtime のおかげ)、WebSocket にする必要性がないのであればレスポンスストリーミングを利用する方が良いと思いました。
また、今回認証をどうするか考えたのですが、WebSocket では制約が多く(ブラウザの場合はリクエストヘッダベースの認証は利用できない)、こういった面でもレスポンスストリーミングの方が好ましく感じられました。

参考

Discussion