👋

【LangChain】データベースを扱う② SQL エージェントを試す

2024/04/07に公開

前の章では、LLMを使ってクエリを作成し、データベースからデータを抽出して回答を生成しました。
簡単な指示なら、LLMを使うだけでも実行できますが、人間の指示がふわっとしているなどLLMだけでは手に追えない場合はエージェントを使うことで、より効果的にデータベースを扱うことができます。

SQLエージェントの作成と実行

これまで学習したエージェントの同じく、エージェントのexecutorを作成してinvokeします。

from langchain_community.agent_toolkits import create_sql_agent

agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)

agent_executor.invoke(
    {
        "input": "List the total sales per country. Which country's customers spent the most?"
    }
)

ちなみに質問内容は「どの国の顧客が最も多く支出したか」です。

実行結果は出力された通りなんですが、読み解いてみます。

Invoking: `sql_db_list_tables` with `{'tool_input': ''}`


Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track

まずはデータベースのテーブルに何があるのかを調べています。
エージェントさんは、今回渡されたデータベースがどんなものか知らないので、当たり前です。

Invoking: `sql_db_schema` with `{'table_names': 'Customer, Invoice, InvoiceLine'}`

CREATE TABLE "Customer" (
・・・(省略)・・・
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
・・・(省略)・・・
*/

次にテーブル構造を解読しています。
今回渡された質問(どの国の顧客が最も多く支出したか)の回答を得るために、どのテーブルから何のデータを抽出すれば良いか見当をつけているわけですね。

Invoking: `sql_db_query` with `{'query': 'SELECT c.Country, SUM(i.Total) AS TotalSales FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC'}`
responded: To find the total sales per country, I will first calculate the total sales for each invoice by summing the total amount in each invoice. Then, I will group the results by the country of the customer. Finally, I will order the results by the total sales in descending order to determine which country's customers spent the most.

Here is the SQL query to achieve this:

``sql
SELECT c.Country, SUM(i.Total) AS TotalSales
FROM Customer c
JOIN Invoice i ON c.CustomerId = i.CustomerId
GROUP BY c.Country
ORDER BY TotalSales DESC
``

そして、クエリを作成しています。
ご丁寧に結果を得るまでのプロセスまで解説してくれていますね。

sql_db_query コマンドを使用し、クエリ {'query': 'SELECT c.Country, SUM(i.Total) AS TotalSales FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC'} を実行しました。その結果、各国の総売上を見つけるために、まず各請求書の総額を合計して各請求書の総売上を計算します。次に、顧客の国別に結果をグループ化します。最後に、総売上の降順で結果を並べ替え、どの国の顧客が最も多くのお金を使ったかを判断します。

I will now execute this query to find out which country's customers spent the most.

[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.620000000000005), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.620000000000005), ('Poland', 37.620000000000005), ('Italy', 37.620000000000005), ('Denmark', 37.620000000000005), ('Australia', 37.620000000000005), ('Argentina', 37.620000000000005), ('Spain', 37.62), ('Belgium', 37.62)]The total sales per country are as follows:

1. USA: $523.06
2. Canada: $303.96
3. France: $195.10
4. Brazil: $190.10
5. Germany: $156.48

The country whose customers spent the most is the USA with a total sales amount of $523.06.

この部分はクエリ実行結果です
エージェントさんには、上位5位までを回答してくださいと投げかけたわけではありませんが、気を利かせて上位5位までに絞ってくれています。

> Finished chain.
{'input': "List the total sales per country. Which country's customers spent the most?",
 'output': 'The total sales per country are as follows:\n\n1. USA: $523.06\n2. Canada: $303.96\n3. France: $195.10\n4. Brazil: $190.10\n5. Germany: $156.48\n\nThe country whose customers spent the most is the USA with a total sales amount of $523.06.'}

最後にchainの実行結果です。
人間の問いかけ(input)に対して、AIの回答(output)を作成しています。

中身を読み解くと改めてエージェントの優秀さに驚かされます。

few-shotプロンプトの実行

前述のようにエージェントは、人間の質問に対して適宜クエリ作成・実行・集計まで考えてやってくれると相当優秀ですが、
私のようにITリテラシーが半世紀遅れているようなJTCに勤めていると、社内でしか通用しない独特な用語や、一般的に用いられている意味と違う意味になっている社内用語がたくさん飛び交っていますので、適宜社内用語を一般的な言葉に翻訳する必要があります。
(例えていうなら、他社から転職してきたエンジニアが社内用語に戸惑う感じです。)

そのような場合、「人間がこういう感じの問いかけをしてきたら、こういうクエリを作成したらいいよ」というようにプロンプトを予め作成することで、ITリテラシーのない人間のふわっとした質問にも答えられるようにできます。

以下、人間のふわっとした質問と、実行すべきクエリのリストです。

examples = [
    {   "input": "List all artists.", 
        "query": "SELECT * FROM Artist;"
    },{
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
    },
    {
        "input": "List all tracks in the 'Rock' genre.",
        "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
    },
    {
        "input": "Find the total duration of all tracks.",
        "query": "SELECT SUM(Milliseconds) FROM Track;",
    },
    {
        "input": "List all customers from Canada.",
        "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
    },
    {
        "input": "How many tracks are there in the album with ID 5?",
        "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
    },
    {
        "input": "Find the total number of invoices.",
        "query": "SELECT COUNT(*) FROM Invoice;",
    },
    {
        "input": "List all tracks that are longer than 5 minutes.",
        "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
    },
    {
        "input": "Who are the top 5 customers by total purchase?",
        "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
    },
    {
        "input": "Which albums are from the year 2000?",
        "query": "SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';",
    },
    {
        "input": "How many employees are there",
        "query": 'SELECT COUNT(*) FROM "Employee"',
    },
]

これをfew-shotプロンプトとして用いるには、まずSemanticSimilarityExampleSelectorを使って、inputに対してどのクエリを実行するかを特定するセレクタを作成します

from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    FAISS,
    k=5,
    input_keys=["input"],
)

見ての通り、セレクタではFAISSを使って近傍検索をしています。
これにより、人間のふわっとした質問に対して、「おそらくこのクエリを実行すればいいんだな」とクエリを選ぶことができます。

次に、このセレクタをぶち込んだFewShotPromptTemplateを作成します。

from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotPromptTemplate,
    MessagesPlaceholder,
    PromptTemplate,
    SystemMessagePromptTemplate,
)

system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.

Here are some examples of user inputs and their corresponding SQL queries:"""

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input", "dialect", "top_k"],
    prefix=system_prefix,
    suffix="",
)

まだエージェントに渡すには不十分です。
過去の記事に書いたように、エージェントが使えるプロンプトとするには、inputとagent_scratchpadを備えている必要があります。

これらを揃えると以下のようになります。

full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=few_shot_prompt),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

これでやっとエージェントが使えるプロンプトになりましたので、few-shotプロンプトを兼ね備えた独自のSQLエージェントを作成しましょう。

agent = create_sql_agent(
    llm=llm,
    db=db,
    prompt=full_prompt,
    verbose=True,
    agent_type="openai-tools",
)

実際に試すことができますが、チュートリアルではfew-shotプロンプトでなくても答えを得られるサンプルなので、社内用語などfew-shotプロンプトの威力を実感できるサンプルが欲しいところ

agent.invoke({"input": "How many artists are there?"})

Discussion