GPT-4o APIを叩くRust製ぬるぬるCLIアプリ
先日【Rust】✨CLIアプリ向けクレート4選✨ ~ clap, dialoguer, indicatif, console ~ #Rust - Qiitaという記事をQiitaに投稿しまして、その副産物としてChatGPT APIを叩くRust CLIアプリケーションを作成しましたので、折角ですから紹介したいと思います!
Rustに関係のない部分から初めて、Rustで工夫した部分をゆるく紹介できればと思います✨
直接叩いてみる & SSE
公式のQuick Startにはcurl, Python, Node.jsしか用意されておらず、Rustのチュートリアルは当然ありません。とりあえずcurlのチュートリアルに従ってから、Rustで対応するコードを書くことにします。
curlチュートリアルなのでcurlで叩いても良いですが、面倒なのでVSCode拡張のREST Clientを使って叩いてみます。
POST https://api.openai.com/v1/chat/completions HTTP/1.1
Content-Type: application/json
Authorization: Bearer {{$dotenv CHATGPT_APIKEY}}
{
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Say this is a test!"}],
"temperature": 0.7
}
###
POST https://api.openai.com/v1/chat/completions HTTP/1.1
Content-Type: application/json
Authorization: Bearer {{$dotenv CHATGPT_APIKEY}}
{
"model": "gpt-4o",
"messages": [{"role": "user", "content": "「これはテストです」と言ってください!"}],
"stream": true
}
{{$dotenv CHATGPT_APIKEY}}
としておくと、.env
ファイルにAPIキーを書いて置くだけで読み込んでくれるようになります!
CHATGPT_APIKEY="..."
queries.http
ファイルの1つ目のクエリは返答を一気に返してもらうモードで、2つ目のクエリはストリームモードを有効にしたものです。"stream": true
と付けることでストリームモードが有効になっています。
ストリームモードはSSEという仕組みを使って、一度の返答ではなく、サーバーから連続した返答を受け取れるものです。
詳しくは次の記事様が参考になりました🙇
ストリームモードの返答はこんな感じです。
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{"content":"これは"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{"content":"テ"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{"content":"スト"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{"content":"です"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{"content":"。"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
data: [DONE]
データが残っているうちは下記フォーマットのJSONが返ってきて、完了すると[DONE]
が返ってきています。
{
"id": "chatcmpl-xxx",
"object": "chat.completion.chunk",
"created": 1716825742,
"model": "gpt-4o-2024-05-13",
"system_fingerprint": "fp_xxx",
"choices": [
{
"index": 0,
"delta": {
"content":"これは"
},
"logprobs":null,
"finish_reason":null
}
]
}
見づらいですが、「これは」「テ」「スト」「です」「。」と順番に返ってきています!
Rustから叩く
POSTクエリと付属させるJSONの構造がわかったのでいよいよRustから叩いてみます!
ChatGPT APIを叩くにはreqwestとserdeを使います。
serdeでJSON構造体を扱う
REST Clientの実験より、送信用の構造体と受信用の構造体を次の通り用意すれば通信できそうです。
#[derive(Debug, serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum Role {
System,
User,
Assistant,
}
#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
struct Message {
role: Role,
content: String,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct RequestBody {
model: String,
messages: Vec<Message>,
stream: bool,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct Content {
content: String,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct Choice {
delta: Content,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct ResponseBody {
choices: Vec<Choice>,
}
構造体の数が多くなってしまったのはご愛嬌。
reqwestでクエリする
reqwest部分についてソースコードを一部抜粋すると以下のような感じです。
use reqwest::{Client, RequestBuilder};
let response = Client::new()
.post("https://api.openai.com/v1/chat/completions")
.header("Content-Type", "application/json")
.header("Authorization", api_key_field.as_str())
.json(&RequestBody {
model: "gpt-4o".to_string(),
messages: Vec::from(input_messages),
stream: true,
})
.send()
.await?
.bytes_stream();
-
input_messages
:&[Message]
型変数 -
api_key_field
:String
型変数
.bytes_stream()
で取得すると、impl Stream<Item = Result<...>>
というStream型の値を得られます。Stream型は非同期版Iteratorと言えるようなもので、非同期にデータを受け取り次第随時処理することができるようになります。
StreamReaderを使って行単位で処理できるようにする
ここが筆者的には今回のアプリケーションで一番工夫した部分になります。
先ほど紹介した.bytes_stream()
を使ってStreamを取得し、そのままStreamに対して.next().await
を呼んだ際はバッファされたバイト列がバッファのサイズに応じて取得されるため(つまり多分一定量集まったら出力されるイメージ)、適切なJSON文字列を得るのが難しくなっています。次の記事様が詳しいです。
各JSONは改行文字で区切られていることを利用し、StreamReaderを間に噛ませることで(AsyncBufReadExt::read_untilを呼び出すことにより)改行ごとに取り出せるようにします[1]。
発想としてはBufReader構造体+BufReadトレイトを使うのに似ています。知っておくと便利なテクニックです!
use tokio::io::AsyncBufReadExt;
use tokio_stream::{Stream, StreamExt};
use tokio_util::io::StreamReader;
// ...
let mut response = StreamReader::new(
response.map(|r| r.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))),
);
// ...
let mut line = Vec::new();
while response.read_until(b'\n', &mut line).await? > 0 {
let line_str = String::from_utf8_lossy(&line);
// ...
}
ソースコード全体
全体としては次のようになりました!dialoguer::Input等話足りない部分もありますが、執筆に疲れたので今回はこの辺にしたいと思います😇
Inputについては是非Qiita記事の方を読んでいただけると幸いです🙇
[package]
name = "chatgpt_cli"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.83"
dialoguer = "0.11.0"
dotenvy = "0.15.7"
reqwest = { version = "0.12.4", features = ["json", "stream"] }
serde = { version = "1.0.202", features = ["derive"] }
serde_json = "1.0.117"
tokio = { version = "1.37.0", features = ["full"] }
tokio-stream = "0.1.15"
bytes = "1.6.0"
tokio-util = { version = "0.7.11", features = ["io"] }
use anyhow::Result;
use bytes::Bytes;
use dialoguer::Input;
use reqwest::{Client, RequestBuilder};
use std::io;
use std::io::Write;
use tokio::io::AsyncBufReadExt;
use tokio_stream::{Stream, StreamExt};
use tokio_util::io::StreamReader;
#[derive(Debug, serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum Role {
System,
User,
Assistant,
}
#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
struct Message {
role: Role,
content: String,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct RequestBody {
model: String,
messages: Vec<Message>,
stream: bool,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct Content {
content: String,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct Choice {
delta: Content,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct ResponseBody {
choices: Vec<Choice>,
}
fn common_header(api_key: &str) -> RequestBuilder {
let api_key_field = format!("Bearer {}", api_key);
Client::new()
.post("https://api.openai.com/v1/chat/completions")
.header("Content-Type", "application/json")
.header("Authorization", api_key_field.as_str())
}
async fn query(
api_key: &str,
input_messages: &[Message],
) -> Result<impl Stream<Item = reqwest::Result<Bytes>>> {
let res = common_header(api_key)
.json(&RequestBody {
model: "gpt-4o".to_string(),
messages: Vec::from(input_messages),
stream: true,
})
.send()
.await?
.bytes_stream();
Ok(res)
}
fn to_response(line: String) -> Result<ResponseBody> {
let line = line.replace("data: ", "");
let response_body: ResponseBody = serde_json::from_str(&line)?;
Ok(response_body)
}
#[tokio::main]
async fn main() -> Result<()> {
let mut stdout = std::io::stdout();
dotenvy::dotenv().ok();
let api_key = std::env::var("CHATGPT_APIKEY")?;
if api_key.is_empty() {
eprintln!("Please set the environment variable CHATGPT_APIKEY");
std::process::exit(1);
}
let mut messages = vec![Message {
role: Role::System,
content: "You are a helpful assistant.".to_string(),
}];
loop {
let input = Input::new()
.with_prompt("You")
.interact_text()
.unwrap_or_else(|_| "quit".to_string());
if input == "quit" || input == "q" {
break;
}
messages.push(Message {
role: Role::User,
content: input,
});
let response = query(&api_key, &messages).await?;
let mut response = StreamReader::new(
response.map(|r| r.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))),
);
print!("ChatGPT: ");
stdout.flush()?;
let mut line = Vec::new();
let mut all_content = "ChatGPT: ".to_string();
while response.read_until(b'\n', &mut line).await? > 0 {
let line_str = String::from_utf8_lossy(&line);
let response_body_str = line_str.trim().to_string();
line.clear();
if let Ok(response_body) = to_response(response_body_str.to_string()) {
if let Some(c) = response_body.choices.first() {
let content_parts = &c.delta.content;
print!("{}", content_parts);
stdout.flush()?;
all_content.push_str(content_parts);
}
}
}
println!();
messages.push(Message {
role: Role::Assistant,
content: all_content,
});
}
Ok(())
}
GitHubリポジトリ: https://github.com/anotherhollow1125/chatgpt_cli
まとめ・所感
StreamReaderを噛ませる工夫が個人的にはツボったので本記事を書き始めましたが、ぶっちゃけ今回のようなシンプルなCLI版ChatGPTは不便だなぁって思っています...
というのも、我々は日本人のため日本語で入力したいですが、ターミナルで日本語入力を行うと大体ろくなことがありません。実際、バックスペース等が関わると表示と中身がバラバラになったりなんてザラです。カッコいいかもしれませんが使い勝手が悪いわけです。
ChatGPT APIを組み込んだCLIアプリはちょっと作ってみたいものができたので現在鋭意制作中です!こちらも完成し次第記事にできたらなぁと思います。
本記事が誰かの役に立てば幸いです、ここまで読んでいただき誠にありがとうございました!
-
read_line
というメソッドもあるのですが、バイト列に対して使うためにread_until
の方を採用しています。 ↩︎
Discussion