🎧

GPT-4o APIを叩くRust製ぬるぬるCLIアプリ

2024/05/28に公開

先日【Rust】✨CLIアプリ向けクレート4選✨ ~ clap, dialoguer, indicatif, console ~ #Rust - Qiitaという記事をQiitaに投稿しまして、その副産物としてChatGPT APIを叩くRust CLIアプリケーションを作成しましたので、折角ですから紹介したいと思います!

about_rust.gif

https://github.com/anotherhollow1125/chatgpt_cli

Rustに関係のない部分から初めて、Rustで工夫した部分をゆるく紹介できればと思います✨

直接叩いてみる & SSE

公式のQuick Startにはcurl, Python, Node.jsしか用意されておらず、Rustのチュートリアルは当然ありません。とりあえずcurlのチュートリアルに従ってから、Rustで対応するコードを書くことにします。

curlチュートリアルなのでcurlで叩いても良いですが、面倒なのでVSCode拡張のREST Clientを使って叩いてみます。

https://marketplace.visualstudio.com/items?itemName=humao.rest-client

https://qiita.com/mgmgmogumi/items/61f0b896580d3e6db2bb

queries.http
POST https://api.openai.com/v1/chat/completions HTTP/1.1
Content-Type: application/json
Authorization: Bearer {{$dotenv CHATGPT_APIKEY}}

{
    "model": "gpt-4o",
    "messages": [{"role": "user", "content": "Say this is a test!"}],
    "temperature": 0.7
}

###

POST https://api.openai.com/v1/chat/completions HTTP/1.1
Content-Type: application/json
Authorization: Bearer {{$dotenv CHATGPT_APIKEY}}

{
    "model": "gpt-4o",
    "messages": [{"role": "user", "content": "「これはテストです」と言ってください!"}],
    "stream": true
}

{{$dotenv CHATGPT_APIKEY}}としておくと、.envファイルにAPIキーを書いて置くだけで読み込んでくれるようになります!

.env
CHATGPT_APIKEY="..."

queries.httpファイルの1つ目のクエリは返答を一気に返してもらうモードで、2つ目のクエリはストリームモードを有効にしたものです。"stream": trueと付けることでストリームモードが有効になっています。

ストリームモードはSSEという仕組みを使って、一度の返答ではなく、サーバーから連続した返答を受け取れるものです。

詳しくは次の記事様が参考になりました🙇

https://zenn.dev/sekapi/articles/a089c203adad74

ストリームモードの返答はこんな感じです。

data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{"content":"これは"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{"content":"テ"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{"content":"スト"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{"content":"です"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{"content":"。"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-xxx","object":"chat.completion.chunk","created":1716825742,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_xxx","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}

data: [DONE]

データが残っているうちは下記フォーマットのJSONが返ってきて、完了すると[DONE]が返ってきています。

{
  "id": "chatcmpl-xxx",
  "object": "chat.completion.chunk",
  "created": 1716825742,
  "model": "gpt-4o-2024-05-13",
  "system_fingerprint": "fp_xxx",
  "choices": [
    {
      "index": 0,
      "delta": {
        "content":"これは"
      },
      "logprobs":null,
      "finish_reason":null
    }
  ]
}

見づらいですが、「これは」「テ」「スト」「です」「。」と順番に返ってきています!

Rustから叩く

POSTクエリと付属させるJSONの構造がわかったのでいよいよRustから叩いてみます!

ChatGPT APIを叩くにはreqwestserdeを使います。

  • reqwest: POSTクエリを投げるために使います。
  • serde: JSONをRustの構造体にパースするために使います。

serdeでJSON構造体を扱う

REST Clientの実験より、送信用の構造体と受信用の構造体を次の通り用意すれば通信できそうです。

送信用構造体
#[derive(Debug, serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum Role {
    System,
    User,
    Assistant,
}

#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
struct Message {
    role: Role,
    content: String,
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct RequestBody {
    model: String,
    messages: Vec<Message>,
    stream: bool,
}
受信用構造体
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct Content {
    content: String,
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct Choice {
    delta: Content,
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct ResponseBody {
    choices: Vec<Choice>,
}

構造体の数が多くなってしまったのはご愛嬌。

reqwestでクエリする

reqwest部分についてソースコードを一部抜粋すると以下のような感じです。

use reqwest::{Client, RequestBuilder};

let response = Client::new()
    .post("https://api.openai.com/v1/chat/completions")
    .header("Content-Type", "application/json")
    .header("Authorization", api_key_field.as_str())
    .json(&RequestBody {
        model: "gpt-4o".to_string(),
        messages: Vec::from(input_messages),
        stream: true,
    })
    .send()
    .await?
    .bytes_stream();
  • input_messages: &[Message]型変数
  • api_key_field: String型変数

.bytes_stream()で取得すると、impl Stream<Item = Result<...>>というStream型の値を得られます。Stream型は非同期版Iteratorと言えるようなもので、非同期にデータを受け取り次第随時処理することができるようになります。

StreamReaderを使って行単位で処理できるようにする

ここが筆者的には今回のアプリケーションで一番工夫した部分になります。

先ほど紹介した.bytes_stream()を使ってStreamを取得し、そのままStreamに対して.next().awaitを呼んだ際はバッファされたバイト列がバッファのサイズに応じて取得されるため(つまり多分一定量集まったら出力されるイメージ)、適切なJSON文字列を得るのが難しくなっています。次の記事様が詳しいです。

https://zenn.dev/fraim/articles/2024-02-01-rust-hyper-buffer-size

各JSONは改行文字で区切られていることを利用し、StreamReaderを間に噛ませることで(AsyncBufReadExt::read_untilを呼び出すことにより)改行ごとに取り出せるようにします[1]

発想としてはBufReader構造体+BufReadトレイトを使うのに似ています。知っておくと便利なテクニックです!

use tokio::io::AsyncBufReadExt;
use tokio_stream::{Stream, StreamExt};
use tokio_util::io::StreamReader;

// ...

let mut response = StreamReader::new(
    response.map(|r| r.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))),
);

// ...

let mut line = Vec::new();
while response.read_until(b'\n', &mut line).await? > 0 {
    let line_str = String::from_utf8_lossy(&line);
    // ...
}

ソースコード全体

全体としては次のようになりました!dialoguer::Input等話足りない部分もありますが、執筆に疲れたので今回はこの辺にしたいと思います😇

Inputについては是非Qiita記事の方を読んでいただけると幸いです🙇

https://qiita.com/namn1125/items/5eb2c7cfecbf8870abe0

Cargo.toml
[package]
name = "chatgpt_cli"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
anyhow = "1.0.83"
dialoguer = "0.11.0"
dotenvy = "0.15.7"
reqwest = { version = "0.12.4", features = ["json", "stream"] }
serde = { version = "1.0.202", features = ["derive"] }
serde_json = "1.0.117"
tokio = { version = "1.37.0", features = ["full"] }
tokio-stream = "0.1.15"
bytes = "1.6.0"
tokio-util = { version = "0.7.11", features = ["io"] }
main.rs
use anyhow::Result;
use bytes::Bytes;
use dialoguer::Input;
use reqwest::{Client, RequestBuilder};
use std::io;
use std::io::Write;
use tokio::io::AsyncBufReadExt;
use tokio_stream::{Stream, StreamExt};
use tokio_util::io::StreamReader;

#[derive(Debug, serde::Serialize, serde::Deserialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum Role {
    System,
    User,
    Assistant,
}

#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)]
struct Message {
    role: Role,
    content: String,
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct RequestBody {
    model: String,
    messages: Vec<Message>,
    stream: bool,
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct Content {
    content: String,
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct Choice {
    delta: Content,
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct ResponseBody {
    choices: Vec<Choice>,
}

fn common_header(api_key: &str) -> RequestBuilder {
    let api_key_field = format!("Bearer {}", api_key);

    Client::new()
        .post("https://api.openai.com/v1/chat/completions")
        .header("Content-Type", "application/json")
        .header("Authorization", api_key_field.as_str())
}

async fn query(
    api_key: &str,
    input_messages: &[Message],
) -> Result<impl Stream<Item = reqwest::Result<Bytes>>> {
    let res = common_header(api_key)
        .json(&RequestBody {
            model: "gpt-4o".to_string(),
            messages: Vec::from(input_messages),
            stream: true,
        })
        .send()
        .await?
        .bytes_stream();

    Ok(res)
}

fn to_response(line: String) -> Result<ResponseBody> {
    let line = line.replace("data: ", "");

    let response_body: ResponseBody = serde_json::from_str(&line)?;

    Ok(response_body)
}

#[tokio::main]
async fn main() -> Result<()> {
    let mut stdout = std::io::stdout();
    dotenvy::dotenv().ok();

    let api_key = std::env::var("CHATGPT_APIKEY")?;

    if api_key.is_empty() {
        eprintln!("Please set the environment variable CHATGPT_APIKEY");
        std::process::exit(1);
    }

    let mut messages = vec![Message {
        role: Role::System,
        content: "You are a helpful assistant.".to_string(),
    }];

    loop {
        let input = Input::new()
            .with_prompt("You")
            .interact_text()
            .unwrap_or_else(|_| "quit".to_string());

        if input == "quit" || input == "q" {
            break;
        }

        messages.push(Message {
            role: Role::User,
            content: input,
        });

        let response = query(&api_key, &messages).await?;
        let mut response = StreamReader::new(
            response.map(|r| r.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))),
        );

        print!("ChatGPT: ");
        stdout.flush()?;

        let mut line = Vec::new();
        let mut all_content = "ChatGPT: ".to_string();
        while response.read_until(b'\n', &mut line).await? > 0 {
            let line_str = String::from_utf8_lossy(&line);
            let response_body_str = line_str.trim().to_string();
            line.clear();

            if let Ok(response_body) = to_response(response_body_str.to_string()) {
                if let Some(c) = response_body.choices.first() {
                    let content_parts = &c.delta.content;

                    print!("{}", content_parts);
                    stdout.flush()?;

                    all_content.push_str(content_parts);
                }
            }
        }

        println!();

        messages.push(Message {
            role: Role::Assistant,
            content: all_content,
        });
    }

    Ok(())
}

GitHubリポジトリ: https://github.com/anotherhollow1125/chatgpt_cli

まとめ・所感

StreamReaderを噛ませる工夫が個人的にはツボったので本記事を書き始めましたが、ぶっちゃけ今回のようなシンプルなCLI版ChatGPTは不便だなぁって思っています...

というのも、我々は日本人のため日本語で入力したいですが、ターミナルで日本語入力を行うと大体ろくなことがありません。実際、バックスペース等が関わると表示と中身がバラバラになったりなんてザラです。カッコいいかもしれませんが使い勝手が悪いわけです。

ChatGPT APIを組み込んだCLIアプリはちょっと作ってみたいものができたので現在鋭意制作中です!こちらも完成し次第記事にできたらなぁと思います。

本記事が誰かの役に立てば幸いです、ここまで読んでいただき誠にありがとうございました!

脚注
  1. read_lineというメソッドもあるのですが、バイト列に対して使うためにread_untilの方を採用しています。 ↩︎

Discussion