🛠️

SlackとChatGPT APIでチャットボットを作る パート2(会話履歴管理編)

2023/08/08に公開

(2023-12-13 追記)最近のOpenAI SDKは仕様が変わっており、現在載せているコードは動かないので互換性のある古いOpenAI SDKを含むrequirents.txtを掲載します。Pythonコードも細部を少し修正しました。

requirements.txt
wheel
tenacity
slack_bolt
openai==0.28.1
tiktoken
pandas
matplotlib
japanize_matplotlib
seaborn
scikit-learn
ipykernel

SlackとChatGPT APIでチャットボットを作る パート2( function calling編 会話履歴管理編)

パート1ではチャットボットの骨組み部分を作りました。パート2ではfunction callingを使って自分のデータベースにアクセスする機能を実装 しようと思いますが、 するつもりだったのですが、その前にまずチャットの履歴管理を強化しようとしたらそこそこ長くなったので今回は会話履歴管理編とします。

履歴管理の強化

パート1のutilsモジュールを再掲します。

utils.py
from typing import Optional, Any, Callable, Generator
import os
import re
import openai
from openai.error import InvalidRequestError
import tiktoken
from tenacity import retry, retry_if_not_exception_type, wait_fixed

class Messages:
    def __init__(self, tokens_estimator: Callable[[dict], int]) -> None:
        """Initializes the Messages class.
        Args:
            tokens_estimator (Callable[[Dict], int]):
                Function to estimate the number of tokens of a message.
                Args:
                    message (Dict): The message to estimate the number of tokens of.
                Returns:
                    (int): The estimated number of tokens.
        """
        self.tokens_estimator = tokens_estimator
        self.messages = list()
        self.num_tokens = list()
    
    def append(self, message: dict[str, str], num_tokens: Optional[int]=None) -> None:
        """Appends a message to the messages.
        Args:
            message (Dict[str, str]): The message to append.
            num_tokens (Optional[int]):
                The number of tokens of the message.
                If None, self.tokens_estimator will be used.
        """
        self.messages.append(message)
        if num_tokens is None:
            self.num_tokens.append(self.tokens_estimator(message))
        else:
            self.num_tokens.append(num_tokens)
    
    def trim(self, max_num_tokens: int) -> None:
        """Trims the messages to max_num_tokens."""
        while sum(self.num_tokens) > max_num_tokens:
            _ = self.messages.pop(1)
            _ = self.num_tokens.pop(1)
    
    def rollback(self, n: int) -> None:
        """Rolls back the messages by n steps."""
        for _ in range(n):
            _ = self.messages.pop()
            _ = self.num_tokens.pop()

class ChatEngine:
    """Chatbot engine that uses OpenAI's API to generate responses."""
    size_pattern = re.compile(r"\-(\d+)k")

    @classmethod
    def get_max_num_tokens(cls) -> int:
        """Returns the maximum number of tokens allowed for the model."""
        mo = cls.size_pattern.search(cls.model)
        if mo:
            return int(mo.group(1))*1024
        elif cls.model.startswith("gpt-3.5"):
            return 4*1024
        elif cls.model.startswith("gpt-4"):
            return 8*1024
        else:
            raise ValueError(f"Unknown model: {cls.model}")

    @classmethod
    def setup(cls, model: str, tokens_haircut: float|tuple[float]=0.9) -> None:
        """Basic setup of the class.
        Args:
            model (str): The name of the OpenAI model to use, i.e. "gpt-3-0613" or "gpt-4-0613"
            tokens_haircut (float|Tuple[float]): coefficients to modify the maximum number of tokens allowed for the model.
        """
        cls.model = model
        cls.enc = tiktoken.encoding_for_model(model)
        match tokens_haircut:
            case tuple(x) if len(x) == 2:
                cls.max_num_tokens = round(cls.get_max_num_tokens()*x[1] + x[0])
            case float(x):
                cls.max_num_tokens = round(cls.get_max_num_tokens()*x)
        
        openai.api_key = os.getenv("OPENAI_API_KEY")

    @classmethod
    def estimate_num_tokens(cls, message: dict) -> int:
        """Estimates the number of tokens of a message.
        Args:
            message (Dict): The message to estimate the number of tokens of.
        Returns:
            (int): The estimated number of tokens.
        """
        return len(cls.enc.encode(message["content"]))
    
    def __init__(self) -> None:
        """Initializes the chatbot engine.
        """
        self.messages = Messages(self.estimate_num_tokens)
        self.messages.append({
            "role": "system",
            "content": "ユーザーを助けるチャットボットです。博多弁で答えます。"
        })
        self.completion_tokens_prev = 0
        self.total_tokens_prev = self.messages.num_tokens[-1]

    @retry(retry=retry_if_not_exception_type(InvalidRequestError), wait=wait_fixed(10))
    def _process_chat_completion(self, **kwargs) -> dict[str, Any]:
        """Processes ChatGPT API calling."""
        self.messages.trim(self.max_num_tokens)
        response = openai.ChatCompletion.create(
            model=self.model,
            messages=self.messages.messages,
            **kwargs
        )
        assert isinstance(response, dict)
        message = response["choices"][0]["message"]
        usage = response["usage"]
        self.messages.append(message, num_tokens=usage["completion_tokens"] - self.completion_tokens_prev)
        self.messages.num_tokens[-2] = usage["prompt_tokens"] - self.total_tokens_prev
        self.completion_tokens_prev = usage["completion_tokens"]
        self.total_tokens_prev = usage["total_tokens"]
        return message
    
    def reply_message(self, user_message: str) -> Generator:
        """Replies to the user's message.
        Args:
            user_message (str): The user's message.
        Yields:
            (str): The chatbot's response(s)
        """
        message = {"role": "user", "content": user_message}
        self.messages.append(message)
        try:
            message = self._process_chat_completion()
        except InvalidRequestError as e:
            yield f"## Error while Chat GPT API calling with the user message: {e}"
            return
        
        yield message['content']


人類の叡智を大半を学習しているかのようなChatGPTですが、その核となるLLMには実はあなたとの直前の会話を記憶する能力はありません。なので、パート1におけるChatEngineクラスではself.messagesというlistでチャット履歴を管理し、都度新しいメッセージを付け加えてChatGPTに投げ、返事を得ています。つまり、会話が続くとこのself.messagesが大きくなっていくのです。

ちなみに、このような管理はウェブからChatGPTを使う場合にはサーバーサイドで自動的に行われていて、ユーザーは意識する必要はありません。

本題に戻ると、ここで問題なのは、ChatGPTが扱える文章の長さ(トークン数)には制限があるということです。具体的には標準のgpt-3.5-turboでは4Kトークン、標準のgpt-4では8Kトークンが上限となっています。今後function callingを実装すると履歴にかなり大きなデータが加わるケースが想定されることもあり、会話履歴のサイズを管理する機能、より具体的にいうと、サイズが上限を超えたら古い履歴から順に消去する機能を実装したいと思います。

utils.py
from typing import List, Dict, Tuple, Optional, Union, Any, Callable
import os
import re
import openai
from openai.error import InvalidRequestError
import tiktoken
from tenacity import retry, retry_if_not_exception_type, wait_fixed

class Messages:
    def __init__(self, tokens_estimator: Callable[[Dict], int]) -> None:
        """Initializes the Messages class.
        Args:
            tokens_estimator (Callable[[Dict], int]):
                Function to estimate the number of tokens of a message.
                Args:
                    message (Dict): The message to estimate the number of tokens of.
                Returns:
                    (int): The estimated number of tokens.
        """
        self.tokens_estimator = tokens_estimator
        self.messages = list()
        self.num_tokens = list()
    
    def append(self, message: Dict[str, str], num_tokens: Optional[int]=None) -> None:
        """Appends a message to the messages.
        Args:
            message (Dict[str, str]): The message to append.
            num_tokens (Optional[int]):
                The number of tokens of the message.
                If None, self.tokens_estimator will be used.
        """
        self.messages.append(message)
        if num_tokens is None:
            self.num_tokens.append(self.tokens_estimator(message))
        else:
            self.num_tokens.append(num_tokens)
    
    def trim(self, max_num_tokens: int) -> None:
        """Trims the messages to max_num_tokens."""
        while sum(self.num_tokens) > max_num_tokens:
            _ = self.messages.pop(1)
            _ = self.num_tokens.pop(1)
    
    def rollback(self, n: int) -> None:
        """Rolls back the messages by n steps."""
        for _ in range(n):
            _ = self.messages.pop()
            _ = self.num_tokens.pop()

class ChatEngine:
    """Chatbot engine that uses OpenAI's API to generate responses."""
    size_pattern = re.compile(r"\-(\d+)k")

    @classmethod
    def get_max_num_tokens(cls) -> int:
        """Returns the maximum number of tokens allowed for the model."""
        mo = cls.size_pattern.search(cls.model)
        if mo:
            return int(mo.group(1))*1024
        elif cls.model.startswith("gpt-3.5"):
            return 4*1024
        elif cls.model.startswith("gpt-4"):
            return 8*1024
        else:
            raise ValueError(f"Unknown model: {cls.model}")

    @classmethod
    def setup(cls, model: str, tokens_haircut: float|Tuple[float]=0.9) -> None:
        """Basic setup of the class.
        Args:
            model (str): The name of the OpenAI model to use, i.e. "gpt-3-0613" or "gpt-4-0613"
            tokens_haircut (float|Tuple[float]): coefficients to modify the maximum number of tokens allowed for the model.
        """
        cls.model = model
        cls.enc = tiktoken.encoding_for_model(model)
        if isinstance(tokens_haircut, tuple):
            cls.max_num_tokens = round(cls.get_max_num_tokens()*tokens_haircut[1] + tokens_haircut[0])
        else:
            cls.max_num_tokens = round(cls.get_max_num_tokens()*tokens_haircut)
        openai.api_key = os.getenv("OPENAI_API_KEY")

    @classmethod
    def estimate_num_tokens(cls, message: Dict) -> int:
        """Estimates the number of tokens of a message.
        Args:
            message (Dict): The message to estimate the number of tokens of.
        Returns:
            (int): The estimated number of tokens.
        """
        return len(cls.enc.encode(message["content"]))
    
    def __init__(self) -> None:
        """Initializes the chatbot engine.
        """
        self.messages = Messages(self.estimate_num_tokens)
        self.messages.append({
            "role": "system",
            "content": "ユーザーを助けるチャットボットです。博多弁で答えます。"
        })
        self.completion_tokens_prev = 0
        self.total_tokens_prev = self.messages.num_tokens[-1]

    @retry(retry=retry_if_not_exception_type(InvalidRequestError), wait=wait_fixed(10))
    def _process_chat_completion(self, **kwargs) -> Dict[str, Any]:
        """Processes ChatGPT API calling."""
        self.messages.trim(self.max_num_tokens)
        response = openai.ChatCompletion.create(
            model=self.model,
            messages=self.messages.messages,
            **kwargs
        )
        message = response["choices"][0]["message"]
        usage = response["usage"]
        self.messages.append(message, num_tokens=usage["completion_tokens"] - self.completion_tokens_prev)
        self.messages.num_tokens[-2] = usage["prompt_tokens"] - self.total_tokens_prev
        self.completion_tokens_prev = usage["completion_tokens"]
        self.total_tokens_prev = usage["total_tokens"]
        return message
    
    def reply_message(self, user_message: str) -> None:
        """Replies to the user's message.
        Args:
            user_message (str): The user's message.
        Yields:
            (str): The chatbot's response(s)
        """
        message = {"role": "user", "content": user_message}
        self.messages.append(message)
        try:
            message = self._process_chat_completion()
        except InvalidRequestError as e:
            yield f"## Error while Chat GPT API calling with the user message: {e}"
            return
        
        yield message['content']

この改訂版のutils.pyではlistの代わりにMessagesというクラスで会話履歴を管理しています。Messagesのインスタンスは従来のself.messageslistに加えてself.num_tokensという各メッセージのトークン数を保持するlistを持っています。trimというメソッドを呼ぶことによって古いメッセージを削除し、トータルのトークン数を上限以内に抑えることができます。このとき、最初のメッセージは"role": "system"の重要なメッセージなので削除せず、2番目のメッセージから削除していきます。

API呼び出し前のトークン数の推定にはtiktokenを用いています。tiktokenは次のようにインストールできます:

pip install tiktoken

APIを呼び出した後、戻り値に含まれているusageでトークン数をアップデートしています。これは累積値らしいので差分を計算しています。usageから得られるトークン数は事前にtiktokenで計算したトークン数より若干多いのですが、実際にLLMに投げられるプロンプトはmessage["content"]に若干付け加わったものであるためと考えています。

パート2はこれでおしまいです。コードはtf-koichi/slack-chatbot at part2に置いてあります。今回は履歴管理を強化しただけなので、普通にチャットした感じはパート1と変わらないはずです。パート3は今度こそfunction calling編にします。以上、何かお気づきの点がありましたらフィードバックよろしくお願いします。

Discussion