🏃‍♂️

【Python/TypeScript】LangChainを使ったOpenAIのトークン消費量の算出法

2024/05/07に公開

はじめに

最近GPT、Gemini、Claude、Command R+などのLLM(Large Language Models)間の性能比較が盛んに行われていますよね。特に、GPT-4-Turbo-2024-04-09がClaude 3 Opusのベンチマークを超えたというニュースは記憶に新しいでしょう。今回の記事では、このようなLLMを扱う際に欠かせないトークン消費量の計算方法を、GPTに限定する形になりますがPythonとTypeScriptの両方を例に解説します。
ただ、OpenAIのAPIに慣れた方は「レスポンスに含まれる情報では?」と思うかもしれませんが、ストリーミングを有効にした場合は、2024/05/07時点でこれは適用されません。しかし、今後APIの仕様が変更されストーリミングでトークン数が取得できるようになる可能性があるため、継続してのウォッチが必要です。
ですので、今回の記事の内容はそれまでの暫定対応に近いニュアンスがあります。

ここでいうトークンとは、OpenAIが定義する単位で、テキストの量を表します。この記事では、PythonとTypeScriptを使用してトークン消費量を具体的に計算する方法を紹介します。実装の詳細に先立ち、LangChainでストリーミングを活用しつつ、会話やRAG(Retrieval-Augmented Generation)のトークン消費を計測する手法については、こちらの記事も参考にしてください。
本記事では上記をPythonとTypeScriptの両方に展開をし、より詳細に解説をした形になります。

目次

  1. Pythonによる実装
  2. TypeScriptによる実装
  3. 実装内容の説明
  4. まとめ

Pythonでの実装

Pythonを使用してOpenAI APIのトークン消費量を計算する方法を以下に示します。具体的には、tiktokenライブラリを利用してエンコーディングを取得し、メッセージごとのトークン数を計算します。ここでは、各メッセージに「assistant」というロール識別子を含めた上で、メッセージの開始と終了を示す追加トークンも考慮に入れています。

import tiktoken

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import BaseMessage
from langchain.schema.output import LLMResult

from logging import getLogger

logger = getLogger(__name__)


def _num_tokens_from_messages(messages: list[BaseMessage], model: str) -> int:
    try:
        encoding = tiktoken.encoding_for_model(model)
        # 新しいモデルではメッセージの開始と終了、およびロール識別子(例えば"assistant")の3つのトークンが追加される
        # 参考:https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
        # 参考:https://community.openai.com/t/what-is-the-reason-for-adding-total-7-tokens/337002/12
        tokens_per_message = 3
        num_tokens = 0
        for message in messages:
            # 参考:https://community.openai.com/t/when-we-prompt-the-open-ai-model-are-the-roles-counted-in-input-token/513343
            num_tokens += tokens_per_message
            num_tokens += len(encoding.encode(message.type))
            num_tokens += len(encoding.encode(message.content))
        # アシスタントの応答を開始するためのプロンプト
        # "assistant"というロール識別子と、その後に続くメッセージの開始と終了を示す
        # 参考:https://community.openai.com/t/what-is-the-reason-for-adding-total-7-tokens/337002/12
        num_tokens += 3
        return num_tokens
    except KeyError as e:
        # 不適切なモデル名を選択している場合
        logger.exception(e)
        return 0


# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
# トークン数算出のために必要に応じてカスタマイズする
class CostCalculateCallbackHandler(BaseCallbackHandler):
    token = ""
    total_tokens: int = 0
    prompt_tokens: int = 0
    completion_tokens: int = 0

    def __init__(self, model: str) -> None:
        super().__init__()
        self.model = model
        self.encoding = tiktoken.encoding_for_model(model)

    def on_llm_new_token(self, token: str, **kwargs: any) -> None:
        self.token += token

    # 処理内容によってon_llm_endかon_chain_endのいずれかが呼ばれる
    # 基本的にこちらが呼ばれる
    def on_llm_end(
        self,
        response: LLMResult,
        **kwargs: any,
    ) -> any:
        self.completion_tokens = len(self.encoding.encode(self.token))
        self.total_tokens = self.completion_tokens + self.prompt_tokens

    # QARetrieverを使う際に呼ばれる
    def on_chain_end(
        self,
        outputs: dict[str, any],
        **kwargs: any,
    ) -> any:
        self.completion_tokens = len(self.encoding.encode(self.token))
        self.total_tokens = self.completion_tokens + self.prompt_tokens

    def on_chat_model_start(
        self,
        serialized: dict[str, any],
        messages: list[list[BaseMessage]],
        **kwargs: any,
    ) -> any:
        for base_messages in messages:
            self.prompt_tokens += _num_tokens_from_messages(base_messages, self.model)


# embeddings API用
# https://github.com/langchain-ai/langchain/issues/945#issuecomment-1538498870
def num_tokens_for_embedding(string: str) -> int:
    """
    Returns the number of tokens in a text string for embedding.

    Parameters:
        string (str): The text string to be tokenized.

    Returns:
        int: The number of tokens in the text string.
    """
    encoding = tiktoken.get_encoding("cl100k_base")
    num_tokens = len(encoding.encode(string))
    return num_tokens

TypeScriptでの実装

同様にTypeScriptでの実装例は以下の通りになります。

import { BaseMessage } from "@langchain/core/messages";
import { BaseCallbackHandler, NewTokenIndices } from "@langchain/core/callbacks/base";
import { Tiktoken, getEncoding, encodingForModel, TiktokenModel } from "js-tiktoken";
import { Serialized } from "@langchain/core/load/serializable";
import { ChainValues } from "@langchain/core/utils/types";
import { LLMResult } from "@langchain/core/outputs";
import { DocumentInterface } from "@langchain/core/documents";

const numTokensFromMessages = (messages: BaseMessage[], model: TiktokenModel): number => {
    try {
        const encoding = encodingForModel(model);
        /* 
            NOTE: 新しいモデルではメッセージの開始と終了、およびロール識別子(例えば"assistant")の3つのトークンが追加される
            参考: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
            参考: https://community.openai.com/t/what-is-the-reason-for-adding-total-7-tokens/337002/12
        */
        const tokensPerMessage = 3;
        let numTokens = 0;
        for (const message of messages) {
            // 参考:https://community.openai.com/t/when-we-prompt-the-open-ai-model-are-the-roles-counted-in-input-token/513343
            numTokens += tokensPerMessage;
            numTokens += encoding.encode(message._getType()).length;
            numTokens += encoding.encode(message.content as string).length;
        }
        /*
            NOTE: アシスタントの応答を開始するためのプロンプト
            "assistant"というロール識別子と、その後に続くメッセージの開始と終了を示す
            参考:https://community.openai.com/t/what-is-the-reason-for-adding-total-7-tokens/337002/12
        */
        numTokens += 3;
        return numTokens;
    } catch (e) {
        // NOTE: 不適切なモデル名を選択している場合
        console.error(e);
        return 0;
    }
}

// NOTE: トークン数算出のために必要に応じてカスタマイズする
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
export class CostCalculateCallbackHandler extends BaseCallbackHandler {
    token = "";
    totalTokens = 0;
    promptTokens = 0;
    completionTokens = 0;
    model: TiktokenModel;
    encoding: Tiktoken;
    name = "CostCalculateCallbackHandler";

    constructor(model: TiktokenModel) {
        super();
        this.model = model;
        this.encoding = encodingForModel(model);
    }

    handleLLMNewToken(token: string, idx: NewTokenIndices, runId: string): void {
        this.token += token;
    }

    // NOTE: 処理内容によってonLlmEndかonChainEndのいずれかが呼ばれる
    // 基本的にこちらが呼ばれる
    handleLLMEnd(output: LLMResult, runId: string): any {
        this.completionTokens = this.encoding.encode(this.token).length;
        this.totalTokens = this.completionTokens + this.promptTokens;
    }

    // NOTE: QARetrieverを使う際に呼ばれる
    handleChainEnd(outputs: ChainValues, runId: string): any {
        this.completionTokens = this.encoding.encode(this.token).length;
        this.totalTokens = this.completionTokens + this.promptTokens;
    }

    handleRetrieverEnd(documents: DocumentInterface<Record<string, any>>[], runId: string, parentRunId?: string | undefined, tags?: string[] | undefined): any {
        this.completionTokens = this.encoding.encode(this.token).length;
        this.totalTokens = this.completionTokens + this.promptTokens;
    }

    handleChatModelStart(llm: Serialized, messages: BaseMessage[][], runId: string): any {
        for (const baseMessages of messages) {
            this.promptTokens += numTokensFromMessages(baseMessages, this.model);
        }
    }
}

// NOTE: embeddings API用
// https://github.com/langchain-ai/langchain/issues/945#issuecomment-1538498870
export const numTokensForEmbedding = (input: string): number => {
    const encoding = getEncoding("cl100k_base");
    const numTokens = encoding.encode(input).length;
    return numTokens;
}

実装内容の説明

このコードは、トークン数の計算やプロセスハンドリングに関わる複数の機能を含んでいます。今回はPython側の処理をピックアップして、それぞれの部分を簡単に分解して説明します:

1. tiktoken ライブラリの利用

import tiktoken で、tiktoken ライブラリをインポートしています。このライブラリは、テキストをトークン化し、それを数えるために使用されます。トークンは、自然言語処理モデルで入力を理解するための基本的な単位です。

2. _num_tokens_from_messages 関数

この関数は、メッセージリストからトータルのトークン数を計算します。各メッセージには、そのタイプと内容に応じてトークンが割り当てられ、特定のトークンがこれに追加されます(例えば、"assistant"としての役割識別子など)。

3. CostCalculateCallbackHandler クラス

このクラスは、特定のモデルに基づいてトークン計算を行いながら、異なるコールバックメソッド(on_llm_new_token, on_llm_end, on_chain_end, on_chat_model_start)を利用して、言語モデルの処理を追跡し、コスト計算をします。このクラスは、モデルからの新しいトークンを受け取り、プロンプトやモデルの出力にかかるトークン数を計算します。

4. num_tokens_for_embedding 関数

この関数は、特定のテキスト文字列をトークン化し、そのトークン数を返します。これは特に、言語モデルの入力としてどの程度のトークンが必要かを把握するのに有用です。(今回はEmbeddings APIのトークン消費量計算に使います。)

使い方の例

呼び出し側の実装例

  1. Chat Completions APIのトークン消費量
  • Python
    # CostCalculateCallbackHandlerのインスタンスを宣言し、ChatOpenAIインスタンス生成時にcallback関数として指定する
    gpt4_handler = CostCalculateCallbackHandler(model="gpt-4-turbo")
    llm_gpt4 = ChatOpenAI(
        model="gpt-4-turbo", temperature=0.5, streaming=True, callbacks=[gpt4_handler]
    )

    # チャットのやり取り

    # 入力プロンプト、出力プロンプト、及び入力・出力プロンプトの合計のトークン消費量を取得できる
    prompt_tokens: int = gpt4_handler.prompt_tokens,
    completion_tokens: int = gpt4_handler.completion_tokens,
    total_tokens: int = gpt4_handler.total_tokens,
  • TypeScript
    // CostCalculateCallbackHandlerのインスタンスを宣言し、ChatOpenAIインスタンス生成時にcallback関数として指定する
    const chatHandler = new CostCalculateCallbackHandler(selectedModel.modelName as TiktokenModel)
    const chat = new ChatOpenAI({
        modelName: "gpt-4-turbo",
        temperature: 0.5,
        streaming: true,
        callbacks: [chatHandler]
    });

    // チャットのやり取り

    // 入力プロンプト、出力プロンプト、及び入力・出力プロンプトの合計のトークン消費量を取得できる
    const promptTokens =  chatHandler.promptTokens
    const completionTokens = chatHandler.completionTokens
    const totalTokens = chatHandler.totalTokens
  1. Embeddings APIのトークン消費量
  • Embeddings APIは入力プロンプトの消費量のみが考慮されるため、こちらを元に算出する
  • Python
    # 入力プロンプト=合計プロンプトとなる
    prompt_tokens: int = num_tokens_for_embedding(user_input_message)
    completion_tokens: int = 0
    total_tokens: int = prompt_tokens
  • TypeScript
    // 入力プロンプト=合計プロンプトとなる
    const promptTokens = numTokensForEmbedding(userQuestion)
    const completionTokens = 0
    const totalTokens = promptTokens

まとめ

こちらの記事では、LangChainライブラリを使用してPythonとTypeScriptの両方でOpenAI APIのトークン消費量を計算する方法について解説しました。トークン消費量の計算は、APIコストの管理や制限内で最大限の効果を得るために重要です。特に、ストリーミングAPIを利用する場合、リアルタイムでのトークン計算が非常に重要になります。

どのようにしてトークン消費量を活用するか

トークン消費量の把握は、次のようなシナリオで役立ちます:

  1. 予算管理:APIのコストを把握し、予算内での使用を計画する。
  2. 効率的な設計:必要なトークン数を最小限に抑えるようなクエリ設計を行う。
  3. 性能の最適化:トークン数が多いクエリが性能に与える影響を分析し、最適化する。

冒頭でもお伝えした通り、今後はストリーミングAPIに対してもレスポンスでトークン消費量が返ってくるようになるかもしれませんので、情報のウォッチは必要です。また、今回の対応方法だとAssistant API等々他のAPIに対しては課金体系が異なってくることもあり、別途対処法を検討する必要があります。こちらに関しては、別の機会で執筆ができたらと思います。

Discussion