SlackとChatGPT APIでチャットボットを作る パート3(function calling編)
(2023-12-13 追記)最近のOpenAI SDKは仕様が変わっており、現在載せているコードは動かないので互換性のある古いOpenAI SDKを含む
requirents.txt
を掲載します。Pythonコードも細部を少し修正しました。
wheel
tenacity
slack_bolt
openai==0.28.1
tiktoken
pandas
matplotlib
japanize_matplotlib
seaborn
scikit-learn
ipykernel
SlackとChatGPT APIでチャットボットを作る パート3(function calling編)
パート1(基礎編)、パート2(会話履歴管理編)と来て、パート3はやっとfunction calling編です。
function callingとは
function callingでは、ChatGPTが回答作成のために使用することができる「関数」をあらかじめ用意し、質問などとともにこの関数の仕様を含んだメッセージをChatGPTに送ります。そうするとChatGPTは関数を使用したい場合は使用したい関数の名前とその関数に与える引数をメッセージとして返してきます。これを受け取ったユーザーは指定された関数を指定された引数で実行します。関数の実行はChatGPTが勝手に行うのではなく、ユーザーが自らの環境で行うという点がミソです。無事関数の結果が得られたら、これをメッセージとしてChatGPTに送ります。そうするとChatGPTは受け取った関数の出力を使って当初のメッセージに含まれていた質問に対する回答を作成し、ユーザーに送ります。
データベースの用意
「関数」としてはアイデア次第で様々なものがありえますが、このパートではSQLを引数としてユーザーのデータベースを検索する「関数」を取上げることにします。データベースの中身に関する質問をChatGPTに投げかけるとChatGPTがこの「関数」の引数としてSQLを組み立ててくれます。
お手持ちのデータベースを使うと便利な検索チャットボットができます。この記事では適当なデータを用意してSQLite3のファイルにしました(tf-koichi/slack-chatbot at part3の/data/world_stats.sqlite3
)。データ作成に用いたノートブックは/notebooks/data.ipynb
で、元データ取得ページへのリンクも記されています。
スラッシュ・コマンドの設定
Slackのスラッシュ・アプリを使いますので、Slackアプリの設定ページの左のメニューからSlash Commandsをクリックし、Create New Commandをクリックして/verbose
と/style
の2つのコマンドを登録します。
スクショを撮り忘れたのですが、最初のスラッシュ・コマンドを登録したときに、設定ページの上部にアプリのリロードを促すメッセージが表示されますので、リンクの部分をクリックしてリロードを完了しておいてください。
スラッシュ・コマンド/verbose
はverboseモードをオン・オフするために使い、/style
はチャットボットの返答スタイルをセットするのに使います。
ソース・コード
パート3ではutils.py
とchatbot.py
を以下のように書き換えます。
from typing import Optional, Any, Callable, Generator
import io
import os
import re
import json
from pathlib import Path
import sqlite3
import pandas as pd
import openai
from openai.error import InvalidRequestError
import tiktoken
from tenacity import retry, retry_if_not_exception_type, wait_fixed
class WSDatabase:
data_path = Path("../data/world_stats.sqlite3")
schema = [
{
"name": "country",
"description": "国名"
},{
"name": "country_code",
"description": "国コード"
},{
"name": "average life expectancy at birth",
"description": "平均寿命(年)"
},{
"name": "alcohol_consumption",
"description": "一人当たりの年間アルコール消費量(リットル)"
},{
"name": "region",
"description": "地域"
},{
"name": "gdp per capita",
"description": "一人当たりのGDP(ドル)"
}
]
def __enter__(self):
self.conn = sqlite3.connect(self.data_path)
return self
def __exit__(self, exc_type, exc_value, traceback):
self.conn.close()
@classmethod
def schema_str(cls):
schema_df = pd.DataFrame.from_records(cls.schema)
text_buffer = io.StringIO()
schema_df.to_csv(text_buffer, index=False)
text_buffer.seek(0)
schema_csv = text_buffer.read()
schema_csv = "table: world_stats\ncolumns:\n" + schema_csv
return schema_csv
def ask_database(self, query):
"""Function to query SQLite database with a provided SQL query."""
try:
cursor = self.conn.cursor()
cursor.execute(query)
results = cursor.fetchall()
cols = [col[0] for col in cursor.description]
results_df = pd.DataFrame(results, columns=cols)
text_buffer = io.StringIO()
results_df.to_csv(text_buffer, index=False)
text_buffer.seek(0)
results_csv = text_buffer.read()
except Exception as e:
results_csv = f"query failed with error: {e}"
return results_csv
class Messages:
def __init__(self, tokens_estimator: Callable[[dict], int]) -> None:
"""Initializes the Messages class.
Args:
tokens_estimator (Callable[[Dict], int]):
Function to estimate the number of tokens of a message.
Args:
message (Dict): The message to estimate the number of tokens of.
Returns:
(int): The estimated number of tokens.
"""
self.tokens_estimator = tokens_estimator
self.messages = list()
self.num_tokens = list()
def append(self, message: dict[str, str], num_tokens: Optional[int]=None) -> None:
"""Appends a message to the messages.
Args:
message (Dict[str, str]): The message to append.
num_tokens (Optional[int]):
The number of tokens of the message.
If None, self.tokens_estimator will be used.
"""
self.messages.append(message)
if num_tokens is None:
self.num_tokens.append(self.tokens_estimator(message))
else:
self.num_tokens.append(num_tokens)
def trim(self, max_num_tokens: int) -> None:
"""Trims the messages to max_num_tokens."""
while sum(self.num_tokens) > max_num_tokens:
_ = self.messages.pop(1)
_ = self.num_tokens.pop(1)
def rollback(self, n: int) -> None:
"""Rolls back the messages by n steps."""
for _ in range(n):
_ = self.messages.pop()
_ = self.num_tokens.pop()
class ChatEngine:
"""Chatbot engine that uses OpenAI's API to generate responses."""
size_pattern = re.compile(r"\-(\d+)k")
@classmethod
def get_max_num_tokens(cls) -> int:
"""Returns the maximum number of tokens allowed for the model."""
mo = cls.size_pattern.search(cls.model)
if mo:
return int(mo.group(1))*1024
elif cls.model.startswith("gpt-3.5"):
return 4*1024
elif cls.model.startswith("gpt-4"):
return 8*1024
else:
raise ValueError(f"Unknown model: {cls.model}")
@classmethod
def setup(cls, model: str, tokens_haircut: float|tuple[float]=0.9, quotify_fn: Callable[[str], str]=lambda x: x) -> None:
"""Basic setup of the class.
Args:
model (str): The name of the OpenAI model to use, i.e. "gpt-3-0613" or "gpt-4-0613"
tokens_haircut (float|Tuple[float]): coefficients to modify the maximum number of tokens allowed for the model.
quotify_fn (Callable[[str], str]): Function to quotify a string.
"""
openai.api_key = os.getenv("OPENAI_API_KEY")
cls.model = model
cls.enc = tiktoken.encoding_for_model(model)
match tokens_haircut:
case tuple(x) if len(x) == 2:
cls.max_num_tokens = round(cls.get_max_num_tokens()*x[1] + x[0])
case float(x):
cls.max_num_tokens = round(cls.get_max_num_tokens()*x)
case _:
raise ValueError(f"Invalid tokens_haircut: {tokens_haircut}")
cls.functions = [
{
"name": "ask_database",
"description": "世界各国の平均寿命、アルコール消費量、一人あたりGDPのデータベースを検索するための関数。出力はSQLite3が理解できる完全なSQLクエリである必要がある。",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": f"""
SQL query extracting info to answer the user's question.
SQL should be written using this database schema:
{WSDatabase.schema_str()}
""",
}
},
"required": ["query"]
},
}
]
cls.quotify_fn = staticmethod(quotify_fn)
@classmethod
def estimate_num_tokens(cls, message: dict) -> int:
"""Estimates the number of tokens of a message.
Args:
message (Dict): The message to estimate the number of tokens of.
Returns:
(int): The estimated number of tokens.
"""
return len(cls.enc.encode(message["content"]))
def __init__(self, style: str="博多弁") -> None:
"""Initializes the chatbot engine.
"""
style_direction = f"{style}で答えます" if style else ""
self.style = style
self.messages = Messages(self.estimate_num_tokens)
self.messages.append({
"role": "system",
"content": f"必要に応じてデータベースを検索し、ユーザーを助けるチャットボットです。{style_direction}"
})
self.completion_tokens_prev = 0
self.total_tokens_prev = self.messages.num_tokens[-1]
self._verbose = False
@property
def verbose(self) -> bool:
return self._verbose
@verbose.setter
def verbose(self, value: bool) -> None:
self._verbose = value
@retry(retry=retry_if_not_exception_type(InvalidRequestError), wait=wait_fixed(10))
def _process_chat_completion(self, **kwargs) -> dict[str, Any]:
"""Processes ChatGPT API calling."""
self.messages.trim(self.max_num_tokens)
response = openai.ChatCompletion.create(
model=self.model,
messages=self.messages.messages,
**kwargs
)
assert isinstance(response, dict)
message = response["choices"][0]["message"]
usage = response["usage"]
self.messages.append(message, num_tokens=usage["completion_tokens"] - self.completion_tokens_prev)
self.messages.num_tokens[-2] = usage["prompt_tokens"] - self.total_tokens_prev
self.completion_tokens_prev = usage["completion_tokens"]
self.total_tokens_prev = usage["total_tokens"]
return message
def reply_message(self, user_message: str) -> Generator:
"""Replies to the user's message.
Args:
user_message (str): The user's message.
Yields:
(str): The chatbot's response(s)
"""
message = {"role": "user", "content": user_message}
self.messages.append(message)
try:
message = self._process_chat_completion(
functions=self.functions,
)
except InvalidRequestError as e:
yield f"## Error while Chat GPT API calling with the user message: {e}"
return
while message.get("function_call"):
function_name = message["function_call"]["name"]
arguments = json.loads(message["function_call"]["arguments"])
if self._verbose:
yield self.quotify_fn(f"function name: {function_name}")
yield self.quotify_fn(f"arguments: {arguments}")
if function_name == "ask_database":
with WSDatabase() as db:
function_response = db.ask_database(arguments["query"])
else:
function_response = f"## Unknown function name: {function_name}"
if self._verbose:
yield self.quotify_fn(f"function response:\n{function_response}")
self.messages.append({
"role": "function",
"name": function_name,
"content": function_response
})
try:
message = self._process_chat_completion()
except InvalidRequestError as e:
yield f"## Error while ChatGPT API calling with the function response: {e}"
self.messages.rollback(3)
return
yield message['content']
WSDatabase
はこの記事で取上げるデータベースまわりの機能を集めたクラスです。クラスメソッドschema_str()
はデータベースのスキーマを出力します。スキーマはのちにChatGPTに送られ、それによってChatGPTはデータベースの構造を理解します。メソッドask_database()
はSQLの文字列を引数に取り、データベースの検索結果を返します。これが今回ChatGPTが使う「関数」です。独自のデータベースを使用したい場合はお使いのDBMSに合わせてこのクラスを書き換え、前出の2つのメソッドを実装してください。結果はChatGPTが理解できる形であれば良く、JSONで返す例をよく見ますが、ここではCSVで返しています。
次にChatEngine
クラスですが、setup()
クラスメソッドにquotify_fn: Callable[[str], str]
というあらたな引数を与えることができるようになっています。これはチャットボットの返答文字列を引用形に変換する関数です。
__init__()
にはstyle: str
という引数を与えることができ、これはチャットボットの返答のスタイルを設定するのに使用されます。パート2まではこれは「博多弁」とハードコードされていました。ここで注意なのですが、 この引数はプロンプト・インジェクションに悪用される可能性がありますが、現状では何の対策もされていません。 もしもチャットボットを大勢の使用に供する場合はこの点の対策を検討してください。"role": "system"
のプロンプトがfunction callingに対応して若干変更されています。
reply_message()
メソッドは大幅に変更されています。最初にメッセージをChatGPTに投げるときにfunctions
という引数が追加されています。これによってChatGPTが使用できる「関数」の説明を送ります。今回の「関数」はデータベースの検索なのでデータベースのスキーマも含まれます。
「関数を使いたい」とChatGPTが判断した場合には返信のメッセージにfunction_call
というキーが含まれていて、その下にname
というキーで関数名、arguments
というキーで引数が格納されています。こちらで指定された関数を実行し、その結果をもとに{"role": "function", "name": "<関数名>", "content": "<関数の出力>"}
というかたちのメッセージを組み立ててChatGPTに投げます。そうするとChatGPTは当初の質問に対する回答を含んだメッセージを打ち返してきます。
import os
from slack_bolt import App
from slack_bolt.adapter.socket_mode import SocketModeHandler
from utils import ChatEngine
chatbot_app_token = os.environ["CHATBOT_APP_TOKEN"]
slack_bot_token = os.environ["SLACK_BOT_TOKEN"]
app = App(token=slack_bot_token)
@app.message()
def handle(message, say):
global chat_engine_dict
if message["user"] not in chat_engine_dict.keys():
chat_engine_dict[message["user"]] = ChatEngine()
for reply in chat_engine_dict[message["user"]].reply_message(message['text']):
say(reply)
@app.command("/verbose")
def verbose_function(ack, body, respond):
ack()
global chat_engine_dict
user_id = body["user_id"]
if user_id not in chat_engine_dict.keys():
chat_engine_dict[user_id] = ChatEngine()
switch = body["text"].lower().strip()
if not switch:
respond("Verbose mode." if chat_engine_dict[user_id].verbose else "Quiet mode.")
elif switch == "on":
chat_engine_dict[user_id].verbose = True
respond("Verbose mode.")
elif switch == "off":
chat_engine_dict[user_id].verbose = False
respond("Quiet mode.")
else:
respond("usage: /verbose [on|off]")
@app.command("/style")
def style_function(ack, body, respond):
ack()
global chat_engine_dict
user_id = body["user_id"]
switch = body["text"].lower().strip()
if switch:
chat_engine_dict[user_id] = ChatEngine(style=switch)
respond(f"Style: {chat_engine_dict[user_id].style}")
elif user_id in chat_engine_dict.keys():
respond(f"Style: {chat_engine_dict[user_id].style}")
else:
respond("まだ会話が始まっていません。")
model = "gpt-4-0613"
def quotify(s: str) -> str:
"""Adds quotes to a string.
Args:
s (str): The string to add quotes to.
Returns:
(str) The string with quotes added.
"""
return "\n".join([f"> {l}" for l in s.split("\n")])
ChatEngine.setup(model, quotify_fn=quotify)
chat_engine_dict = dict()
SocketModeHandler(app, chatbot_app_token).start()
スラッシュ・コマンドのハンドラが追加されています。また、ChatEngine.setup()
のquotify_fn
に渡すための関数quotify()
を定義しています。これはmarkdown形式に対応して行頭に"> "を追加するものになっています。
動作例
それではパート2までと同様にチャットボットを起動してください。
python chatbot.py
⚡️ Bolt app is running!
/style 秋田弁
で秋田弁に変更することができます。
/verbose on
でverboseモードになり、function callingの動作の舞台裏を見ることができます。
パート3は以上です。再掲になりますが、パート3のコードやデータはtf-koichi/slack-chatbot at part3に置いてあります。パート4に続きます。何かお気づきの点がありましたらフィードバックよろしくお願いします。
Discussion