🦔

LangGraphの途中出力をストリーミングする

に公開

コード生成エージェントを作っていた際に、途中結果をチャット画面に表示する方法を調べていました。
その過程で、LangGraphの途中出力をストリーミングする方法を見つけたので、ご紹介します。

LangGraph stremingのドキュメントを見ると、ストリーミングの設定がいくつかあります。今回は、customを利用したストリーミングを行います。

customストリーミング

ここでは、customストリーミングのために設定する箇所をご説明します。実際に動くコードは最後に記載します。
TypeScriptとPythonのコードのそれぞれを記載します。

customストリーミングを利用する場合に確認する場所は3つです。

  1. streamModecustomに設定
  2. writerを設定
  3. checkpointerを設定

streamModeの設定

グラフのstreamメソッドをコールする際に、streamModeをcustomに設定します。streamModeには、他にもmessages, values, debug, updatesを指定可能です。

typescript

const stream = await graph.stream({
    messages: ...
}, {
    ...
    streamMode: "custom",
});

python

stream = app.astream(
    {"messages": ...},
    config=...,
    stream_mode="custom",
)

writerの設定

ノード内でconfig.writerをコールして、ストリーミング出力したいデータを送ります。出力するデータとGraphStateは同じ型である必要はありません。

typescript

async function llmNode(state: typeof GraphState.State, config: LangGraphRunnableConfig): Promise<Command> {
    ...
    for await (const chunk of stream) {
        config?.writer?.({
            key: value,
            ...
        });
    }
    ...
    return new Command({
        update: {
            messages: ...,
        },
        goto: "next_node",
    });
}

python

async def llm_node(state: GraphState, writer: StreamWriter):
    ...
    async for chunk in stream:
        writer({
            key: value,
            ...
        })
    ...
    return Command(
        goto="next_node",
        update={"messages": ...},
    )

checkpointerの設定

checkpointerはワークフローの状態を保存する仕組みです。
streamModeを利用する際にはcheckpointerの設定が必要です。

typescript

const workflow = new StateGraph(GraphState)
...
const memory = new MemorySaver();
const app = workflow.compile({ 
    checkpointer: memory,
});

python

workflow = StateGraph(GraphState)
...
memory = MemorySaver()
app = workflow.compile(checkpointer=memory)

まとめ

steamModeはアプリにワークフローを組み込む際に非常に便利です。
TypeScriptとPythonのコードを記載したので、ケースに合わせてご利用ください。

実際のコード

最後にテストで動かしたコード全文と出力結果例を記載します。

コードでは、思考と回答を分けて出力させています。アプリではテキストボックスの色を変えるなどのケースを想定し、出力途中のメッセージとメッセージタイプの2つのデータをストリーミング出力させます。

TypeScript

import "dotenv/config";
import { v4 as uuidv4 } from 'uuid';
import { BaseChatModel } from "@langchain/core/language_models/chat_models";
import { ChatOpenAI } from "@langchain/openai";
import {
    START,
    END,
    StateGraph,
    MemorySaver,
    StreamMode,
    LangGraphRunnableConfig,
    Annotation,
    Command, 
} from "@langchain/langgraph";
import { SystemMessage, HumanMessage, AIMessage, BaseMessage } from "@langchain/core/messages";
import { RunnableConfig } from "@langchain/core/runnables";


type LangGraphConfigType = {
    configurable: {
        thread_id: string;
    };
    streamMode: StreamMode;
};

type ParserOutputType = {
    textList: string[];
    nextType?: string | null;
}

const GraphState = Annotation.Root({
    messages: Annotation<BaseMessage[]>({
        // reducer: (x, y) => x.concat(y),
        reducer: (x, y) => x.concat(y),
        default: () => [],
    }),
})

function getSystemPrompt() {
    return `ユーザー空のクエリに対して、以下の手順で回答してください。
1. ユーザーのクエリを分析する。
2. 分析結果をもとに、回答を生成する。

出力ルール:
出力は2つのブロックからなります。
1. 分析結果を記述したブロック([THOUGHT])
2. 回答を記述したブロック([OUTPUT])

以下の例を参考に、出力してください。
出力例: 
[THOUGHT]ここに分析結果を記述します。
[OUTPUT]ここに回答を記述します。`
}


// output type getter
function XmlOutputParser(currentContent: string): ParserOutputType {
    // currentContentを<thought>で分割し、リストを取得
    const thoughtStartList = currentContent.split('[THOUGHT]');
    if (thoughtStartList.length >= 2) {
        return {
            textList: thoughtStartList,
            nextType: 'thought',
        }
    }
    const outputStartList = currentContent.split('[OUTPUT]');
    if (outputStartList.length >= 2) {
        return {
            textList: outputStartList,
            nextType: 'output',
        }
    }
    return {
        textList: [currentContent],
        nextType: undefined,
    }
}

async function buildLangGraph(llm: BaseChatModel) {
    // llm generation node
    async function llmNode(state: typeof GraphState.State, config: LangGraphRunnableConfig): Promise<Command> {
        try {
            console.log("Starting llm_node execution");
            const messages = [new SystemMessage(getSystemPrompt())].concat(state.messages);
            const stream = await llm.stream(messages);
            let allContent = '';  // chunkを結合
            let currentContent = '';  // typeの塊で保持
            let outputContent = '';  // 出力する内容
            let currentType: string | undefined = undefined;  // 現在のtype
            let nextType: string | undefined = undefined;  // 次のtype
            
            // ストリームを処理して結果を得る
            for await (const chunk of stream) {
                allContent += chunk.content;
                currentContent += chunk.content;
                const output = XmlOutputParser(currentContent);
                if (output.nextType === 'thought') {
                    // thoughtがきた場合は、次のメッセージのタイプをthoughtにして、切り替わる前を連携
                    outputContent = output.textList[0];
                    currentContent = output.textList[output.textList.length - 1];
                    nextType = 'thought';
                } else if (output.nextType === 'output') {
                    // outputも同様
                    outputContent = output.textList[0];
                    currentContent = output.textList[output.textList.length - 1];
                    nextType = 'output';
                } else {
                    // 切り替わる前以外は、前回のタイプを保持して連携
                    currentContent = output.textList[output.textList.length - 1];
                    outputContent = currentContent;
                }

                if (currentType && ['thought', 'output'].includes(currentType)) {
                    // config.writerでカスタム出力を定義
                    config?.writer?.({
                        messageType: currentType,
                        messageContent: outputContent,
                    });
                }
                currentType = nextType;
            }
            
            console.log("LLM processing completed");
            console.log("--------------------------------");
            
            // LLMの処理が完全に終わってから次のノードに移る
            return new Command({
                update: {
                    messages: [new AIMessage(allContent)],
                },
                goto: END,
            });
        } catch (error) {
            console.error("Error in llm_node:", error);
            throw error; // エラーを再スローして上位で処理できるようにする
        }
    }

    const workflow = new StateGraph(GraphState)
        .addNode("llm_node", llmNode, { ends: [END] })
        .addEdge(START, "llm_node")

    // custom出力を利用する場合は、checkpointerを設定
    const memory = new MemorySaver();
    
    const app = workflow.compile({ 
        checkpointer: memory,
    });

    return app;
};

async function main() {
    const llm = new ChatOpenAI({
        model: "gpt-4o-mini",
        temperature: 0,
    });

    const app = await buildLangGraph(llm);

    // LangGraphの設定
    const langGraphConfig = {
        configurable: {
            thread_id: uuidv4(),
        },
        streamMode: "custom" as StreamMode,
    };

    const stream = await app.stream(
        {
            messages: [new HumanMessage("LLMについて教えてください。")],
        }, 
        { 
            ...langGraphConfig, 
            recursionLimit: 10  // 再帰回数の上限. ループが長くなる場合は多めに設定
        }
    );

    for await (const chunk of stream) {
        console.log(chunk);
    }

    const resultState = await app.getState(langGraphConfig as RunnableConfig);
    console.log(resultState.values);
}

main();

Python

import asyncio
import os
import uuid
from typing import TypedDict, List, Optional, Annotated, Sequence
from dotenv import load_dotenv

from langgraph.graph import StateGraph, START, END
from langgraph.types import StreamWriter, Command
from langgraph.checkpoint.memory import MemorySaver
# Assuming StreamMode exists or is handled differently in Python LangGraph stream APIs
# from langgraph.stream import StreamMode # May not be needed directly
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from termcolor import colored

# Load environment variables
load_dotenv()

# Type definitions (using TypedDict for structured data)
class Configurable(TypedDict):
    thread_id: str

class ParserOutputType(TypedDict):
    text_list: List[str]
    next_type: Optional[str]

# Define the state structure using TypedDict and Annotated
class GraphState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], lambda x, y: x + y]

def get_system_prompt() -> str:
    return """ユーザー空のクエリに対して、以下の手順で回答してください。
1. ユーザーのクエリを分析する。
2. 分析結果をもとに、回答を生成する。

出力ルール:
出力は2つのブロックからなります。
1. 分析結果を記述したブロック([THOUGHT])
2. 回答を記述したブロック([OUTPUT])

以下の例を参考に、出力してください。
出力例:
[THOUGHT]ここに分析結果を記述します。
[OUTPUT]ここに回答を記述します。"""

# output type getter
def xml_output_parser(current_content: str) -> ParserOutputType:
    # current_contentを<thought>で分割し、リストを取得
    thought_start_list = current_content.split('[THOUGHT]')
    if len(thought_start_list) >= 2:
        return {
            "text_list": thought_start_list,
            "next_type": 'thought',
        }
    output_start_list = current_content.split('[OUTPUT]')
    if len(output_start_list) >= 2:
        return {
            "text_list": output_start_list,
            "next_type": 'output',
        }
    return {
        "text_list": [current_content],
        "next_type": None,
    }

async def build_langgraph(llm: BaseChatModel):
    # llm generation node
    async def llm_node(state: GraphState, writer: StreamWriter):
        try:
            print("Starting llm_node execution")
            messages = [SystemMessage(content=get_system_prompt())] + state['messages']
            # Use astream for async iteration in Python
            stream = llm.astream(messages)

            all_content = ''  # chunkを結合
            current_content = ''  # typeの塊で保持
            output_content = ''  # 出力する内容
            current_type: Optional[str] = None  # 現在のtype
            next_type: Optional[str] = None  # 次のtype

            # ストリームを処理して結果を得る
            async for chunk in stream:
                chunk_content = chunk.content
                if isinstance(chunk_content, str):
                    all_content += chunk_content
                    current_content += chunk_content
                    output = xml_output_parser(current_content)

                    if output['next_type'] == 'thought':
                        # thoughtがきた場合は、次のメッセージのタイプをthoughtにして、切り替わる前を連携
                        output_content = output['text_list'][0]
                        current_content = output['text_list'][-1]
                        next_type = 'thought'
                    elif output['next_type'] == 'output':
                        # outputも同様
                        output_content = output['text_list'][0]
                        current_content = output['text_list'][-1]
                        next_type = 'output'
                    else:
                        current_content = output['text_list'][-1]
                        output_content = current_content

                    if current_type and current_type in ['thought', 'output']:
                        # writerでストリーミング出力したいデータを送る
                        writer({"message_type": current_type, "message_content": output_content})

                    current_type = next_type

            print("LLM processing completed")
            print("--------------------------------")

            return Command(
                goto=END,
                update={"messages": [AIMessage(content=all_content)]},
            )

        except Exception as e:
            print(f"Error in llm_node: {e}")
            raise e # Re-throw the error

    workflow = StateGraph(GraphState)
    workflow.add_node("llm_node", llm_node)
    workflow.add_edge(START, "llm_node")

    # custom出力を利用する場合は、checkpointerを設定
    memory = MemorySaver()

    app = workflow.compile(checkpointer=memory)
    return app

async def main():
    llm = ChatOpenAI(
        model="gpt-4o-mini",
        temperature=0,
    )

    app = await build_langgraph(llm)

    # LangGraphの設定
    thread_id = str(uuid.uuid4())
    langgraph_config: RunnableConfig = {
        "configurable": {
            "thread_id": thread_id,
        },
        "recursion_limit": 10  # Recursion limit is part of the config
    }

    print(f"Using Thread ID: {thread_id}")
    async for chunk in app.astream(
        {"messages": [HumanMessage(content="LLMについて教えてください。")]},
        config=langgraph_config,
        stream_mode="custom",
    ):
        print(colored(chunk, "green"))

    # Get final state
    final_state = await app.aget_state(langgraph_config)
    print("--- Final State ---")
    print(final_state.values)
    print("-------------------")


if __name__ == "__main__":
    asyncio.run(main()) 

出力結果例

Starting llm_node execution
{'message_type': 'thought', 'message_content': ' ユ'}
{'message_type': 'thought', 'message_content': ' ユー'}
...
{'message_type': 'output', 'message_content': ' L'}
{'message_type': 'output', 'message_content': ' LLM'}
...
LLM processing completed
--------------------------------
--- Final State ---
{'messages': [HumanMessage(content='LLMについて教えてください。', additional_kwargs={}, response_metadata={}), AIMessage(content='[THOUGHT] ユーザーは「LLM」という用語についての情報を求めています。LLMは「大規模言語モデル(Large Language Model)」の略であり、自然言語処理において使用されるAI技術の一つです。ユーザーはおそらくLLMの基本的な概念、機能、用途、またはその 利点について知りたいと考えているでしょう。\n\n[OUTPUT] LLM(大規模言語モデル)とは、膨大な量のテキストデータを学習し、自然言語を理解し生成する能力を持つAIモデルのことです。これらのモデルは、文章の生成、質問応答、翻訳、要約など、さまざまな自然言語処理タスクに利用されます。LLMは、トランスフォーマーアーキテクチャを基にしており、文脈を考慮した言語理解が可能です。主な 利点としては、高い精度での言語生成、柔軟な応用範囲、そして人間のような対話能力が挙げられます。代表的なLLMには、OpenAIのGPTシリーズやGoogleのBERTなどがあります。', additional_kwargs={}, response_metadata={})]}
-------------------

Discussion