Rust on Lambda でWebSocket とストリームレスポンスを試す
本記事は Rust Advent Calendar 2023 16 日目の記事です。
AWS Lambda からストリームレスポンスを受け取りたい
AWS Lambda 上で動く Rust アプリケーション上で OpenAI の API を叩いていたのですが、こちらのブログ の方と同じようにレスポンスをヌルヌル表示したくなりました。
そこで、今回以下 2 つの Lambda からストリームレスポンスを受け取る方法を試してみました。
- Amazon API Gateway + AWS Lambda による WebSocket
- AWS Lambda のレスポンスストリーミング
1 つ目に関しては、上記のブログ で紹介されているとおり、OpenAI からの出力をヌルヌル表示できそうです。
2 つ目に関しても、こちらのブログで紹介されており、Lambda 単体でもヌルヌル表示できそうです。
自分は Lambda を Rust で実装してみました(Rust Advent Calendar だからね!)。
コードについては GitHub 上で公開しています。
Amazon API Gateway + AWS Lambda による WebSocket
まず main 関数ですが、aws-lambda-rust-runtime を利用しており、中身は handler 関数に書いています。
async fn main() -> Result<(), Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_target(false)
.without_time()
.init();
run(service_fn(handler)).await
}
次に handler 関数ですが、まず Lambda にやってきた event が WebSocket のどのイベントなのかを判別します。
そして、それぞれのイベントごとに処理を match 文で記載します。
- event_type が CONNECT: コネクション開始時
- event_type が DISCONNECT: コネクション終了時
- event_type が MESSAGE: コネクション確立後、メッセージ受信時
async fn handler(
event: LambdaEvent<ApiGatewayWebsocketProxyRequest>,
) -> anyhow::Result<ApiGatewayProxyResponse> {
let event_type = event.payload.request_context.event_type.clone();
match event_type.as_deref() {
Some("CONNECT") => {
tracing::info!("CONNECT");
}
Some("MESSAGE") => {
tracing::info!("MESSAGE");
...
}
Some("DISCONNECT") => {
tracing::info!("DISCONNECT event");
}
Some(s) => {
tracing::error!("Unknown event type: {:?}", s);
}
None => {
tracing::error!("No event type found");
}
}
Ok(ApiGatewayProxyResponse {
status_code: 200,
..Default::default()
})
}
event_type が MESSAGE の分岐のところでは、クライアントから受信したメッセージに対する処理を記載します。
Some("MESSAGE") => {
tracing::info!("MESSAGE");
let message = event.payload.body.clone();
let config = OpenAIConfig::new()
.with_api_key(env::var("OPENAI_API_KEY").expect("No API key found"))
.with_org_id(env::var("OPENAI_ORG_ID").expect("No org id found"));
let client = Client::with_config(config);
let request = CreateCompletionRequestArgs::default()
.model("text-davinci-003")
.n(1)
.prompt(message.expect("No message found"))
.stream(true)
.max_tokens(1024_u16)
.build()?;
let mut stream = client.completions().create_stream(request).await?;
let region = "ap-northeast-1";
let region_provider = RegionProviderChain::first_try(Region::new(region));
let shared_config = aws_config::defaults(BehaviorVersion::v2023_11_09())
.region(region_provider)
.load()
.await;
let apigw_id = env::var("APIGW_ID").expect("No APIGW_ID found");
let api_uri =
format!("https://{apigw_id}.execute-api.ap-northeast-1.amazonaws.com/prod");
let api_management_config = config::Builder::from(&shared_config)
.endpoint_url(api_uri)
.build();
let apigw_client =
aws_sdk_apigatewaymanagement::Client::from_conf(api_management_config);
let connection_id = &event
.payload
.request_context
.connection_id
.clone()
.expect("No connection ID found");
while let Some(response) = stream.next().await {
let ccr = response?;
let blob = Blob::new(
serde_json::to_vec(&ccr.choices[0].text.clone()).expect("Could not serialize"),
);
apigw_client
.post_to_connection()
.connection_id(connection_id)
.data(blob)
.send()
.await?;
}
}
まず、event 構造体からクライアントから受信したメッセージを抜き取ります。
この受け取ったメッセージを OpenAI に渡します。
今回は OpenAI からのレスポンスをストリーム形式で受信したかったので、async-openai を利用しました。
create_stream を利用することで futures の Stream を受け取ることができます。
OpenAI に API リクエストをしたら、Stream からデータを受信するごとに while ループ内で処理をします。
ループ内では、OpenAI から受け取ったストリームデータを読みとり、WebSocket クライアントに送信します。
AWS API Gateway と Lambda で WebSocket サーバーを構築した場合、API Gateway が WebSocket の管理をしているので、クライアントに送信したいメッセージがある場合は API Gateway の API を Lambda から呼び出してクライアントに送信したいメッセージを API Gateway に渡す必要があります。
ループ内では AWS SDK for Rust を利用して、post_to_connection を呼び出しています。
AWS Lambda のレスポンスストリーミング
こっちは WebSocket より簡素です。
aws-lambda-rust-runtime がレスポンスストリーム用に run_with_streaming_response を用意してくれているので、これを利用します。
また、hyper Body の channel を用意して、送信側を利用して OpenAI からのレスポンスを書き込み、受信側は Lambda のレスポンスとします。
#[tokio::main]
async fn main() -> Result<(), Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_target(false)
.without_time()
.init();
lambda_runtime::run_with_streaming_response(service_fn(func)).await?;
Ok(())
}
#[derive(Deserialize)]
struct Request {
body: String,
}
async fn func(event: LambdaEvent<Request>) -> anyhow::Result<Response<Body>> {
let message = event.payload.body;
let request = CreateCompletionRequestArgs::default()
.model("text-davinci-003")
.n(1)
.prompt(message)
.stream(true)
.max_tokens(1024_u16)
.build()?;
let (mut tx, rx) = Body::channel();
tokio::spawn(async move {
let config = OpenAIConfig::new()
.with_api_key(env::var("OPENAI_API_KEY").expect("No API key found"))
.with_org_id(env::var("OPENAI_ORG_ID").expect("No org id found"));
let client = Client::with_config(config);
let mut stream = client.completions().create_stream(request).await.unwrap();
while let Some(response) = stream.next().await {
let ccr = response.unwrap();
let chunk: Bytes = ccr.choices[0].text.clone().into();
tx.send_data(chunk).await.unwrap();
}
});
let resp = Response::builder()
.header("content-type", "text/html")
.body(rx)?;
Ok(resp)
}
Frontend
今回は React で簡単な Frontend 画面を作って、それぞれ同じメッセージを送信してみました。
左側が WebSocket で右側がレスポンスストリームです。
メッセージを送信すると、レスポンスがリアルタイムに返却されていることがわかります。
まとめ
Amazon API Gateway と AWS Lambda を使ってお手軽 WebSocket サーバーと、AWS Lambda のレスポンスストリーミングを Rust で作成してみました。
実装が簡単なのはレスポンスストリーミングの方で(aws-lambda-rust-runtime のおかげ)、WebSocket にする必要性がないのであればレスポンスストリーミングを利用する方が良いと思いました。
また、今回認証をどうするか考えたのですが、WebSocket では制約が多く(ブラウザの場合はリクエストヘッダベースの認証は利用できない)、こういった面でもレスポンスストリーミングの方が好ましく感じられました。
参考
- https://qiita.com/hama1080/items/849888c4e6dfabf92cd2
- https://developer.mamezou-tech.com/blogs/2023/04/23/lambda-response-streaming-intro/
- https://medium.com/@mohammadaliasghar523/creating-a-real-time-chat-app-with-next-js-and-websockets-e41fd131949c
- https://medium.com/techhappily/unveiling-the-power-of-streaming-responses-in-aws-lambda-using-rust-793fa5c9faf8
Discussion