🎉

TiDB for AI : pytidb AI に特化したTiDB用 Python SDK その1

に公開

今日はTiDB for AI というものを試していきたいと思います。
https://pingcap.github.io/ai/

どうやらAIに特化したTiDBを操作する用のpython SDKとのことで、pytidbというパッケージが準備されています。

さっそくやってみる

1. インストール

https://pingcap.github.io/ai/quickstart/

まずはインストールを行います。

pip install pytidb

このSDKには以下との依存関係が存在しているようです。
AIフレームワーク: LlamaIndex、LangChain
ORM ライブラリ: SQLAlchemy、Django-ORM、Peewee
AIサービス: Bedrock
埋め込みモデル/サービス: JinaAI

2. 接続

https://pingcap.github.io/ai/quickstart/
ここにクイックスタートがありますが、なかなかに記載がハードです。サンプルスクリプトが記載されているのですがそのまま実行できる形になっていません。このため1つづつ実行可能なPythonスクリプトを作っていきます。

from pytidb import TiDBClient

client = TiDBClient.connect(
    host="gateway01.us-east-1.prod.shared.aws.tidbcloud.com",
    port=4000,
    username="4EfqPF23YKBxaQb.root",
    password="abcd1234",
    database="test",
)

このままでは接続されたかどうか不明ですので、以下に変更して実行します。

connect.py
from pytidb import TiDBClient

# データベース接続
client = TiDBClient.connect(
    host="gateway01.us-west-2.prod.aws.tidbcloud.com",
    port=4000,
    username="237VxcjMhxqE85K.root",
    password="mnxxxxx",
    database="test",
)

# 接続確認
result = client.execute("SELECT 1 as test")
print("接続成功")
print(f"結果: {result}")
python connect.py
接続成功
結果: rowcount=1 success=True message=None

3. 埋め込みモデルの設定と埋め込み接続テスト

次にOpenAIの埋め込みモデルの設定と実際の埋め込み接続テストを行います。
まず必要なライブラリをインストールします。

pip install litellm

インストールが完了したらconnect.pyを以下に置き換えます。

from pytidb import TiDBClient
from pytidb.embeddings import EmbeddingFunction

# データベース接続
client = TiDBClient.connect(
    host="gateway01.us-west-2.prod.aws.tidbcloud.com",
    port=4000,
    username="237VxcjMhxqE85K.root",
    password="mnxxxxx",
    database="test",
)

# 埋め込み機能の設定
text_embed = EmbeddingFunction(
    model_name="openai/text-embedding-3-small",
    api_key="<api key>",
)

# 接続確認
result = client.execute("SELECT 1 as test")
print("接続成功")
print(f"結果: {result}")

# 埋め込みモデルのテスト
print("\n埋め込みモデルテスト...")
test_text = "こんにちは、これはテストです。"
try:
    embedding = text_embed.get_query_embedding(test_text)
    print("✅ 埋め込み生成成功")
    print(f"テキスト: {test_text}")
    print(f"埋め込み次元: {len(embedding)}")
    print(f"埋め込みの最初の5要素: {embedding[:5]}")
except Exception as e:
    print(f"❌ 埋め込み生成エラー: {e}")

<api key>は皆さんの環境ごとに置き換えておきます。
では実行してみます。

python connect.py
接続成功
結果: rowcount=1 success=True message=None

埋め込みモデルテスト...
✅ 埋め込み生成成功
テキスト: こんにちは、これはテストです。
埋め込み次元: 1536
埋め込みの最初の5要素: [0.02179015800356865, 0.021149272099137306, -0.021624768152832985, -0.061483804136514664, 0.04730159416794777]

無事OpenAIへの埋め込みが接続が成功しています。

4. ベクトルデータ保存用テーブルの作成

先ほど埋め込みで生成されたベクトルデータを保存するテーブルを作成します。再度connect.pyを以下に置き換えます。

connect.py
from pytidb import TiDBClient
from pytidb.embeddings import EmbeddingFunction
from pytidb.schema import TableModel, Field, VectorField

# データベース接続
client = TiDBClient.connect(
    host="gateway01.us-west-2.prod.aws.tidbcloud.com",
    port=4000,
    username="237VxcjMhxqE85K.root",
    password="mnxxxxx",
    database="test",
)

# 埋め込み機能の設定
text_embed = EmbeddingFunction(
    model_name="openai/text-embedding-3-small",
    api_key="<api key>",
)

# 接続確認
result = client.execute("SELECT 1 as test")
print("接続成功")
print(f"結果: {result}")

# 埋め込みモデルのテスト
print("\n埋め込みモデルテスト...")
test_text = "こんにちは、これはテストです。"
try:
    embedding = text_embed.get_query_embedding(test_text)
    print("✅ 埋め込み生成成功")
    print(f"テキスト: {test_text}")
    print(f"埋め込み次元: {len(embedding)}")
    print(f"埋め込みの最初の5要素: {embedding[:5]}")
except Exception as e:
    print(f"❌ 埋め込み生成エラー: {e}")

# テーブルスキーマの定義
print("\nテーブル作成...")
class Chunk(TableModel):
    id: int | None = Field(default=None, primary_key=True)
    text: str = Field()
    text_vec: list[float] = text_embed.VectorField(source_field="text")
    user_id: int = Field()

try:
    table = client.create_table(schema=Chunk, mode="overwrite")
    print("✅ テーブル作成成功")
    print(f"テーブル: {table}")
except Exception as e:
    print(f"❌ テーブル作成エラー: {e}")
python connect.py
接続成功
結果: rowcount=1 success=True message=None

埋め込みモデルテスト...
✅ 埋め込み生成成功
テキスト: こんにちは、これはテストです。
埋め込み次元: 1536
埋め込みの最初の5要素: [0.02179015800356865, 0.021149272099137306, -0.021624768152832985, -0.061483804136514664, 0.04730159416794777]

テーブル作成...
✅ テーブル作成成功
テーブル: <pytidb.table.Table object at 0x713bd3ba5eb0>

5. データの埋め込み

では次に与えられた文字列をベクトル化してテーブルに保存します。再度connect.pyを以下に変更します。

from pytidb import TiDBClient
from pytidb.embeddings import EmbeddingFunction
from pytidb.schema import TableModel, Field, VectorField

# データベース接続
client = TiDBClient.connect(
    host="gateway01.us-west-2.prod.aws.tidbcloud.com",
    port=4000,
    username="237VxcjMhxqE85K.root",
    password="mnxxxxx",
    database="test",
)

# 埋め込み機能の設定
text_embed = EmbeddingFunction(
    model_name="openai/text-embedding-3-small",
    api_key="<api key>",
)

# 接続確認
result = client.execute("SELECT 1 as test")
print("接続成功")
print(f"結果: {result}")

# 埋め込みモデルのテスト
print("\n埋め込みモデルテスト...")
test_text = "こんにちは、これはテストです。"
try:
    embedding = text_embed.get_query_embedding(test_text)
    print("✅ 埋め込み生成成功")
    print(f"テキスト: {test_text}")
    print(f"埋め込み次元: {len(embedding)}")
    print(f"埋め込みの最初の5要素: {embedding[:5]}")
except Exception as e:
    print(f"❌ 埋め込み生成エラー: {e}")

# テーブルスキーマの定義
print("\nテーブル作成...")
class Chunk(TableModel):
    id: int | None = Field(default=None, primary_key=True)
    text: str = Field()
    text_vec: list[float] = text_embed.VectorField(source_field="text")
    user_id: int = Field()

try:
    table = client.create_table(schema=Chunk, mode="overwrite")
    print("✅ テーブル作成成功")
    print(f"テーブル: {table}")
except Exception as e:
    print(f"❌ テーブル作成エラー: {e}")

# データの一括挿入
print("\nデータ挿入...")
try:
    table.bulk_insert([
        # テキストは自動的に埋め込まれて text_vec フィールドに格納されます
        Chunk(text="PyTiDB is a Python library for developers to connect to TiDB.", user_id=2),
        Chunk(text="LlamaIndex is a framework for building AI applications.", user_id=2),
        Chunk(text="OpenAI is a company and platform that provides AI models service and tools.", user_id=3),
    ])
    print("✅ データ挿入成功")
except Exception as e:
    print(f"❌ データ挿入エラー: {e}")
python connect.py
接続成功
結果: rowcount=1 success=True message=None

埋め込みモデルテスト...
✅ 埋め込み生成成功
テキスト: こんにちは、これはテストです。
埋め込み次元: 1536
埋め込みの最初の5要素: [0.021751966327428818, 0.02115233987569809, -0.02162790484726429, -0.0613686628639698, 0.04722575098276138]

テーブル作成...
✅ テーブル作成成功
テーブル: <pytidb.table.Table object at 0x75f6584fe4e0>

データ挿入...
✅ データ挿入成功

先ほど作成されたテーブルにベクトルデータの格納が完了しています。
なんとtable.bulk_insertだけで埋め込みができるようです。まさにAI用のSDKですね。

    table.bulk_insert([
        # テキストは自動的に埋め込まれて text_vec フィールドに格納されます
        Chunk(text="PyTiDB is a Python library for developers to connect to TiDB.", user_id=2),
        Chunk(text="LlamaIndex is a framework for building AI applications.", user_id=2),
        Chunk(text="OpenAI is a company and platform that provides AI models service and tools.", user_id=3),
    ])

text_embed.VectorField(source_field="text")により、textフィールドの値が自動的に埋め込みAPIに送信されます。OpenAI APIから返されたベクトルがtext_vecフィールドに格納されるとうになっており、この処理はbulk_insert時に自動実行されるためです。

class Chunk(TableModel):
    id: int | None = Field(default=None, primary_key=True)
    text: str = Field()
    # ⭐ この部分の説明をもう少し詳しく
    text_vec: list[float] = text_embed.VectorField(source_field="text")
    user_id: int = Field()

6. 検索

次に検索を行います。再度connect.pyを置き換えます。

connect.py
from pytidb import TiDBClient
from pytidb.embeddings import EmbeddingFunction
from pytidb.schema import TableModel, Field, VectorField

# データベース接続
client = TiDBClient.connect(
    host="gateway01.us-west-2.prod.aws.tidbcloud.com",
    port=4000,
    username="237VxcjMhxqE85K.root",
    password="mnxxxxx",
    database="test",
)

# 埋め込み機能の設定
text_embed = EmbeddingFunction(
    model_name="openai/text-embedding-3-small",
    api_key="<api key>",
)

# 接続確認
result = client.execute("SELECT 1 as test")
print("接続成功")
print(f"結果: {result}")

# 埋め込みモデルのテスト
print("\n埋め込みモデルテスト...")
test_text = "こんにちは、これはテストです。"
try:
    embedding = text_embed.get_query_embedding(test_text)
    print("✅ 埋め込み生成成功")
    print(f"テキスト: {test_text}")
    print(f"埋め込み次元: {len(embedding)}")
    print(f"埋め込みの最初の5要素: {embedding[:5]}")
except Exception as e:
    print(f"❌ 埋め込み生成エラー: {e}")

# テーブルスキーマの定義
print("\nテーブル作成...")
class Chunk(TableModel):
    id: int | None = Field(default=None, primary_key=True)
    text: str = Field()
    text_vec: list[float] = text_embed.VectorField(source_field="text")
    user_id: int = Field()

try:
    table = client.create_table(schema=Chunk, mode="overwrite")
    print("✅ テーブル作成成功")
    print(f"テーブル: {table}")
except Exception as e:
    print(f"❌ テーブル作成エラー: {e}")

# データの一括挿入
print("\nデータ挿入...")
try:
    table.bulk_insert([
        # テキストは自動的に埋め込まれて text_vec フィールドに格納されます
        Chunk(text="PyTiDB is a Python library for developers to connect to TiDB.", user_id=2),
        Chunk(text="LlamaIndex is a framework for building AI applications.", user_id=2),
        Chunk(text="OpenAI is a company and platform that provides AI models service and tools.", user_id=3),
    ])
    print("✅ データ挿入成功")
except Exception as e:
    print(f"❌ データ挿入エラー: {e}")

# ベクトル検索
print("\nベクトル検索...")
try:
    results = table.search(
        # クエリテキストを直接渡すと、自動的にクエリベクトルに埋め込まれます
        "A library for my artificial intelligence software"
    ).limit(3).to_list()
    
    print("✅ 検索成功")
    print("検索結果:")
    for i, result in enumerate(results, 1):
        print(f"  {i}. {result}")
except Exception as e:
    print(f"❌ 検索エラー: {e}")
python connect.py
接続成功
結果: rowcount=1 success=True message=None

埋め込みモデルテスト...
✅ 埋め込み生成成功
テキスト: こんにちは、これはテストです。
埋め込み次元: 1536
埋め込みの最初の5要素: [0.02179015800356865, 0.021149272099137306, -0.021624768152832985, -0.061483804136514664, 0.04730159416794777]

テーブル作成...
✅ テーブル作成成功
テーブル: <pytidb.table.Table object at 0x7f10942bc680>

データ挿入...
✅ データ挿入成功

ベクトル検索...
✅ 検索成功
検索結果:
  1. {'id': 2, 'text': 'LlamaIndex is a framework for building AI applications.', 'text_vec': array([-0.00899981, -0.0415895 ,  0.03556553, ...,  0.02199954,
       -0.02746931,  0.00598783], shape=(1536,), dtype=float32), 'user_id': 2, '_distance': 0.5720425717931883, '_score': 0.4279574282068117}
  2. {'id': 3, 'text': 'OpenAI is a company and platform that provides AI models service and tools.', 'text_vec': array([-0.02068236, -0.02139141,  0.02399852, ...,  0.01264286,
       -0.01109386,  0.02537298], shape=(1536,), dtype=float32), 'user_id': 3, '_distance': 0.6031775185994683, '_score': 0.39682248140053167}
  3. {'id': 1, 'text': 'PyTiDB is a Python library for developers to connect to TiDB.', 'text_vec': array([-0.05514091, -0.07363754, -0.00621099, ...,  0.00800503,
        0.01695888,  0.02312079], shape=(1536,), dtype=float32), 'user_id': 2, '_distance': 0.6202760421735025, '_score': 0.3797239578264975}

無事検索結果が表示されています。素晴らしい!!

7.テーブルの削除

では最後にdelete.pyを作成後実行して終わりです。

delete.py
from pytidb import TiDBClient

# データベース接続
client = TiDBClient.connect(
    host="gateway01.us-west-2.prod.aws.tidbcloud.com",
    port=4000,
    username="237VxcjMhxqE85K.root",
    password="mnxxxxx",
    database="test",
)

# 接続確認
result = client.execute("SELECT 1 as test")
print("接続成功")
print(f"結果: {result}")

# テーブル削除(存在する場合のみ)
try:
    client.execute("DROP TABLE IF EXISTS chunks")
    print("\n✅ chunksテーブルを削除しました(存在していた場合)")
except Exception as e:
    print(f"\n❌ テーブル削除エラー: {e}")
python delete.py
接続成功
結果: rowcount=1 success=True message=None

✅ chunksテーブルを削除しました(存在していた場合)

おつかれさまでした!

Discussion