Rustとpgvectorを使って、類義語の検索をやってみた
はじめに
この記事は株式会社LabBase テックカレンダー Advent Calendar 2023 24日目の記事です。
株式会社LabBaseでエンジニアをしている渡辺創です。
今回は pgvector を使って、類義語の提案ができるか試してみたのでそれについて書いていきたいと思います。
背景
LabBase就職というサービスを提供しており、研究に取り組んでいる学生が研究概要を登録してくれています。研究を頑張っている学生を採用したい企業の方が研究内容と募集内容のマッチングによって、就職活動・採用活動を支援するサービスとなっています。
企業の人事の方が学生の研究内容をキーワードで検索するのですが、ユーザは自分の語彙の中でしか検索することができないので、その語彙を拡張する方法の1つとして、ユーザに類義語を提案しようと考えています。
類義語関連だとこちらの記事にも書いたので、ぜひ読んでみてください。
使用する技術など
- Rust
- pgvector
- Docker
- OpenAI Embedding
- LabBase就職に登録されている技術キーワードの一部
手順
- データの準備
- データベースの準備
- データのインサート
- キーワードからベクトルを取得し、保存
- 類似キーワードを検索する
ディレクトリ構成
.
├── sql
│ └── create_table.sql
├── db
│ └── Dockerfile
├── src
│ ├── data
│ │ └── test.csv
│ └── main.rs
├── .env
├── .gitignore
├── Cargo.lock
├── Cargo.toml
└── docker-compose.yml
データの準備
LabBase就職のキーワードに登録されているデータを5000単語ダウンロードしました。
以下のようなデータをサンプルデータとして使用します。
id,tag
5,建設業
6,生産性向上
7,コンクリート
8,橋梁
10,N体シュミレーション
データベースの準備
Docker と sqlファイル を使って、データベースを立ち上げる
FROM postgres:14.1
# buildに必要な依存関係を入れる
RUN apt-get update && \
apt-get install -y git make gcc postgresql-server-dev-14
# pgvectorをbuildしてinstall
RUN git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git && \
cd pgvector && \
make && \
make install && \
cd ../ && rm -rf pgvector
version: "3"
services:
db:
build:
context: ./db
restart: always
ports:
- "5433:5432"
volumes:
- db-store:/var/lib/postgresql/data
environment:
POSTGRES_USER: "your_user_name"
POSTGRES_PASSWORD: "your_password"
volumes:
db-store:
CREATE DATABASE pgvector_test;
\c pgvector_test
CREATE EXTENSION vector;
CREATE TABLE keywords
(
id BIGSERIAL PRIMARY KEY,
keyword VARCHAR(255) NOT NULL
);
CREATE TABLE keywords_vector
(
id BIGSERIAL PRIMARY KEY,
embedding vector(1536)
);
docker ps で立ち上がってることを確認する。ローカルのport:5433がコンテナのport:5432に接続されていることも確認する。
以下のコマンドを実行して、データベースとテーブルを作成する。
$ psql -p 5433 -h localhost -U your_user_name -f ./database/create_table.sql
データのインサート
CSVを読み込み、データベースに接続し、データを入れていく。
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
// csv読み込み
let mut rdr = csv::Reader::from_path("./src/data/test.csv").unwrap();
// データベース接続設定
let (client, connection) = tokio_postgres::connect(
"host=localhost user=your_user_name password=your_password dbname=pgvector_test port=5433",
NoTls,
)
.await?;
// コネクションを管理するための別のタスクを起動
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
// CSVから読み込んだデータをデータベースに挿入
for result in rdr.deserialize() {
let record: Keyword = result?;
client
.execute(
"INSERT INTO keywords (id, keyword) VALUES ($1, $2)",
&[&record.id, &record.tag],
)
.await?;
}
/*
キーワードからベクトルを取得し、保存
*/
/*
類義語を検索する
*/
Ok(())
}
#[derive(Debug, Deserialize)]
struct Keyword {
id: i64,
tag: String,
}
キーワードからベクトルを取得し、保存
OpenAIのAPIを利用して、embeddingを取得し、それをデータベースに保存する。
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
/*
データのインサート
*/
// データベースからデータを取り出して、Keywordに入れる
let mut ids: Vec<i64> = Vec::new();
let mut keywords: Vec<String> = Vec::new();
let rows = client.query("SELECT * FROM keywords", &[]).await?;
for row in rows {
let id: i64 = row.get(0);
let tag: String = row.get(1);
ids.push(id);
keywords.push(tag.clone());
}
// 1000データずつOpenAIに投げる
for (ids_chunk, keywords_chunk) in ids.chunks(1000).zip(keywords.chunks(1000)) {
let embeddings = get_embeddings(keywords_chunk.to_vec())
.await
.context("embedding fetch failed")?;
for (id, embedding) in ids_chunk.iter().zip(embeddings) {
let embedding_vector = Vector::from(embedding.embedding.clone());
client
.execute(
"INSERT INTO keywords_vector (id, embedding) VALUES ($1, $2)",
&[&id, &embedding_vector],
)
.await?;
}
}
/*
類義語を検索する
*/
Ok(())
}
pub async fn get_embeddings(kys: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
let openai_api_key = env::var("OPENAI_API_KEY").unwrap();
let config = OpenAIConfig::new().with_api_key(openai_api_key);
let client = OpenAIClient::with_config(config);
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-ada-002")
.input(kys)
.build()
.context("failed to build openai embedding request")?;
let res = client
.embeddings()
.create(request)
.await
.context("failed embedding request")?;
Ok(res.data)
}
類似キーワードを検索する
その関数をつくる
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
/*
データのインサート
*/
/*
キーワードからベクトルを取得し、保存
*/
// 類似キーワードを検索したいキーワードの設定
let query = vec!["言語モデル".to_string()];
// Embeddingを取得する
let query_embeddings = get_embeddings(query)
.await
.context("embedding fetch failed")?;
let query_vector = Vector::from(query_embeddings[0].embedding.clone());
// 類似キーワードを検索する
let embedding_rows = client
.query(
"SELECT * FROM keywords_vector ORDER BY embedding <-> $1 LIMIT 10",
&[&query_vector],
)
.await?;
for embedding in embedding_rows {
let id: i64 = embedding.get(0);
// id からキーワードを取得する
let keyword = client
.query_one("SELECT * FROM keywords WHERE id = $1", &[&id])
.await?;
let name: String = keyword.get(1);
// let distance: f32 = embedding.get(2);
println!("id: {}, name: {}", id, name);
}
Ok(())
}
言語モデルを検索したところ、以下の結果を得ることができました。
思ったより、いい結果がでました。データを整形したら、ユーザへの提案としては使える可能性もありそうです。
おわりに
最後までお読みいただき、ありがとうございます!ソースコードも少し整えたらgithubに上げる予定です。
LabBaseでは、GPTを活用した推薦の実装や検証を実際に手を動かしながら、進めているのでご興味ある方がいらっしゃいましたら、ぜひ渡辺のTwitterまでお気軽にお声がけください!!
明日はラストです!CTOの https://qiita.com/mizno さんです。お楽しみに!
Discussion