Stream APIで、OpenAIを利用してみる

に公開

LLMを使った機能実装時に、Stream APIを使いたい(でも面倒くさい…)ということが増えてきたので、勉強の意味も兼ねて実装してみた。Clineと頭突き合わせながら、実装の仔細を確認する。こういう使い方もある。

概要

ストリーミングでは、ネットワーク経由で受信するリソースを小さなチャンク(塊)に分割し、少しずつ処理します。
https://developer.mozilla.org/ja/docs/Web/API/Streams_API

Stream APIは、Javascriptで非同期にデータを取得してくる際などに利用するAPIです。レスポンスをストリームとして読み取り、取得状況に合わせて非同期的に利用することができます。そのため、大きなデータや時間のかかる処理を、断片的にレスポンスを取得できるというのが最大の利点です。その最中で発生したエラーも捕捉することが可能になります。

OpenAIなどのLLMとの通信の場合、トークンのサイズによって、途中でサーバー側がタイムアウトをかけることが多いので、このようなストリーム処理でタイムアウトを避けてみました。

実装例

バックエンド

TransformStream

まずは、ストリームの読み取りです。TransformStreamを用いて、streamを作成します。このstreamにストリームデータを書き込むためのwriterが、WritableStreamから提供されます。

const stream = new TransformStream();
const writer = stream.writable.getWriter();

Encode

TextEncoderを用いて、受け取ったストリームを、UTF-8のバイトストリームに変換します。

const encoder = new TextEncoder();

Uint8Arrayに変換したオブジェクトを、streamに書き込んでやります。

writer.write(
  encoder.encode(
    formatData({
      content,
      result: '',
    })
  )
);

formatData()は、オブジェクトを文字列に変換して、プレフィックスとしてdata: を付与している単純な処理です。

const formatData = (dataObject: ServerMessage) =>
  `data: ${JSON.stringify(dataObject)}\n\n`;

レスポンス

以下のような形で、Content-Type: 'text/event-stream'で、stream.readableを返します。

return new Response(stream.readable, {
  headers: {
    'Content-Type': 'text/event-stream',
    'Cache-Control': 'no-cache',
    Connection: 'keep-alive',
  },
});

非同期処理/ポーリング

この後は色々とやりようがあると思います。今回はポーリング形式での実装を試しています。

// 非同期でOpenAI処理を開始
(async () => {
  try {
    // ...省略
    // ステータスをポーリングして進捗を送信
    let status = await openai.beta.threads.runs.retrieve(thread.id, run.id);

    while (
      !['completed', 'failed', 'cancelled'].includes(status.status)
    ) {
      // 定期的にステータスを更新
      await new Promise(resolve => setTimeout(resolve, 1000));// 1秒ごとに
      status = await openai.beta.threads.runs.retrieve(thread.id, run.id);

      // 処理中のステータスを送信
      writer.write(
        encoder.encode(
          formatData({
            content,
            status: status.status,
            progress: true,
          })
        )
      );
    }

    if (status.status === 'completed') {
      const messages = await openai.beta.threads.messages.list(run.thread_id);

      let result = {};
      for (const message of messages.data.reverse()) {
        if (message.role !== 'assistant') continue;
        if (message.content[0].type !== 'text') continue;
        result = message.content[0];
        break;
      }

      // 最終結果を送信
      writer.write(
        encoder.encode(
          formatData({
            content,
            result: result as ServerMessage['result'],
            completed: true,
          })
        )
      );
    } else {
      throw new Error();
    }
  } catch (error: unknown) {
    console.error('Stream processing error:', error);
    // エラー状態を送信
    writer.write(
      encoder.encode(
        formatData({
          content,
          error: '処理中にエラーが発生しました',
          completed: true,
        })
      )
    );
  } finally {
    writer.close();
  }
})();

と、こんな感じ。要は、

  1. ポーリングで定期的にOpenAIからステータスを取得し続ける
  2. status.status(ややこしい&ダサい)がcompleted、failed、canceledにならない限り、ポーリングを続ける
  3. 同時にその結果を、StreamAPI経由で流し続ける
  4. status.status(ややこしい&ダサい)がcompleted、failed、canceledになったら、ポーリングを停止して、resultをStreamAPIに流して終了

という感じの処理になっています。

完成系

バックエンドはこんな感じ。react-router v7で書いています。

import OpenAI from 'openai';
import { data } from 'react-router';

export const config = {
  maxDuration: 60,
};

export type ServerStatus =
  | 'in_progress'
  | 'completed'
  | 'queued'
  | 'requires_action'
  | 'cancelling'
  | 'cancelled'
  | 'failed'
  | 'incomplete'
  | 'expired';

export interface ServerMessage {
  content: string;
  status?: ServerStatus;
  result?: string;
  progress?: boolean;
  error?: string;
  completed?: boolean;
}

const wait = async (msec = 1000) => await new Promise((resolve) => setTimeout(resolve, msec));

const formatData = (dataObject: ServerMessage) =>
  `data: ${JSON.stringify(dataObject)}\n\n`;

export const action = async ({ request }: { request: Request }) => {
  // 認証チェックとか

  // FormDataからcontent取得
  const formData = await request.formData();
  const content = formData.get('content') as string;

  if (!content) return data({ error: 'Content is required' }, { status: 400 });

  // ストリーミング用のTransformStream
  const stream = new TransformStream();
  const writer = stream.writable.getWriter();
  const encoder = new TextEncoder();

  // 空のデータで初期化
  writer.write(
    encoder.encode(
      formatData({
        content,
        result: '',
      })
    )
  );

  // 非同期でOpenAI処理を開始
  (async () => {
    try {
      const openai = new OpenAI({
        apiKey: process.env.VITE_OPENAI_APIKEY,
        organization: process.env.VITE_OPENAI_ORGANIZATION_ID,
        project: process.env.VITE_OPENAI_PROJECT_ID,
      });

      // スレッドとメッセージの作成
      const thread = await openai.beta.threads.create();
      await openai.beta.threads.messages.create(thread.id, {
        role: 'user',
        content,
      });

      // Runの作成
      const run = await openai.beta.threads.runs.create(thread.id, {
        assistant_id: process.env.VITE_OPENAI_ENTITY_ASSISTANT as string,
      });

      // ステータスをポーリングして進捗を送信
      let status = await openai.beta.threads.runs.retrieve(thread.id, run.id);

      while (
        !['completed', 'failed', 'cancelled'].includes(status.status)
      ) {
        // 定期的にステータスを更新
        await wait();
        status = await openai.beta.threads.runs.retrieve(thread.id, run.id);

        // 処理中のステータスを送信
        writer.write(
          encoder.encode(
            formatData({
              content,
              status: status.status,
              progress: true,
            })
          )
        );
      }

      if (status.status === 'completed') {
        const messages = await openai.beta.threads.messages.list(run.thread_id);

        let result = {};
        for (const message of messages.data.reverse()) {
          if (message.role !== 'assistant') continue;
          if (message.content[0].type !== 'text') continue;
          result = message.content[0];
          break;
        }

        // 最終結果を送信
        writer.write(
          encoder.encode(
            formatData({
              content,
              result: result as ServerMessage['result'],
              completed: true,
            })
          )
        );
      } else {
        // エラー状態を送信
        throw new Error();
      }
    } catch (error: unknown) {
      console.error('Stream processing error:', error);
      // エラー状態を送信
      writer.write(
        encoder.encode(
          formatData({
            content,
            error: '処理中にエラーが発生しました',
            completed: true,
          })
        )
      );
    } finally {
      writer.close();
    }
  })();

  // ストリームレスポンスを返す
  return new Response(stream.readable, {
    headers: {
      'Content-Type': 'text/event-stream',
      'Cache-Control': 'no-cache',
      Connection: 'keep-alive',
    },
  });
};

フロント

フロントでは、APIをStreamとして読み込みます。AbortControllerを渡して、中止処理はそこで実装します(Clineがfetchで書いてきたが、当然axiosでも実装可能であろう)。

fetch('/api/retrive', {
  method: 'POST',
  body: formData,
  signal: abortController.signal,
})

レスポンスの処理

「あー、こういうことができるんだ」と、ちょっと目から鱗だったんですが、responce.bodyからReaderを直接取得してきます。Uint8Arrayをデコードするために、TextDecorderも初期化します。

const reader = response.body.getReader();
const decoder = new TextDecoder();

ストリーミング処理

processStreamが、ストリーミング処理を行う関数です。await reader.read()でdoneがtruthyになるまでストリームを読みこみます。

const processStream = async () => {
  let buffer = ''; // 未完成のメッセージを保持するバッファ
  try {
    let processComplete = false;
    while (!processComplete) {
     const { done, value } = await reader.read();
     if (done) {
       processComplete = true;
       break;
     }

完成系

// SSEエンドポイントにPOSTリクエストを送信
fetch('/api/retrive', {
  method: 'POST',
  body: formData,
  signal: abortController.signal,
})
  .then(async (response) => {
    if (!response.ok) {
      throw new Error(`HTTP error! status: ${response.status}`);
    }

    // レスポンスのContent-Typeをチェック
    const contentType = response.headers.get('Content-Type');
    if (!contentType || !contentType.includes('text/event-stream')) {
      throw new Error(`Invalid content type: ${contentType}`);
    }

    // レスポンスのボディが存在することを確認
    if (!response.body) {
      throw new Error('Response body is null');
    }

    // レスポンスのストリームを直接読み込む
    const reader = response.body.getReader();
    const decoder = new TextDecoder();

    // 手動でSSEを処理
    const processStream = async () => {
      let buffer = ''; // 未完成のメッセージを保持するバッファ
      try {
        let processComplete = false;
        while (!processComplete) {
          const { done, value } = await reader.read();
          if (done) {
            processComplete = true;
            break;
          }

          // 受信データをデコード
          const chunk = decoder.decode(value, { stream: true });
          buffer += chunk;

          // バッファからメッセージを抽出
          const messages = buffer.split('\n\n');
          // 最後のメッセージは未完成の可能性があるため保持
          buffer = messages.pop() || '';

          for (const message of messages.filter(Boolean)) {
            // 'data: ' プレフィックスを削除してJSONをパース
            if (message.startsWith('data: ')) {
              try {
                const data = JSON.parse(message.substring(6));
                handleServerMessage(data);
              } catch (parseError: unknown) {
                console.error('JSON parse error:', parseError, message);
              }
            }
          }
        }
      } catch (streamError: unknown) {
        // AbortErrorはエラーとして表示しない(ユーザーによる中断)
        const error = streamError as Error;
        if (error.name !== 'AbortError') {
          console.error('Stream reading error:', streamError);
          setError(
            'ストリーム読み込み中にエラーが発生しました'
          );
          setIsPending(false);
        }
      }
    };

    // ストリーム処理を開始
    processStream();
  })
  .catch((fetchError: unknown) => {
    // AbortErrorはエラーとして表示しない(ユーザーによる中断)
    const error = fetchError as Error;
    if (error.name !== 'AbortError') {
      console.error('Fetch error:', fetchError);
      setError(
        'リクエスト中にエラーが発生しました: ' + error.message
      );
      setIsPending(false);
    }
  });

実際にAPIを叩いてみると、こんな感じでEventStreamが発行され、定期的に値を取得しているのがわかると思います。

Discussion