SlackとChatGPT APIでチャットボットを作る パート2(会話履歴管理編)
(2023-12-13 追記)最近のOpenAI SDKは仕様が変わっており、現在載せているコードは動かないので互換性のある古いOpenAI SDKを含む
requirents.txt
を掲載します。Pythonコードも細部を少し修正しました。
wheel
tenacity
slack_bolt
openai==0.28.1
tiktoken
pandas
matplotlib
japanize_matplotlib
seaborn
scikit-learn
ipykernel
function calling編 会話履歴管理編)
SlackとChatGPT APIでチャットボットを作る パート2( パート1ではチャットボットの骨組み部分を作りました。パート2ではfunction callingを使って自分のデータベースにアクセスする機能を実装 しようと思いますが、 するつもりだったのですが、その前にまずチャットの履歴管理を強化しようとしたらそこそこ長くなったので今回は会話履歴管理編とします。
履歴管理の強化
パート1のutilsモジュールを再掲します。
from typing import Optional, Any, Callable, Generator
import os
import re
import openai
from openai.error import InvalidRequestError
import tiktoken
from tenacity import retry, retry_if_not_exception_type, wait_fixed
class Messages:
def __init__(self, tokens_estimator: Callable[[dict], int]) -> None:
"""Initializes the Messages class.
Args:
tokens_estimator (Callable[[Dict], int]):
Function to estimate the number of tokens of a message.
Args:
message (Dict): The message to estimate the number of tokens of.
Returns:
(int): The estimated number of tokens.
"""
self.tokens_estimator = tokens_estimator
self.messages = list()
self.num_tokens = list()
def append(self, message: dict[str, str], num_tokens: Optional[int]=None) -> None:
"""Appends a message to the messages.
Args:
message (Dict[str, str]): The message to append.
num_tokens (Optional[int]):
The number of tokens of the message.
If None, self.tokens_estimator will be used.
"""
self.messages.append(message)
if num_tokens is None:
self.num_tokens.append(self.tokens_estimator(message))
else:
self.num_tokens.append(num_tokens)
def trim(self, max_num_tokens: int) -> None:
"""Trims the messages to max_num_tokens."""
while sum(self.num_tokens) > max_num_tokens:
_ = self.messages.pop(1)
_ = self.num_tokens.pop(1)
def rollback(self, n: int) -> None:
"""Rolls back the messages by n steps."""
for _ in range(n):
_ = self.messages.pop()
_ = self.num_tokens.pop()
class ChatEngine:
"""Chatbot engine that uses OpenAI's API to generate responses."""
size_pattern = re.compile(r"\-(\d+)k")
@classmethod
def get_max_num_tokens(cls) -> int:
"""Returns the maximum number of tokens allowed for the model."""
mo = cls.size_pattern.search(cls.model)
if mo:
return int(mo.group(1))*1024
elif cls.model.startswith("gpt-3.5"):
return 4*1024
elif cls.model.startswith("gpt-4"):
return 8*1024
else:
raise ValueError(f"Unknown model: {cls.model}")
@classmethod
def setup(cls, model: str, tokens_haircut: float|tuple[float]=0.9) -> None:
"""Basic setup of the class.
Args:
model (str): The name of the OpenAI model to use, i.e. "gpt-3-0613" or "gpt-4-0613"
tokens_haircut (float|Tuple[float]): coefficients to modify the maximum number of tokens allowed for the model.
"""
cls.model = model
cls.enc = tiktoken.encoding_for_model(model)
match tokens_haircut:
case tuple(x) if len(x) == 2:
cls.max_num_tokens = round(cls.get_max_num_tokens()*x[1] + x[0])
case float(x):
cls.max_num_tokens = round(cls.get_max_num_tokens()*x)
openai.api_key = os.getenv("OPENAI_API_KEY")
@classmethod
def estimate_num_tokens(cls, message: dict) -> int:
"""Estimates the number of tokens of a message.
Args:
message (Dict): The message to estimate the number of tokens of.
Returns:
(int): The estimated number of tokens.
"""
return len(cls.enc.encode(message["content"]))
def __init__(self) -> None:
"""Initializes the chatbot engine.
"""
self.messages = Messages(self.estimate_num_tokens)
self.messages.append({
"role": "system",
"content": "ユーザーを助けるチャットボットです。博多弁で答えます。"
})
self.completion_tokens_prev = 0
self.total_tokens_prev = self.messages.num_tokens[-1]
@retry(retry=retry_if_not_exception_type(InvalidRequestError), wait=wait_fixed(10))
def _process_chat_completion(self, **kwargs) -> dict[str, Any]:
"""Processes ChatGPT API calling."""
self.messages.trim(self.max_num_tokens)
response = openai.ChatCompletion.create(
model=self.model,
messages=self.messages.messages,
**kwargs
)
assert isinstance(response, dict)
message = response["choices"][0]["message"]
usage = response["usage"]
self.messages.append(message, num_tokens=usage["completion_tokens"] - self.completion_tokens_prev)
self.messages.num_tokens[-2] = usage["prompt_tokens"] - self.total_tokens_prev
self.completion_tokens_prev = usage["completion_tokens"]
self.total_tokens_prev = usage["total_tokens"]
return message
def reply_message(self, user_message: str) -> Generator:
"""Replies to the user's message.
Args:
user_message (str): The user's message.
Yields:
(str): The chatbot's response(s)
"""
message = {"role": "user", "content": user_message}
self.messages.append(message)
try:
message = self._process_chat_completion()
except InvalidRequestError as e:
yield f"## Error while Chat GPT API calling with the user message: {e}"
return
yield message['content']
人類の叡智を大半を学習しているかのようなChatGPTですが、その核となるLLMには実はあなたとの直前の会話を記憶する能力はありません。なので、パート1におけるChatEngine
クラスではself.messages
というlistでチャット履歴を管理し、都度新しいメッセージを付け加えてChatGPTに投げ、返事を得ています。つまり、会話が続くとこのself.messages
が大きくなっていくのです。
ちなみに、このような管理はウェブからChatGPTを使う場合にはサーバーサイドで自動的に行われていて、ユーザーは意識する必要はありません。
本題に戻ると、ここで問題なのは、ChatGPTが扱える文章の長さ(トークン数)には制限があるということです。具体的には標準のgpt-3.5-turbo
では4Kトークン、標準のgpt-4
では8Kトークンが上限となっています。今後function callingを実装すると履歴にかなり大きなデータが加わるケースが想定されることもあり、会話履歴のサイズを管理する機能、より具体的にいうと、サイズが上限を超えたら古い履歴から順に消去する機能を実装したいと思います。
from typing import List, Dict, Tuple, Optional, Union, Any, Callable
import os
import re
import openai
from openai.error import InvalidRequestError
import tiktoken
from tenacity import retry, retry_if_not_exception_type, wait_fixed
class Messages:
def __init__(self, tokens_estimator: Callable[[Dict], int]) -> None:
"""Initializes the Messages class.
Args:
tokens_estimator (Callable[[Dict], int]):
Function to estimate the number of tokens of a message.
Args:
message (Dict): The message to estimate the number of tokens of.
Returns:
(int): The estimated number of tokens.
"""
self.tokens_estimator = tokens_estimator
self.messages = list()
self.num_tokens = list()
def append(self, message: Dict[str, str], num_tokens: Optional[int]=None) -> None:
"""Appends a message to the messages.
Args:
message (Dict[str, str]): The message to append.
num_tokens (Optional[int]):
The number of tokens of the message.
If None, self.tokens_estimator will be used.
"""
self.messages.append(message)
if num_tokens is None:
self.num_tokens.append(self.tokens_estimator(message))
else:
self.num_tokens.append(num_tokens)
def trim(self, max_num_tokens: int) -> None:
"""Trims the messages to max_num_tokens."""
while sum(self.num_tokens) > max_num_tokens:
_ = self.messages.pop(1)
_ = self.num_tokens.pop(1)
def rollback(self, n: int) -> None:
"""Rolls back the messages by n steps."""
for _ in range(n):
_ = self.messages.pop()
_ = self.num_tokens.pop()
class ChatEngine:
"""Chatbot engine that uses OpenAI's API to generate responses."""
size_pattern = re.compile(r"\-(\d+)k")
@classmethod
def get_max_num_tokens(cls) -> int:
"""Returns the maximum number of tokens allowed for the model."""
mo = cls.size_pattern.search(cls.model)
if mo:
return int(mo.group(1))*1024
elif cls.model.startswith("gpt-3.5"):
return 4*1024
elif cls.model.startswith("gpt-4"):
return 8*1024
else:
raise ValueError(f"Unknown model: {cls.model}")
@classmethod
def setup(cls, model: str, tokens_haircut: float|Tuple[float]=0.9) -> None:
"""Basic setup of the class.
Args:
model (str): The name of the OpenAI model to use, i.e. "gpt-3-0613" or "gpt-4-0613"
tokens_haircut (float|Tuple[float]): coefficients to modify the maximum number of tokens allowed for the model.
"""
cls.model = model
cls.enc = tiktoken.encoding_for_model(model)
if isinstance(tokens_haircut, tuple):
cls.max_num_tokens = round(cls.get_max_num_tokens()*tokens_haircut[1] + tokens_haircut[0])
else:
cls.max_num_tokens = round(cls.get_max_num_tokens()*tokens_haircut)
openai.api_key = os.getenv("OPENAI_API_KEY")
@classmethod
def estimate_num_tokens(cls, message: Dict) -> int:
"""Estimates the number of tokens of a message.
Args:
message (Dict): The message to estimate the number of tokens of.
Returns:
(int): The estimated number of tokens.
"""
return len(cls.enc.encode(message["content"]))
def __init__(self) -> None:
"""Initializes the chatbot engine.
"""
self.messages = Messages(self.estimate_num_tokens)
self.messages.append({
"role": "system",
"content": "ユーザーを助けるチャットボットです。博多弁で答えます。"
})
self.completion_tokens_prev = 0
self.total_tokens_prev = self.messages.num_tokens[-1]
@retry(retry=retry_if_not_exception_type(InvalidRequestError), wait=wait_fixed(10))
def _process_chat_completion(self, **kwargs) -> Dict[str, Any]:
"""Processes ChatGPT API calling."""
self.messages.trim(self.max_num_tokens)
response = openai.ChatCompletion.create(
model=self.model,
messages=self.messages.messages,
**kwargs
)
message = response["choices"][0]["message"]
usage = response["usage"]
self.messages.append(message, num_tokens=usage["completion_tokens"] - self.completion_tokens_prev)
self.messages.num_tokens[-2] = usage["prompt_tokens"] - self.total_tokens_prev
self.completion_tokens_prev = usage["completion_tokens"]
self.total_tokens_prev = usage["total_tokens"]
return message
def reply_message(self, user_message: str) -> None:
"""Replies to the user's message.
Args:
user_message (str): The user's message.
Yields:
(str): The chatbot's response(s)
"""
message = {"role": "user", "content": user_message}
self.messages.append(message)
try:
message = self._process_chat_completion()
except InvalidRequestError as e:
yield f"## Error while Chat GPT API calling with the user message: {e}"
return
yield message['content']
この改訂版のutils.py
ではlistの代わりにMessages
というクラスで会話履歴を管理しています。Messages
のインスタンスは従来のself.messages
listに加えてself.num_tokens
という各メッセージのトークン数を保持するlistを持っています。trim
というメソッドを呼ぶことによって古いメッセージを削除し、トータルのトークン数を上限以内に抑えることができます。このとき、最初のメッセージは"role": "system"
の重要なメッセージなので削除せず、2番目のメッセージから削除していきます。
API呼び出し前のトークン数の推定にはtiktoken
を用いています。tiktoken
は次のようにインストールできます:
pip install tiktoken
APIを呼び出した後、戻り値に含まれているusage
でトークン数をアップデートしています。これは累積値らしいので差分を計算しています。usage
から得られるトークン数は事前にtiktoken
で計算したトークン数より若干多いのですが、実際にLLMに投げられるプロンプトはmessage["content"]
に若干付け加わったものであるためと考えています。
パート2はこれでおしまいです。コードはtf-koichi/slack-chatbot at part2に置いてあります。今回は履歴管理を強化しただけなので、普通にチャットした感じはパート1と変わらないはずです。パート3は今度こそfunction calling編にします。以上、何かお気づきの点がありましたらフィードバックよろしくお願いします。
Discussion