🔥

LangChainを使って自然言語でデータベースを検索する(Azure編)

2023/07/20に公開

やること

この記事ではLangChainを使って自然言語でデータベースを検索してみます。
環境はAzureを使っていますので、Azure OpenAIとSQL Databse付属のサンプルスキーマを利用します。

手順はこの辺を参考にしています。
https://python.langchain.com/docs/modules/agents/toolkits/sql_database
https://python.langchain.com/docs/modules/chains/popular/sqlite

データベースの準備

マニュアルを参考にデータベースとサンプルデータを用意します。
https://learn.microsoft.com/ja-jp/azure/azure-sql/database/single-database-create-quickstart?view=azuresql-db&tabs=azure-portal

追加設定のデータソースでサンプルを選択してAdventureWorksLTが作成されるようにします。
今回はとりあえず動けば良いのでサーバレスモデルの一番小さいサイズにしています。

実行環境

モジュールのインストールなどはこちらを参照。
https://python.langchain.com/docs/get_started/installation

今回使ったバージョンはこちら。

langchain               0.0.237

サンプルコード

初期設定

.envファイルを用意して読み込みます。

.env
OPENAI_API_KEY=AOAIのキー
OPENAI_API_BASE=https://xxxx.openai.azure.com/
OPENAI_API_VERSION=2023-05-15
DATABASE_USERNAME=DB接続ユーザ
DATABASE_PASSWORD=パスワード
DATABASE_SERVER=xxxx.database.windows.net
DATABASE_DB=DB名

インポート部分

import os
import openai
from dotenv import load_dotenv
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.llms import AzureOpenAI
from langchain.sql_database import SQLDatabase

load_dotenv()

LLMインスタンス作成

AOAIのtext-davinci-003を使ってLLMインスタンスを作成します。

openai.api_type = "azure"
openai.api_version = os.getenv("OPENAI_API_VERSION")
openai.api_base = os.getenv("OPENAI_API_BASE")
openai.api_key = os.getenv("OPENAI_API_KEY")

llm = AzureOpenAI(deployment_name="text-davinci-003", model_name="text-davinci-003", temperature=0

DB接続設定

今回はpyodbcで接続したので、次のような形で接続文字列を作ります。pymssqlでも多分できます。

database_user = os.getenv("DATABASE_USERNAME")
database_password = os.getenv("DATABASE_PASSWORD")
database_server = os.getenv("DATABASE_SERVER")
database_db = os.getenv("DATABASE_DB")
driver = "{ODBC Driver 18 for SQL Server}"

odbc_connection_string = f"mssql+pyodbc://?odbc_connect=DRIVER={driver};SERVER={database_server};DATABASE={database_db};ENCRYPT=yes;UID={database_user};PWD={database_password}"

サンプルデータがSalesLTに作成されているのでここで指定します。

db = SQLDatabase.from_uri(
    odbc_connection_string,
    schema="SalesLT",
)

toolkitとagentの設定。

toolkit = SQLDatabaseToolkit(
    db=db,
    llm=llm,
    reduce_k_below_max_tokens=True
)

agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    use_query_checker=True,
    verbose=True
)

検索実行

日本語で聞いてみます。

agent_executor.run("製品カテゴリがMountain Bikesの製品は何種類ありますか。"

実行結果1(エラー)

PostgreSQLの構文でスキーマを変更しようとしてエラーとなります。。。

SET search_path TO SalesLT

原因

原因はこの辺のようで。スキーマが設定されていてsnowflakeとbigquery以外のときはSET search_path TO {self._schema}"してしまうようです。

sql_database.py
    def run(self, command: str, fetch: str = "all") -> str:
        """Execute a SQL command and return a string representing the results.

        If the statement returns rows, a string of the results is returned.
        If the statement returns no rows, an empty string is returned.

        """
        with self._engine.begin() as connection:
            if self._schema is not None:
                if self.dialect == "snowflake":
                    connection.exec_driver_sql(
                        f"ALTER SESSION SET search_path='{self._schema}'"
                    )
                elif self.dialect == "bigquery":
                    connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'")
                else:
                    connection.exec_driver_sql(f"SET search_path TO {self._schema}")

回避策

今回は、こんな感じのコードを追加して凌ぐことにします。

                elif self.dialect == "mssql":
                    pass

修正を待ちましょう。
https://github.com/hwchase17/langchain/issues/7928

実行結果2(成功!)

コード修正して再実行。
しっかり日本語で32種類あると返ってきました!

実行されていたクエリはこちら。

実行されていたクエリ
SELECT COUNT(*) FROM SalesLT.Product WHERE ProductCategoryID = (SELECT ProductCategoryID FROM SalesLT.ProductCategory WHERE Name = 'Mountain Bikes');

最終的に実行されたクエリにたどり着くまでに以下のようなクエリも実行されていました。
テーブルをスキーマ名で修飾していなく怒られたようです。

SELECT COUNT(*) FROM Product WHERE ProductCategoryID = (SELECT ProductCategoryID FROM ProductCategory WHERE Name = 'Mountain Bikes');

Error: (pyodbc.ProgrammingError) ('42S02', "[42S02] [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Invalid object name 'Product'. (208) (SQLExecDirectW)")

さいごに

現状ではLangChainがSQL Databaseにはそのままでは対応でき無さそうなので対応を待ったほうが良さそうです。
schemaの設定をNoneにするとdboスキーマを見に行くので、テーブルが全てdbo配下に作られている場合は対応できそうですが、今回のように特定スキーマにテーブルが有るときは何かしら工夫が必要そうな感じがします。

Discussion