📈

LangChain の Callback でトークン数のログを取る

2023/09/25に公開

LangChain にはコールバックを渡すことができ、LLM へのリクエストを実行した後にトークン数を取得することができます。

ドキュメントでは get_openai_callback を使った方法が紹介されていますが、弊プロダクトでは少し抽象化して ChatOpenAI を使っていたので、callback を引数に渡す方法で試してみたのでご紹介します。

https://python.langchain.com/docs/modules/model_io/models/llms/token_usage_tracking

まず前提として LangChain は次のような Callback 関数群のインターフェイスが定義されています。

class BaseCallbackHandler:
    """Base callback handler that can be used to handle callbacks from langchain."""

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> Any:
        """Run when LLM starts running."""

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
        """Run when LLM ends running."""

    def on_llm_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> Any:
        """Run when LLM errors."""

# ...色々割愛

https://python.langchain.com/docs/modules/callbacks/

なのでこれを継承した Log 用クラスを作って Callback に渡します。
今回は LLM 実行時にトークンの消費量が取りたいので on_llm_end を使い、その引数の response の中にトークン数が入ってます。
下記では BigQuery に送っていますが、ご自身のログの保存先に応じて適宜書き換えてください。

class LogTokensHandler(BaseCallbackHandler):
    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
        bq_client.insert_rows(
            log_table,
            [
                {
                    "prompt_key": "_".join(kwargs["tags"]),
                    "prompt_tokens": response.llm_output["token_usage"].prompt_tokens,
                    "completion_tokens": response.llm_output[
                        "token_usage"
                    ].completion_tokens,
                    "total_tokens": response.llm_output["token_usage"].total_tokens,
                    "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                }
            ],
        )

ちなみに prompt_key のところですが、テンプレート毎に一意のタグを実行時に付与していることを前提としています。

response = self.llm.get_chat_completion(
    prompt, tags=["prompt_name"], **inputs
)

そしたらモデルを呼んでいる所で callbacks に渡すだけです。

ChatOpenAI(model="gpt-4-0613", callbacks=[LogTokensHandler()], **kwargs)

これでプロンプト実行毎にログされるので、あとは頑張ってクエリを書きましょう!

ちなみに LangChain が提供している LangSmith というモニタリングツールがあるのですが、こちらはトークン消費量を確認することができるのですが、タグでフィルターなどができないので今後のアップデートに期待といった状況です。

Discussion