⛓️

LangChainでChain内のプロンプトも含めた消費トークンを取得する方法

2023/04/19に公開

LangChainでConversationChainなどライブラリ側で用意されているプロンプトを利用する場合に、プロンプトに埋め込まれる履歴情報などを含め、どの程度のトークンが消費されているか把握する方法はあるだろうか?

langchain.callbacks内にあるget_openai_callbackを利用すれば可能である。

ConversationChainの用意

それでは、まずはConversationChainを用意してみよう。

from langchain import ConversationChain
from langchain.chat_models import ChatOpenAI
from langchain.chains.conversation.memory import ConversationBufferMemory

conversation = ConversationChain(
    llm=ChatOpenAI(),
    memory=ConversationBufferMemory()
)

動作はこのようになる。run関数の戻り値としてtoken消費量は返ってこない。

conversation.run("おはよう!")
# => "おはようございます!朝ですね。現在の時間は何時ですか?(Good morning! It's morning now. What time is it currently?)"

callbackを受け取る関数の作成

そこで以下のようにcallbackを受け取る関数を作成する。

from langchain.callbacks import get_openai_callback

def count_tokens(chain, query):
    with get_openai_callback() as cb:
        result = chain.run(query)
        tokens = cb.total_tokens

    return tokens, result

この関数を以下のようにして使う。

tokens, result = count_tokens(conversation, "今は午前7時、夜行バスから降りたところだよ")
print(f"消費トークン: {tokens}")
print(result)

すると以下のように消費トークンもあわせて表示することができる。

消費トークン: 279
ああ、そうですか。夜行バスから降りたのですね。現在の時刻は午前7時ですね。あなたはどの地域にいるのですか?私はあなたがどこにいるか知りませんが、あなたがいる地域の天気予報を調べることができます。

会話を続けると履歴の分、消費トークン数が増えていることが分かる。

tokens, result = count_tokens(conversation, "今は梅田にいるよ")
print(f"消費トークン: {tokens}")
print(result)
消費トークン: 426
ああ、梅田にいるのですね。梅田は大阪市北区の繁華街で、商業施設やオフィスビルが多くあります。また、梅田は大阪駅周辺に位置しており、交通の便が良いことでも知られています。現在の梅田の天気は曇りで、気温は摂氏22度です。

Discussion