LangChain の Retriever を用いてVector Searchへ検索を行う
前回の記事ではLangChainのドキュメントに従ってVector Searchへ格納されているベクトルデータに対して自然言語で検索を掛ける手順をまとめました。
今日は自然言語による検索をさらに簡易にするためにRetriever
という機能を使っていきます。
Retriever とは
LangChainによりサポートされている機能です。こちらを用いることでベクトルデータの操作がさらに抽象化されるためよりコーディングがしやすくなります。また検索結果に対するフィルタリングがビルトインの機能として備わっておりより自然な出力が可能となります。またTiDB Serverless Vector Search はMySQL互換DBにベクトル検索用テーブルが同居しているという構成の利点を生かし、ベクトル検索と通常のDBへのSQL検索の両方を組み合わせたハイブリッド検索を簡単に実装できます。
早速やってみます。
まずは前回の記事の内容が終わっていることを確認します。ガンダム部分は飛ばしてもらって問題ありません。ただガンダムは皆さんの人生を豊かにしますのでそれは覚えておきましょう!
--------------------------------------------------------------------------------
Score: 0.21836181550705258
As we take the fight to al-Qaida, we are responsibly leaving Iraq to its people. As a candidate, I promised that I would end this war, and that is what I am doing as president. We will have all of our combat troops out of Iraq by the end of this August. We will support the Iraqi government as they hold elections, and continue to partner with the Iraqi people to promote regional peace and prosperity. But make no mistake: This war is ending, and all of our troops are coming home. Tonight, all of our men and women in uniform -- in Iraq, Afghanistan and around the world -- must know that they have our respect, our gratitude and our full support. And just as they must have the resources they need in war, we all have a responsibility to support them when they come home. That is why we made the largest increase in investments for veterans in decades. That is why we are building a 21st century VA. And that is why Michelle has joined with Jill Biden to forge a national commitment to support military families.
Even as we prosecute two wars, we are also confronting perhaps the greatest danger to the American people -- the threat of nuclear weapons. I have embraced the vision of John F. Kennedy and Ronald Reagan through a strategy that reverses the spread of these weapons and seeks a world without them. To reduce our stockpiles and launchers, while ensuring our deterrent, the United States and Russia are completing negotiations on the farthest-reaching arms control treaty in nearly two decades. And at April's nuclear security summit, we will bring 44 nations together behind a clear goal: securing all vulnerable nuclear materials around the world in four years, so that they never fall into the hands of terrorists.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Score: 0.2232635918072473
So I know the anxieties that are out there right now. They're not new. These struggles are the reason I ran for president. These struggles are what I've witnessed for years in places like Elkhart, Ind., and Galesburg, Ill. I hear about them in the letters that I read each night. The toughest to read are those written by children asking why they have to move from their home, or when their mom or dad will be able to go back to work.
For these Americans and so many others, change has not come fast enough. Some are frustrated; some are angry. They don't understand why it seems like bad behavior on Wall Street is rewarded but hard work on Main Street isn't, or why Washington has been unable or unwilling to solve any of our problems. They are tired of the partisanship and the shouting and the pettiness. They know we can't afford it. Not now.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Score: 0.2232635918072473
So I know the anxieties that are out there right now. They're not new. These struggles are the reason I ran for president. These struggles are what I've witnessed for years in places like Elkhart, Ind., and Galesburg, Ill. I hear about them in the letters that I read each night. The toughest to read are those written by children asking why they have to move from their home, or when their mom or dad will be able to go back to work.
For these Americans and so many others, change has not come fast enough. Some are frustrated; some are angry. They don't understand why it seems like bad behavior on Wall Street is rewarded but hard work on Main Street isn't, or why Washington has been unable or unwilling to solve any of our problems. They are tired of the partisanship and the shouting and the pettiness. They know we can't afford it. Not now.
--------------------------------------------------------------------------------
ではこれをretriever用に書き換えたのが以下です。
# Here we useimport getpass
import getpass
import os
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import TiDBVectorStore
from langchain_openai import OpenAIEmbeddings
os.environ["OPENAI_API_KEY"] = getpass.getpass("OpenAI API Key:")
# copy from tidb cloud console
tidb_connection_string_template = "mysql+pymysql://4JYpfj5Y7etkbg1.root:<PASSWORD>@gateway01.eu-central-1.prod.aws.tidbcloud.com:4000/test?ssl_ca=/etc/ssl/cert.pem&ssl_verify_cert=true&ssl_verify_identity=true"
# tidb_connection_string_template = "mysql+pymysql://root:<PASSWORD>@34.212.137.91:4000/test"
tidb_password = getpass.getpass("Input your TiDB password:")
tidb_connection_string = tidb_connection_string_template.replace(
"<PASSWORD>", tidb_password
)
loader = TextLoader("state_of_the_union.txt",encoding="utf-8")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
TABLE_NAME = "semantic_embeddings"
db = TiDBVectorStore.from_documents(
documents=docs,
embedding=embeddings,
table_name=TABLE_NAME,
connection_string=tidb_connection_string,
distance_strategy="cosine", # default, another option is "l2"
)
retriever = db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 3, "score_threshold": 0.1},
)
semantic_query = "What did the president say about Ketanji Brown Jackson"
results= retriever.invoke(semantic_query)
for result in results:
print("-" * 80)
print(result.page_content)
print("-" * 80)
変更点は以下です。
query = "What did the president say about Ketanji Brown Jackson"
docs_with_score = db.similarity_search_with_score(query, k=3)
retriever = db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 3, "score_threshold": 0.1},
)
semantic_query = "What did the president say about Ketanji Brown Jackson"
results= retriever.invoke(semantic_query)
コードが増えて少しややこしくなっていますが、similarity_search_with_score
で毎回パラメータを指定していた物が、retriever
宣言時に設定することで毎回再帰的に呼び出す部分であるresults= retriever.invoke(semantic_query)
がすっきりしたことがわかります。
また、similarity_search_with_score
は各検索結果ドキュメントでscore付きで戻されるのに対して、retriever.invoke
ではあらかじめscore
を設定しておくことでscore条件に合致したものだけが結果として戻ります。
つまり、結果を受け取った後scoreをもとにアウトプットの調整が必要であったものが、その必要がなくなった分後続の処理がすっきりする、と言えます。
ここからは以下の内容をもとに作業を進めていきます。
まずコードを以下に置き換えます。# Here we useimport getpass
import getpass
import os
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import TiDBVectorStore
from langchain_openai import OpenAIEmbeddings
os.environ["OPENAI_API_KEY"] = getpass.getpass("OpenAI API Key:")
# copy from tidb cloud console
tidb_connection_string_template = "mysql+pymysql://4JYpfj5Y7etkbg1.root:<PASSWORD>@gateway01.eu-central-1.prod.aws.tidbcloud.com:4000/test?ssl_ca=/etc/ssl/cert.pem&ssl_verify_cert=true&ssl_verify_identity=true"
# tidb_connection_string_template = "mysql+pymysql://root:<PASSWORD>@34.212.137.91:4000/test"
tidb_password = getpass.getpass("Input your TiDB password:")
tidb_connection_string = tidb_connection_string_template.replace(
"<PASSWORD>", tidb_password
)
loader = TextLoader("state_of_the_union.txt",encoding="utf-8")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
TABLE_NAME = "semantic_embeddings"
db = TiDBVectorStore.from_documents(
documents=docs,
embedding=embeddings,
table_name=TABLE_NAME,
connection_string=tidb_connection_string,
distance_strategy="cosine", # default, another option is "l2"
)
db.tidb_vector_client.execute(
"""CREATE TABLE airplan_routes (
id INT AUTO_INCREMENT PRIMARY KEY,
airport_code VARCHAR(10),
airline_code VARCHAR(10),
destination_code VARCHAR(10),
route_details TEXT,
duration TIME,
frequency INT,
airplane_type VARCHAR(50),
price DECIMAL(10, 2),
layover TEXT
);"""
)
# insert some data into Routes and our vector table
db.tidb_vector_client.execute(
"""INSERT INTO airplan_routes (
airport_code,
airline_code,
destination_code,
route_details,
duration,
frequency,
airplane_type,
price,
layover
) VALUES
('JFK', 'DL', 'LAX', 'Non-stop from JFK to LAX.', '06:00:00', 5, 'Boeing 777', 299.99, 'None'),
('LAX', 'AA', 'ORD', 'Direct LAX to ORD route.', '04:00:00', 3, 'Airbus A320', 149.99, 'None'),
('EFGH', 'UA', 'SEA', 'Daily flights from SFO to SEA.', '02:30:00', 7, 'Boeing 737', 129.99, 'None');
"""
)
db.add_texts(
texts=[
"Clean lounges and excellent vegetarian dining options. Highly recommended.",
"Comfortable seating in lounge areas and diverse food selections, including vegetarian.",
"Small airport with basic facilities.",
],
metadatas=[
{"airport_code": "JFK"},
{"airport_code": "LAX"},
{"airport_code": "EFGH"},
],
)
retriever = db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 3, "score_threshold": 0.85},
)
semantic_query = "Could you recommend a US airport with clean lounges and good vegetarian dining options?"
reviews = retriever.invoke(semantic_query)
for r in reviews:
print("-" * 80)
print(r.page_content)
print(r.metadata)
print("-" * 80)
実行すると以下が出力されます。
--------------------------------------------------------------------------------
Clean lounges and excellent vegetarian dining options. Highly recommended.
{'airport_code': 'JFK'}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Clean lounges and excellent vegetarian dining options. Highly recommended.
{'airport_code': 'JFK'}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Comfortable seating in lounge areas and diverse food selections, including vegetarian.
{'airport_code': 'LAX'}
--------------------------------------------------------------------------------
何が起きているのか
では中身を見ていきましょう。
db.tidb_vector_client.execute
より上の行は単純にVector SearchとOpenAIへの接続部分を再利用しているだけですので、今となっては意味のない内容が含まれていますが、TiDBVectorStore
クラスに対するdb
インスタンスオブジェクトを再利用しているだけです。
db.tidb_vector_client.execute
はこのコードを複数回実行すると2回目以降は作成を行うテーブルが重複しているエラーが出ますので、2回目以降の実行ではコメントアウトしておきます。
db.tidb_vector_client.execute(
"""INSERT INTO airplan_routes (
この部分でデータを3つInsertしています。
それとは別に以下の部分で3つのベクトルデータを投入しています。
db.add_texts(
texts=[
"Clean lounges and excellent vegetarian dining options. Highly recommended.",
"Comfortable seating in lounge areas and diverse food selections, including vegetarian.",
"Small airport with basic facilities.",
],
metadatas=[
{"airport_code": "JFK"},
{"airport_code": "LAX"},
{"airport_code": "EFGH"},
],
図にするとairplan_route
テーブルには以下が含まれています。
一方db.add_texts
で投入した3つのデータは以下の通りsemantic_embeddings
テーブルに格納されています。
SELECT * FROM semantic_embeddings where meta ORDER by meta;
そして以下の部分でsemantic_embeddings
テーブルに対して検索を実行しています。```
retriever = db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 3, "score_threshold": 0.85},
)
semantic_query = "Could you recommend a US airport with clean lounges and good vegetarian dining options?"
reviews = retriever.invoke(semantic_query)
コードの最後に以下をつけ足して再実行します。
# Extracting airport codes from the metadata
airport_codes = [review.metadata["airport_code"] for review in reviews]
# Executing a query to get the airport details
search_query = "SELECT * FROM airplan_routes WHERE airport_code IN :codes"
params = {"codes": tuple(airport_codes)}
result = db.tidb_vector_client.execute(search_query, params)
airport_details = result.get('result', [])
for detail in airport_details:
print(detail)
以下のような値が出力されるはずです。
(1, 'JFK', 'DL', 'LAX', 'Non-stop from JFK to LAX.', datetime.timedelta(seconds=21600), 5, 'Boeing 777', Decimal('299.99'), 'None')
これはreviews = retriever.invoke(semantic_query)
の結果としてJFK空港が推奨された結果を受け、JFK
という値で以下の部分でairplan_routes
テーブルに通常のSQLでクエリを投げた結果です。
search_query = "SELECT * FROM airplan_routes WHERE airport_code IN :codes"
params = {"codes": tuple(airport_codes)}
result = db.tidb_vector_client.execute(search_query, params)
この処理を一つにまとめてしまうことが出来ます。
# Here we useimport getpass
import getpass
import os
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import TiDBVectorStore
from langchain_openai import OpenAIEmbeddings
os.environ["OPENAI_API_KEY"] = getpass.getpass("OpenAI API Key:")
# copy from tidb cloud console
tidb_connection_string_template = "mysql+pymysql://4JYpfj5Y7etkbg1.root:<PASSWORD>@gateway01.eu-central-1.prod.aws.tidbcloud.com:4000/test?ssl_ca=/etc/ssl/cert.pem&ssl_verify_cert=true&ssl_verify_identity=true"
# tidb_connection_string_template = "mysql+pymysql://root:<PASSWORD>@34.212.137.91:4000/test"
tidb_password = getpass.getpass("Input your TiDB password:")
tidb_connection_string = tidb_connection_string_template.replace(
"<PASSWORD>", tidb_password
)
loader = TextLoader("state_of_the_union.txt",encoding="utf-8")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
TABLE_NAME = "semantic_embeddings"
db = TiDBVectorStore.from_documents(
documents=docs,
embedding=embeddings,
table_name=TABLE_NAME,
connection_string=tidb_connection_string,
distance_strategy="cosine", # default, another option is "l2"
)
db.tidb_vector_client.execute(
"""CREATE TABLE airplan_routes (
id INT AUTO_INCREMENT PRIMARY KEY,
airport_code VARCHAR(10),
airline_code VARCHAR(10),
destination_code VARCHAR(10),
route_details TEXT,
duration TIME,
frequency INT,
airplane_type VARCHAR(50),
price DECIMAL(10, 2),
layover TEXT
);"""
)
# insert some data into Routes and our vector table
db.tidb_vector_client.execute(
"""INSERT INTO airplan_routes (
airport_code,
airline_code,
destination_code,
route_details,
duration,
frequency,
airplane_type,
price,
layover
) VALUES
('JFK', 'DL', 'LAX', 'Non-stop from JFK to LAX.', '06:00:00', 5, 'Boeing 777', 299.99, 'None'),
('LAX', 'AA', 'ORD', 'Direct LAX to ORD route.', '04:00:00', 3, 'Airbus A320', 149.99, 'None'),
('EFGH', 'UA', 'SEA', 'Daily flights from SFO to SEA.', '02:30:00', 7, 'Boeing 737', 129.99, 'None');
"""
)
db.add_texts(
texts=[
"Clean lounges and excellent vegetarian dining options. Highly recommended.",
"Comfortable seating in lounge areas and diverse food selections, including vegetarian.",
"Small airport with basic facilities.",
],
metadatas=[
{"airport_code": "JFK"},
{"airport_code": "LAX"},
{"airport_code": "EFGH"},
],
)
retriever = db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 3, "score_threshold": 0.85},
)
semantic_query = "Could you recommend a US airport with clean lounges and good vegetarian dining options?"
reviews = retriever.invoke(semantic_query)
for r in reviews:
print("-" * 80)
print(r.page_content)
print(r.metadata)
print("-" * 80)
search_query = f"""
SELECT
VEC_Cosine_Distance(se.embedding, :query_vector) as distance,
ar.*,
se.document as airport_review
FROM
airplan_routes ar
JOIN
{TABLE_NAME} se ON ar.airport_code = JSON_UNQUOTE(JSON_EXTRACT(se.meta, '$.airport_code'))
ORDER BY distance ASC
LIMIT 5;
"""
query_vector = embeddings.embed_query(semantic_query)
params = {"query_vector": str(query_vector)}
result = db.tidb_vector_client.execute(search_query, params)
airport_details = result.get('result', [])
for detail in airport_details:
print(detail)
Retrieverには他にもPDFやCSVなどのドキュメントを読み込んだり様々な機能があります。
Discussion