🤖

Vercel AI SDKはなぜ複数ステップのLLM呼び出しを1つのStreamで返せるのか?

に公開

最近業務で AI エージェントの実装をすることが多い新卒 1 年目のエンジニアです。
Vercel AI SDK の実装で面白いなと思った部分について記事にしてみました!(初投稿)

はじめに

Vercel AI SDK では、以下のようにstreamTextに prompt や tool などを渡すだけで、AI エージェントを簡単に作ることができます。

example.ts
// お天気エージェント
const result = streamText({
  model: openai("gpt-4"),
  stopWhen: stepCountIs(5),
  tools: { weather: weatherTool },
  prompt: "What is the weather in Tokyo?",
});

for await (const chunk of result.fullStream) {
  console.log(chunk); // LLM応答 → ツール実行 → LLM応答 が連続して流れる
}

余談ですが、Agent というクラスも用意されているものの、Agent はgenerateTextstreamTextの薄いラッパーであり、エージェントループが実装されているのは、generateText, streamTextの方です。

streamText の裏側はどうなっている?

めちゃくちゃ簡単に書くと、streamTextの内部では以下のようなことをしています。

1. ユーザープロンプト、ツール定義などを用いてメッセージを構築し、LLMをコール(stream1)
2. LLMからtool-callのレスポンスが返される
3. toolを実行
4. tool-resultを含めたメッセージを構築し、LLMをコール(stream2)
5. tool-callが含まれている場合、3に戻る

でもこれ、どうやって複数の LLM からのストリームを 1 つにまとめてるんでしょうか?

先に答えを言うと、stitchableStreamというものを用いて、複数ストリームをキューで管理し、順番に読み取れるようにすることで、1 つのストリームとして扱っています。
ここからはstitchableStreamについて紹介します。

stitchableStream

stitchableStreamは、複数のストリームを順次繋ぎ合わせて 1 つのストリームとして公開するための仕組みです。
以下のようなイメージです。(あくまでイメージ)

usage-image.ts
const stitchableStream = createStitchableStream<TextStreamPart<TOOLS>>();

// クライアント側:先にストリームの消費を開始できる!
for await (const part of stitchableStream.stream) {
  console.log(part); // 全ステップが連続して流れる
}

// AI SDK内:非同期的に後からストリームを追加していく
stitchableStream.addStream(step1Stream); // Step 1のLLM応答
stitchableStream.addStream(toolStream); // ツール実行結果
stitchableStream.addStream(step2Stream); // Step 2のLLM応答

// すべてのストリームを追加し終えたら閉じる
stitchableStream.close();

ちなみに、stitch とは「縫い目」などを意味する英単語であり、stitchable な stream とは「継ぎ接ぎ可能なストリーム」という意味だと思われます。

次に、このstitchableStreamがどうやって作成されるかを見ていきたいと思います。

createStitchableStream

全体像

createStitchableStreamは以下のようなプロパティを持ったオブジェクトを返します。

  • stream: クライアント側から読み取るための ReadableStream(inner stream と対比させて outer stream と呼ぶ)
  • addStream: stitchableStreamの inner stream のキューにストリームを追加するためのメソッド
  • close: ストリームを処理してから終了するためのメソッド
  • terminate: 即座に終了するためのメソッド

createStitchableStreamの全体像は以下のようになっています。
クロージャになっているので、innerStreamReadersなどの変数は、stitchableStreamの各メソッドから内部状態としてアクセス可能なことに注意してください。

create-stitchable-stream.ts
export function createStitchableStream<T>() {
  // 内部状態
  let innerStreamReaders: ReadableStreamDefaultReader<T>[] = [];
  let controller: ReadableStreamDefaultController<T> | null = null;
  let isClosed = false;
  let waitForNewStream = createResolvablePromise<void>();

  // キューから読み取る処理
  const processPull = async () => {
    // 後述
  };

  return {
    stream: new ReadableStream<T>({
      start(controllerParam) {
        controller = controllerParam;
      },
      pull: processPull,
      cancel: async () => {
        /* ... */
      },
    }),
    addStream: (innerStream: ReadableStream<T>) => {
      innerStreamReaders.push(innerStream.getReader());
      waitForNewStream.resolve();
    },
    close: () => {
      /* ... */
    },
    terminate: () => {
      /* ... */
    },
  };
}

クライアント側に渡すのは、stream(= outer stream)なので、streamの挙動が重要です。ReadableStream はバッファに空きがある場合、pullメソッドを呼び出すので、この時に呼ばれることになるprocessPullの実装について詳しく見ていきます。

processPull を理解する

processPullは以下のように実装されています。
簡単に言葉で説明すると、次のような説明になるかと思います。

  • outer stream がpull()されるたびに呼ばれ、innerStreamReadersのキューの先頭からデータを読み取り、outer stream に流す
  • キューが空の場合は新しい stream が追加されるまで待機し、stream が完了したらキューから削除して次の stream に進む
create-stitchable-stream.ts
const processPull = async () => {
  // Case 1: 正常に終了
  if (isClosed && innerStreamReaders.length === 0) {
    controller?.close();
    return;
  }

  // Case 2: 新しいinner streamを待機
  if (innerStreamReaders.length === 0) {
    waitForNewStream = createResolvablePromise<void>();
    await waitForNewStream.promise; // promiseが解決されるまで待つ
    return processPull(); // 再帰実行
  }

  try {
    const { value, done } = await innerStreamReaders[0].read();

    if (done) {
      // Case 3: 現在のstreamが完了したら先頭の要素を取り除く
      innerStreamReaders.shift();

      if (innerStreamReaders.length === 0 && isClosed) {
        controller?.close();
      } else {
        await processPull();
      }
    } else {
      // Case 4: データをouter streamに出力
      controller?.enqueue(value);
    }
  } catch (error) {
    // Case 5: エラー発生時はouter streamにエラーを流して終了
    controller?.error(error);
    innerStreamReaders.shift();
    terminate();
  }
};

注目すべき点はケース 2 のawait 式です。
waitForNewStreamという変数名が分かりやすいので、あえて説明する必要もありませんが、この Promise は inner stream が追加された時に履行されるように実装されています。

create-stitchable-stream.ts
// 再掲
addStream: (innerStream: ReadableStream<T>) => {
  innerStreamReaders.push(innerStream.getReader());
  waitForNewStream.resolve();  // ここ!!
},

processPullは、waitForNewStreamが解決されるまでの間は停止していますが、その間にもイベントループは動いています。なので、この Promise の解決を待っている最中にaddStreamが呼び出されると、Promise が履行され、次の処理に進むことになり、再帰的にprocessPullが実行され、追加された新しい stream を処理することになります。

ちなみに、createResolvablePromiseは以下のような実装で、Promise 生成後に外部からresolve()を呼べるようにしています。

create-resolvable-promise.ts
function createResolvablePromise<T>() {
  let resolve: (value: T) => void;
  let reject: (error: any) => void;

  const promise = new Promise<T>((res, rej) => {
    resolve = res; // Promiseのresolve関数を外部変数に保存
    reject = rej;
  });

  return { promise, resolve: resolve!, reject: reject! };
}

通常の Promise はコンストラクタ内でしか解決できませんが、この実装により好きなタイミングでresolve()を呼び出せます。シンプルだけど、超重要な役割を担っている感じがします 🤔

以上のような仕組みにより、streamTextは複数の LLM コールに対するストリームを 1 本のストリームにまとめて、クライアントから消費できるようになっています。
streamTextの実装自体はもっと長く、pipeThrough でストリームの中身を見てツールを実行する処理なども面白いなと思ったので、興味のある方はソースコードを読んでみるのオススメです!

感想

stitchableStream、とてもスマートな設計だと思いました。
OSS を読むと新しい発見があって勉強になるので、今後も頑張っていきたいと思います!

参考

Discussion