🦜

【LangChain】「チェーンの型」を定義する方法

に公開

LangChain を使ってアプリケーションを開発していると、チェーンの入出力で辞書(dict)を扱う場面が頻繁にあります。

Python は動的型付け言語であり、LangChain 自体も非常に柔軟な設計になっています。
この柔軟性は大きなメリットである一方、特に複数人で開発していたり、後からコードを修正したりする際に、データの受け渡し部分で混乱が生じやすいポイントにもなり得ます。

もっと安心して開発を進めるための一つのアプローチとして、本記事では LangChain のチェーン自体に型を定義する方法をご紹介します。

※以前に書いた記事のリライト版です。

TypedDict@chain でチェーンに型を定義

入出力に型を定義するのはご存知、標準ライブラリ typing に含まれる TypedDict です。
TypedDict を使うと、辞書のキーとその値の型を明示的に定義できます。

from typing import TypedDict

class MyInput(TypedDict):
    name: str
    age: int

このように定義することで、MyInput 型の辞書は name というキーに文字列を、age というキーに整数を持つことが期待されるとコード上で明確に示せます。

さらに @chain デコレータを利用するとチェーン自体にも型を付与できるため、より堅牢な設計が可能になります。
@chain デコレータについては後述しますが、ここでは簡単に「関数をチェーンに変換するデコレータ」とだけ覚えておいてください。

https://python.langchain.com/docs/how_to/functions/#the-convenience-chain-decorator

実践: TypedDict@chain を使った具体的な型定義方法

それでは、実際にLangChainのチェーンに TypedDict@chain を使って型を定義する方法を見ていきましょう。

1. 入出力用の TypedDict を定義する

まず、チェーンの入力と出力に対応する TypedDict を定義します。
ここでは、入力として文字列を受け取り、出力として文字列を返す簡単な翻訳チェーンを例にします。

from typing import TypedDict

class ChainInput(TypedDict):
    input: str  # 入力文字列

class ChainOutput(TypedDict):
    output: str # 出力文字列

2. Runnable でチェーンの型定義し、@chain デコレータで型通りのチェーンを作成

次に、この ChainInputChainOutput を使って、チェーンの型を定義します。
チェーンの型は Runnable[チェーン入力, チェーン出力] の形式で指定します。

@chain で関数をデコレートし、型が付いたチェーンを返却できるようにします。

# チェーンの型を定義👇🏻
Chain = Runnable[ChainInput, ChainOutput]

def build_chain(chat_model: BaseChatModel) -> Chain:

    @chain
    async def _build_chain(input_data: ChainInput) -> AsyncGenerator[ChainOutput, None]:
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", "あなたは優秀な翻訳者です。"),
                (
                    "user",
                    "次の文章を英訳してください。\n{input_text}",
                ),
            ]
        )

        lc_chain = prompt | chat_model

        async for output_chunk in lc_chain.astream({"input_text": input_data["input"]}):
            if isinstance(output_chunk.content, str):
                yield {"output": output_chunk.content} # ChainOutput 型で返す

    return _build_chain

ちょっと分かりづらいかもしれませんが、build_chain は独自で定義したチェーンの型 Chain を持つチェーンを返す関数です。
build_chain を利用してチェーンを構築すると、ChainInput 型の辞書を引数に取り、ChainOutput 型の辞書を返すことが保証されます。

ポイント:

  • Chain = Runnable[ChainInput, ChainOutput] とすることで、build_chain 関数が返すチェーンの入出力型が明確になります。
  • @chain でデコレートされた関数 _build_chain の引数 input_dataChainInput 型アノテーションを付け、返り値の yieldChainOutput 型の辞書を返すようにします。

3. 型定義の恩恵

このようにチェーンの型を定義すると、当然ですが開発時に IDE がサポートしてくれます。

例えば、チェーンを呼び出す際に、ChainInput で定義したキー(この例では input)以外を指定しようとしたり、異なる型の値を渡そうとすると、エディタが警告を表示してくれます。

alt text

また、チェーンの処理フローの中で、どのキーにどんなデータが入っているかが型定義から明確にわかるため、コードの可読性やメンテナンス性が大きく向上します。

コード全体と実行結果

以下に、ここまでの説明で用いた型定義を適用した LangChain のサンプルコード全体と、その実行結果を示します。

from typing import AsyncGenerator, TypedDict

from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import (
    ChatPromptTemplate,
)
from langchain_core.runnables import Runnable, chain


class ChainInput(TypedDict):
    input: str


class ChainOutput(TypedDict):
    output: str


Chain = Runnable[ChainInput, ChainOutput]


def build_chain(chat_model: BaseChatModel) -> Chain:

    @chain
    async def _build_chain(input_data: ChainInput) -> AsyncGenerator[ChainOutput, None]:
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", "あなたは優秀な翻訳者です。"),
                (
                    "user",
                    "次の文章を英訳してください。\n{input_text}", # プロンプト内の変数名
                ),
            ]
        )

        lc_chain = prompt | chat_model

        async for output_chunk in lc_chain.astream({"input_text": input_data["input"]}):
            if isinstance(output_chunk.content, str):
                yield {"output": output_chunk.content}

    return _build_chain

if __name__ == "__main__":
    import asyncio

    from dotenv import load_dotenv
    from langchain_openai import ChatOpenAI

    load_dotenv()

    chat_model = ChatOpenAI(model="gpt-3.5-turbo")

    typed_chain = build_chain(chat_model)

    async def main():
        async for output_chunk in typed_chain.astream(
            {"input": "こんにちは。私の名前はジョンです。"}
        ):
            print(output_chunk)

    asyncio.run(main())

実行結果例 (ストリーミング出力):

python -m chains.chain # ファイル名に合わせて実行
{'output': ''}
{'output': 'Hello'}
{'output': ','}
{'output': ' my'}
{'output': ' name'}
{'output': ' is'}
{'output': ' John'}
{'output': '.'}
{'output': ''}

(お使いのモデルやタイミングによって、細切れの粒度や内容は変わることがあります)

ストリーミング処理 (astream) を利用している場合でも、yield される各チャンクが ChainOutput の型に従うことが期待され、受け取り側もそれを前提に処理を書くことができます。

(補足) LangGraph のグラフも同様に型定義可能

今回は LangChain の基本的なチェーンに焦点を当てましたが、より複雑なエージェントやマルチステップの処理を構築できる LangGraph においても、この TypedDict@chain を使った型定義のアプローチで型を定義できます。

おわりに

LangChain のチェーンや LangGraph の入出力に TypedDict@chain を用いて型を定義することで、より安全で効率的な開発が可能になります。

IDE の補完やエラーチェックの恩恵を受けながら、コードを書けるようになるはずです。

とある通信会社の有志

Discussion