OpenSearch で日本語 Sparse search を動かしてみる
本投稿は以下アドベントカレンダーとリンクしています。
- AWS Analytics Advent Calendar 2024 (12/25 分)
- 情報検索・検索技術 Advent Calendar 2024 (シリーズ2 12/25 分)
OpenSearch におけるベクトル検索の概要
OpenSearch はオープンソースの全文検索・分析スイートです。OpenSearch は全文検索に加えてベクトル検索の機能も備えています。
ベクトル検索は従来の全文検索による検索が苦手とするあいまいなクエリの処理を実現することができるため、あいまい検索や LLM と組み合わせた検索検索拡張生成(RAG) に代表される文書検索・ナレッジ検索で幅広く活用されています。
一般的にベクトル検索と呼ばれているのは、N 次元の数値配列からなる密ベクトルを使った検索です。密ベクトル検索では、クエリと検索対象のデータは N 次元の数値配列として扱われ、それらの距離や角度の差異が類似度として表されます。距離や角度が近いほど類似度が高いとみなされるわけです。
全文検索はクエリと検索対象のデータ間で厳密なマッチングが要求される一方、ベクトル検索は"意味的に近い" 文書を取得する際に有用であるため、うまく使い分けることで幅広い検索要件を達成できます。
従来型のベクトル検索の課題
クエリとデータの類似度を判定するためには、クエリを実行する都度、クエリベクトルと検索対象のベクトル間の距離計算が必要となります。小規模なベクトルデータに対する検索であれば、クエリと全ベクトル間の距離を計算することも現実的な選択肢となるのですが、検索対象のベクトル数が億単位に上るとレイテンシの面で実用的ではなくなります。この場合、Approximate kNN と呼ばれる、精度を少々犠牲にする代わりに探査回数を削減するアプローチが取られます。
高速な密ベクトル検索を実現するうえでは、検索対象のベクトルがメモリ上に格納されていることが求められるため、規模が大きくなるにつれて要求メモリも拡大し、コストが増大します。
ベクトルの量子化や Disk-ANN と呼ばれる仕組みで要求メモリを削減することはできますが、依然としてメモリ上にベクトルデータを格納する必要性は残ります。
スパース検索
OpenSearch では、バージョン 2.11 より Sparse neural search 機能によるスパース検索が実装されています。
スパース検索では、テキストから生成された数値ベクトルではなく、テキストデータをもとにスパースエンコーディングを使用して生成されたトークンリストに対して検索を行います。
スパースエンコーディングの過程で、エンコーダーは意味的に類似したトークンのリストを作成します。モデルの語彙(WordPiece)には、最も一般的に使用される単語に加えて、様々な時制の語尾(例: -edや-ing)や接尾辞(例: -ateや-ion)が含まれています。この語彙は、各文書がスパースベクトルとして表現される意味空間として考えることができます。
以下は密ベクトルと Sparse search におけるデータ構造の違いを表したものです。
出典:
OpenSearch では、スパース検索は転置インデックスを使用して実装されています。スパース検索のためのスパースベクトルは、トークンと重み付けのキーバリューリストとして、rank_features タイプフィールドに格納されています。これにより、従来の全文検索に近い形であいまい検索を実現することが可能です。
日本語による Sparse search 実装の流れ
Sparse search において使用されるスパースエンコーディングモデルとして、OpenSearch Project ではいくつかのモデルを作成・公開しています。しかしながらこれらのモデルは英語のみをサポートしており、日本語を適切にトークン化することができません。
今回は日本語に対応したスパースエンコーディングモデルである japanese-splade-v2 を使わせていただき、日本語による Sparse search を OpenSearch 上で実行できるようにしていきます。
本モデルの詳細や、本モデルのベースとなっている SPLADE についての解説については、モデルの製作者であるセコンさんによるリリースをご覧ください。
準備作業
ベクトルデータベースとして、 Amazon OpenSearch Service ドメインを用意する必要があります。記事執筆時点では OpenSearch Serverless では Sparse search を利用できないため、 OpenSearch
2.17 バージョンの、OpenSearch ドメインを使用しています。
また、作業は Jupyter Notebook もしくは JupyterLab 上で行います。事前に Amazon SageMaker Notebook もしくは Amazon SageMaker Studio などの環境を用意することをお勧めします。
以降の作業は全て JupyterLab 上で実行したものです。
モデルのデプロイ
JupyterLab 上で実行していきます。
事前作業
パッケージインストール、ヘルパー関数のセット、変数のセットなどを実施しています。この部分はそれほど重要な内容ではないので、一つ一つの処理を詳しく見ていく必要はありません。
パッケージインストール
!sudo apt-get update -y
!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
!sudo apt-get install git-lfs git -y
!git lfs install
ライブラリのインポート
import boto3
import json
from datetime import datetime, timedelta
import time
from functools import lru_cache
import pandas as pd
import numpy as np
import sagemaker
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
ヘルパー関数のセット
def get_huggingface_tei_image_uri(instance_type, region):
key = "huggingface-tei" if instance_type.startswith("ml.g") or instance_type.startswith("ml.p") else "huggingface-tei-cpu"
return get_huggingface_llm_image_uri(key, version="1.2.3", region=region)
def search_endpoint(model_id, region):
sagemaker_client = boto3.client("sagemaker", region_name=region)
response = sagemaker_client.search(
Resource="Endpoint",
SearchExpression={
"Filters": [
{
"Name": "EndpointName",
"Operator": "Contains",
"Value": model_id
},
],
},
SortBy="LastModifiedTime",
SortOrder="Descending",
MaxResults=1,
)
return response["Results"]
@lru_cache(maxsize=None)
def list_instance_quotas_for_realtime_inference(instance_family, region):
service_quotas_client = boto3.client("service-quotas", region_name=region)
quotas = []
paginator = service_quotas_client.get_paginator('list_service_quotas')
page_iterator = paginator.paginate(ServiceCode='sagemaker',
PaginationConfig={'MaxResults': 100})
for page in page_iterator:
for quota in page['Quotas']:
if quota["QuotaName"].endswith("endpoint usage") and quota["Value"] > 0 and quota["QuotaName"].startswith("ml."+instance_family):
quotas.append(quota)
return quotas
def list_instance_usages_from_quotas(quotas, region):
cloudwatch_client = boto3.client("cloudwatch", region_name=region)
end_time = datetime.utcnow()
end_time = end_time.replace(minute=end_time.minute - (end_time.minute % 5),
second=0,
microsecond=0) - timedelta(minutes=5)
start_time = end_time - timedelta(hours=1)
metric_data_queries = []
for i, quota in enumerate(quotas):
metric = quota['UsageMetric']
dimensions = []
for key, value in metric['MetricDimensions'].items():
dimensions.append({
'Name': key,
'Value': value
})
metric_data_queries.append({
'Id': f'usage_{i}',
'MetricStat': {
'Metric': {
'Namespace': metric['MetricNamespace'],
'MetricName': metric['MetricName'],
'Dimensions': dimensions
},
'Period': 300, # 5分間隔
'Stat': metric['MetricStatisticRecommendation']
}
})
payload = {
'MetricDataQueries': metric_data_queries,
'StartTime': start_time,
'EndTime': end_time
}
usages = cloudwatch_client.get_metric_data(**payload)
return usages
def list_instance_usages(region):
sagemaker_client = boto3.client("sagemaker", region_name=region)
response = sagemaker_client.list_endpoints(
)
instances = []
for endpoint in response["Endpoints"]:
response = sagemaker_client.describe_endpoint(EndpointName=endpoint["EndpointName"])
response = sagemaker_client.describe_endpoint_config(EndpointConfigName=response["EndpointConfigName"])
instances.append(response["ProductionVariants"][0]["InstanceType"])
values, counts = np.unique(instances, return_counts=True)
return values,counts
def list_instance_attributes_realtime_inference(instance_family, region):
pricing = boto3.client("pricing", region_name="us-east-1")
instance_types = []
paginator = pricing.get_paginator("get_products")
page_iterator = paginator.paginate(
ServiceCode="AmazonSageMaker",
Filters=[
{
"Type": "TERM_MATCH",
"Field": "productFamily",
"Value": "ML Instance"
},
{
"Type": "TERM_MATCH",
"Field": "regionCode",
"Value": region
},
{
"Type": "TERM_MATCH",
"Field": "platoinstancetype",
"Value": "Hosting"
},
{
"Type": "TERM_MATCH",
"Field": "platoinstancename",
"Value": instance_family
},
],
)
products = []
for page in page_iterator:
for product in page["PriceList"]:
products.append(json.loads(product)["product"]["attributes"])
return products
def list_available_instance_types_for_realtime_inference(instance_family, region):
quotas = list_instance_quotas_for_realtime_inference(instance_family=instance_family, region=region)
quotas_df = pd.json_normalize(quotas).loc[:,["QuotaName","Value"]]
quotas_df["InstanceType"] = quotas_df["QuotaName"].str.removesuffix(" for endpoint usage")
quotas_df = quotas_df.drop(columns=["QuotaName"]).rename(columns={"Value":"Limit"})
quotas_df["Limit"] = quotas_df["Limit"].astype(int)
usage_values,usage_counts = list_instance_usages(region=region)
usages_df = pd.DataFrame({"InstanceType": usage_values, "Usage": usage_counts})
attributes = list_instance_attributes_realtime_inference(instance_family=instance_family, region=region)
attributes_df = pd.json_normalize(attributes)
attributes_df = attributes_df.loc[:,["instanceName","vCpu"]].rename(columns={"instanceName": "InstanceType"})
merged_df = pd.merge(pd.merge(quotas_df, usages_df, how="left", on="InstanceType"), attributes_df, on='InstanceType').fillna(value=0)
filtered_df = merged_df.query("Usage<Limit")
instance_types = filtered_df.sort_values("vCpu", ascending=True, ignore_index=True).InstanceType
return instance_types.array
変数等のセット
sagemaker_region = boto3.Session().region_name
default_instance_family_gpu = "g5"
default_instance_family_cpu = "m5"
py_version='py310',
transformers_version="4.37.0", # transformers version used
pytorch_version="2.1.0", # pytorch version used
モデルのデプロイ
Huggingface からダウンロードしたモデルをベースに、Amazon SageMaker の推論エンドポイントを構築します。
モデルのダウンロード
hf_model_id_sparse_embedding = "hotchpotch/japanese-splade-v2"
# hf_model_id_sparse_embedding = "hotchpotch/japanese-splade-base-v1"
model_id_sparse_embedding = hf_model_id_sparse_embedding.lower().replace("/", "-")
role = sagemaker.get_execution_role()
session = sagemaker.Session()
default_bucket = session.default_bucket()
s3_location=f"s3://{default_bucket}/custom_inference/{model_id_sparse_embedding}/model.tar.gz"
%pushd
%mkdir -p ./models
%cd ./models
!git clone https://huggingface.co/$hf_model_id_sparse_embedding
カスタムモデルの作成とデプロイ
ダウンロードしたモデルを元に、Amazon SageMaker モデルエンドポイント用のファイルを追加したモデルファイルを作成、デプロイし、エンドポイントを構築します。
%mkdir -p ./code
Amazon SageMaker におけるリアルタイム推論エンドポイントでは、model_fn でモデルのロード、predict_fn で推論実行を行います。
model_fn 内では、モデルと同様にセコンさんが作成された yasem と呼ばれるライブラリを使用してモデルの読み込みと実行を行っています。
%%writefile ./code/inference.py
from yasem import SpladeEmbedder
def model_fn(model_dir, context=None):
model = SpladeEmbedder(model_dir)
return model
def predict_fn(input_data, model):
text_docs = input_data["inputs"]
embeddings = model.encode(text_docs)
token_values = model.get_token_values(embeddings)
results = {
"response": [token_values]
}
return results
inference.py から参照している yasem および本モデルの実行に必要な fugashi, unidic_lite を導入するために requirements.txt を作成します
%%writefile ./code/requirements.txt
yasem==0.3.2
fugashi
unidic_lite
作成した inference.py と requirements.txt をモデルファイルのディレクトリ配下にコピーし、アーカイブを作成して Amazon S3 バケットにアップロードします
model_dir_sparse_embedding = hf_model_id_sparse_embedding.split("/")[-1]
!cp -rf ./code/ ./$model_dir_sparse_embedding/code/
%cd $model_dir_sparse_embedding
!rm -f model.tar.gz
!tar zcvf model.tar.gz *
!aws s3 cp model.tar.gz $s3_location
アップロードされたモデルを元にエンドポイントを構築します
available_instance_types = list_available_instance_types_for_realtime_inference(instance_family=default_instance_family_gpu, region=sagemaker_region)
instance_type = available_instance_types[0]
print(f"start deploy {hf_model_id_sparse_embedding} on {instance_type}")
role = sagemaker.get_execution_role()
huggingface_model_sparse_embedding = HuggingFaceModel(
model_data=s3_location, # path to your model and script
entry_point='inference.py',
source_dir='code',
role=role, # iam role with permissions to create an Endpoint
py_version='py310',
transformers_version="4.37.0", # transformers version used
pytorch_version="2.1.0", # pytorch version used
)
huggingface_model_sparse_embedding.deploy(
endpoint_name=sagemaker.utils.name_from_base(model_id_sparse_embedding),
initial_instance_count=1,
instance_type=instance_type
)
作業後は元のディレクトリに戻っておきましょう。
%popd
推論のテスト
推論のテストをしていきます。"車の燃費を向上させる方法は?" と呼ばれるテキストから生成されるトークンのリストを見ていきましょう。
model_endpoint_name_sparse_embedding = search_endpoint(model_id_sparse_embedding,sagemaker_region)[0]["Endpoint"]["EndpointName"]
model_endpoint_url_sparse_embedding = f"https://runtime.sagemaker.{sagemaker_region}.amazonaws.com/endpoints/{model_endpoint_name_sparse_embedding}/invocations"
print("sparse embedding endpoint name: " + model_endpoint_name_sparse_embedding)
print("sparse embedding model endpoint url: " + model_endpoint_url_sparse_embedding)
payload = {
"inputs": [
"車の燃費を向上させる方法は?"
]
}
body = bytes(json.dumps(payload), 'utf-8')
sagemaker_runtime_client = boto3.client("sagemaker-runtime",region_name=sagemaker_region)
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=model_endpoint_name_sparse_embedding,
ContentType="application/json",
Accept="application/json",
Body=body
)
result = eval(response['Body'].read().decode('utf-8'))
result
結果は以下のようになりました。これでエンドポイントによる推論が可能となったため、OpenSearch との連携に進みます。
{'response': [[{'燃費': 1.130859375,
'方法': 1.0693359375,
'車': 1.05078125,
'高める': 0.67041015625,
'向上': 0.55615234375,
'増加': 0.5244140625,
'都市': 0.443115234375,
'ガソリン': 0.31982421875,
'改善': 0.297607421875,
'こと': 0.287353515625,
'せ': 0.21533203125,
'上昇': 0.18896484375,
'べき': 0.1451416015625,
'する': 0.128173828125,
'メリット': 0.12384033203125,
'すすめ': 0.10467529296875,
'減': 0.084228515625,
'軽減': 0.07159423828125,
'手段': 0.0250701904296875,
'効果': 0.022216796875,
'ガス': 0.0019512176513671875},
]]}
OpenSearch の設定
OpenSearch の Neural sparse search では、ユーザーがモデルを呼び出してテキストからトークンリストを生成することなく、Sparse search を実行することができます。モデルの呼び出しとトークン化は、OpenSearch が推論エンドポイントと連携して実施してくれるためです。以下はモデルと OpenSearch 内の各コンポーネントの対応図です。
Sparse neural search を実行するために、以下のコンポーネントを作成していきます。以降の作業も JupyterLab(もしくは Jupyter Notebook) 上で実行していきます。
- コネクター(Connector) - 推論エンドポイントやパラメーターなど、モデルを呼び出すために必要な情報を持つもの
- モデル(Model) - OpenSearch 上のモデルカタログに登録されたモデルの情報。コネクターからなる
- 取り込みパイプライン(Ingest Pipeline) - データ取り込み時に外部の推論エンドポイントを呼び出し、テキストからトークンリストを生成するパイプライン
- 検索パイプライン(Search Pipeline) - 検索時に外部の推論エンドポイントを呼び出し、クエリからトークンリストを生成するパイプライン
- インデックス(Index) - Sparse search に必要なドキュメントとドキュメントから生成されたトークンリストを格納
事前処理
パッケージインストール、インポートなど実施していきます。
!pip install opensearch-py requests-aws4auth 'awswrangler[opensearch]' --quiet
import boto3
import json
import time
import awswrangler as wr
import pandas as pd
import numpy as np
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
def search_sagemaker_inference_endpoint(model_id, region):
sagemaker_client = boto3.client("sagemaker", region_name=region)
response = sagemaker_client.search(
Resource="Endpoint",
SearchExpression={
"Filters": [
{
"Name": "EndpointName",
"Operator": "Contains",
"Value": model_id
},
],
},
SortBy="LastModifiedTime",
SortOrder="Descending",
MaxResults=1,
)
return response["Results"]
default_region = boto3.Session().region_name
エンドポイントをセットします。
opensearch_cluster_endpoint = "search-sample-qrvg3nyg1vxb-ox9pmxoel6g5zjegp7cvdyfuca.us-east-1.es.amazonaws.com"
credentials = boto3.Session().get_credentials()
service_code = "es"
auth = AWSV4SignerAuth(credentials=credentials, region=default_region, service=service_code)
opensearch_client = OpenSearch(
hosts=[{"host": opensearch_cluster_endpoint, "port": 443}],
http_compress=True,
http_auth=auth,
use_ssl=True,
verify_certs=True,
connection_class = RequestsHttpConnection
)
opensearch_client.info()
検索インデックスの作成
検索インデックスを作成します。
payload = {
"mappings": {
"properties": {
"id": {"type": "keyword"},
"question": {"type": "text", "analyzer": "custom_sudachi_analyzer"},
"context": {"type": "text", "analyzer": "custom_sudachi_analyzer"},
"answers": {"type": "text", "analyzer": "custom_sudachi_analyzer"},
"question_sparse_embedding": {
"type": "rank_features"
},
}
},
"settings": {
"index.knn": True,
"index.number_of_shards": 1,
"index.number_of_replicas": 0,
"analysis": {
"analyzer": {
"custom_sudachi_analyzer": {
"char_filter": ["icu_normalizer"],
"filter": [
"sudachi_normalizedform",
"custom_sudachi_part_of_speech"
],
"tokenizer": "sudachi_tokenizer",
"type": "custom"
}
},
"filter": {
"custom_sudachi_part_of_speech": {
"type": "sudachi_part_of_speech",
"stoptags": ["感動詞,フィラー","接頭辞","代名詞","副詞","助詞","助動詞","動詞,一般,*,*,*,終止形-一般","名詞,普通名詞,副詞可能"]
}
}
}
}
}
# インデックス名を指定
index_name = "jsquad"
try:
# 既に同名のインデックスが存在する場合、いったん削除を行う
print("# delete index")
response = opensearch_client.indices.delete(index=index_name)
print(json.dumps(response, indent=2))
except Exception as e:
print(e)
# インデックスを作成
response = opensearch_client.indices.create(index_name, body=payload)
response
コネクタの作成とモデルの登録
Amazon SageMaker 推論エンドポイントの情報から URL を作成し、コネクタの登録とモデルの作成を実行します。
hf_model_id_sparse_embedding = "hotchpotch/japanese-splade-v2"
model_id_sparse_embedding = hf_model_id_sparse_embedding.lower().replace("/", "-")
model_endpoint_name_sparse_embedding = search_sagemaker_inference_endpoint(model_id_sparse_embedding, default_region)[0]["Endpoint"]["EndpointName"]
model_endpoint_url_sparse_embedding = f"https://runtime.sagemaker.{default_region}.amazonaws.com/endpoints/{model_endpoint_name_sparse_embedding}/invocations"
print("embedding endpoint name: " + model_endpoint_name_sparse_embedding)
print("embedding model endpoint url: " + model_endpoint_url_sparse_embedding)
payload = {
"name": model_id_sparse_embedding,
"description": "Remote connector for "+ model_id_sparse_embedding,
"version": 1,
"protocol": "aws_sigv4",
"credential": {
"roleArn": opensearch_connector_role_arn
},
"parameters": {
"region": default_region,
"service_name": "sagemaker"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"headers": {
"content-type": "application/json"
},
"url": model_endpoint_url_sparse_embedding,
"pre_process_function": """
def text_docs = params.text_docs;
def textDocsBuilder = new StringBuilder('[');
for (int i=0; i<text_docs.length; i++) {
textDocsBuilder.append('\"');
textDocsBuilder.append(text_docs[i]);
textDocsBuilder.append('\"');
if (i<text_docs.length - 1) {
textDocsBuilder.append(',');
}
}
textDocsBuilder.append(']');
def parameters = '{ \"inputs\": ' + textDocsBuilder.toString() + ' }';
return '{\"parameters\": ' + parameters + '}';
""",
"request_body": "{ \"inputs\": ${parameters.inputs}}"
}
]
}
response = opensearch_client.http.post("/_plugins/_ml/connectors/_create", body=payload)
sparse_embedding_connector_id = response['connector_id']
print("sparse embedding connector id: " + sparse_embedding_connector_id)
payload = {
"name": model_id_sparse_embedding,
"description": model_id_sparse_embedding,
"function_name": "remote",
"connector_id": sparse_embedding_connector_id
}
response = opensearch_client.http.post("/_plugins/_ml/models/_register?deploy=true", body=payload)
opensearch_model_id_sparse_embedding = response['model_id']
for i in range(300):
ml_model_status = opensearch_client.http.get("/_plugins/_ml/models/"+ opensearch_model_id_sparse_embedding)
model_state = ml_model_status.get("model_state")
if model_state in ["DEPLOYED", "PARTIALLY_DEPLOYED"]:
break
time.sleep(1)
print(ml_model_status)
if model_state == "DEPLOYED":
print("sparse embedding model " + opensearch_model_id_sparse_embedding + " is deployed successfully")
elif model_state == "PARTIALLY_DEPLOYED":
print("sparse embedding model " + opensearch_model_id_sparse_embedding + " is deployed only partially")
else:
raise Exception("sparse embedding model " + opensearch_model_id_sparse_embedding + " deployment failed")
以下のような出力が得られます。モデル ID は後続のパイプライン作成時に必要になります。
{'name': 'hotchpotch-japanese-splade-v2', 'model_group_id': 'Dy8s95MB2SlQmDGVnqrm', 'algorithm': 'REMOTE', 'model_version': '1', 'description': 'hotchpotch-japanese-splade-v2', 'model_state': 'DEPLOYED', 'created_time': 1735018716938, 'last_updated_time': 1735018716991, 'last_deployed_time': 1735018716991, 'auto_redeploy_retry_times': 0, 'planning_worker_node_count': 1, 'current_worker_node_count': 1, 'planning_worker_nodes': ['zZiZmVlXSwqGMPV4yCjBeg'], 'deploy_to_all_nodes': True, 'is_hidden': False, 'connector_id': 'Di8s95MB2SlQmDGVk6pE'}
sparse embedding model ES8s95MB2SlQmDGVn6oK is deployed successfully
モデルが登録出来たら、OpenSearch 経由でモデルの呼び出しが行えるかをテストします。
path = "/_plugins/_ml/_predict/sparse_encoding/" + opensearch_model_id_sparse_embedding
payload = {
"text_docs": ["日本で梅雨がないのはどこ?"]
}
opensearch_client.http.post(path, body=payload)
モデルが正しく登録できている場合は、以下のような結果が得られます。
{'inference_results': [{'output': [{'name': 'response',
'dataAsMap': {'response': [{'日本': 1.380859375,
'##雨': 1.1826171875,
'ない': 1.0830078125,
'場所': 1.0546875,
'梅': 1.0205078125,
'なし': 0.90380859375,
'国': 0.7978515625,
'雨': 0.6806640625,
'気候': 0.67138671875,
'少ない': 0.3505859375,
'町': 0.331787109375,
'地域': 0.242431640625,
'季節': 0.20263671875,
'都市': 0.14013671875,
'不': 0.042999267578125}]}}],
'status_code': 200}]}
パイプラインの作成
インデクシングパイプラインと検索パイプラインを作成していきます
インデクシングパイプライン内では、Sparse Encoding プロセッサを使用してトークンリストの生成を行います。
# リクエストペイロードの作成
payload = {
"processors": [
{
"sparse_encoding": {
"model_id": question_sparse_embedding
}
}
}
]
}
# パイプライン ID の指定
indexing_embedding_pipeline_id = "indexing_embedding_pipeline"
# パイプライン作成 API の呼び出し
response = opensearch_client.http.put("/_ingest/pipeline/" + indexing_embedding_pipeline_id, body=payload)
response
検索パイプライン内では、neural_query_enricher および neural_sparse_two_phase_processor の 2 つを組み合わせてクエリのトークン化とスパース検索を行います。
payload={
"request_processors": [
{
"neural_sparse_two_phase_processor": {
"tag": "neural-sparse",
"description": "Creates a two-phase processor for neural sparse search."
}
},
{
"neural_query_enricher" : {
"default_model_id": opensearch_model_id_sparse_embedding
}
}
}
]
}
# パイプライン ID の指定
search_embedding_pipeline_id = "search_embedding_pipeline"
# パイプライン作成 API の呼び出し
response = opensearch_client.http.put("/_search/pipeline/" + search_embedding_pipeline_id, body=payload)
response
データ登録
作成したパイプラインを指定して OpenSearch にデータ書き込みを行います。元のデータファイル内のデータをもとにパイプラインから各種コネクタを通じて ML サービスを呼び出し、埋め込みを生成しながらデータ格納が行われます。
今回はサンプルデータとして JSQuAD を使用しています。
%%time
dataset_dir = "./dataset/jsquad/"
%mkdir -p $dataset_dir
!curl -L -s -o $dataset_dir/valid.json https://github.com/yahoojapan/JGLUE/raw/main/datasets/jsquad-v1.1/valid-v1.1.json
#!curl -L -s -o $dataset_dir/train.json https://github.com/yahoojapan/JGLUE/raw/main/datasets/jsquad-v1.1/train-v1.1.json
%%time
import pandas as pd
import json
def squad_json_to_dataframe(input_file_path, record_path=["data", "paragraphs", "qas", "answers"]):
file = json.loads(open(input_file_path).read())
m = pd.json_normalize(file, record_path[:-1])
r = pd.json_normalize(file, record_path[:-2])
idx = np.repeat(r["context"].values, r.qas.str.len())
m["context"] = idx
m["answers"] = m["answers"]
m["answers"] = m["answers"].apply(lambda x: np.unique(pd.json_normalize(x)["text"].to_list()))
return m[["id", "question", "context", "answers"]]
valid_filename = f"{dataset_dir}/valid.json"
valid_df = squad_json_to_dataframe(valid_filename)
#train_filename = f"{dataset_dir}/train.json"
#train_df = squad_json_to_dataframe(train_filename)
%%time
index_name = "jsquad"
response = wr.opensearch.index_df(
client=opensearch_client,
df=valid_df,
#df=pd.concat([train_df, valid_df]),
use_threads=True,
id_keys=["id"],
index=index_name,
bulk_size=10, # 10 件ずつ書き込み
refresh=False,
pipeline=indexing_embedding_pipeline_id
)
index_name = "jsquad"
response = opensearch_client.indices.refresh(index=index_name)
response = opensearch_client.indices.forcemerge(index=index_name)
テキスト検索と Neural sparse search によるスパース検索の比較
テキスト検索
テキスト検索のヒット率は検索キーワードとインデックスに格納されたコンテンツの内容、およびアナライザーによる正規化設定により左右されます。
テキスト検索では、不要なキーワードは極力含まれない方がよい結果を得られる可能性があります。以下のような単語の組み合わせによる検索では十分性能を発揮できるといえます。
クエリ
index_name = "jsquad"
payload = {
"query": {
"match": {
"question": {
"query": "日本 梅雨 ない どこ",
"operator": "and"
}
}
},
"_source": False,
"fields": ["question", "answers", "context"],
"size": 10
}
response = opensearch_client.search(
index=index_name,
body=payload
)
pd.json_normalize(response["hits"]["hits"])
実行結果
index | id | score | fields.question | fields.answers | fields.context |
---|---|---|---|---|---|
jsquad | a10336p0q0 | 14.434446 | 日本で梅雨がないのは北海道とどこか。 | 小笠原諸島, 小笠原諸島を除く日本 | 梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国... |
jsquad | a10336p24q1 | 14.434446 | 梅雨が日本の中でない地域はどこか。 | 北海道, 東北地方 | 梅雨 [SEP] 年によっては梅雨明けの時期が特定できなかったり、あるいは発表がされないこ... |
では、以下のような検索クエリではどうでしょうか。
index_name = "jsquad"
payload = {
"query": {
"match": {
"question": {
"query": "日本で梅雨がない場所は?",
"operator": "and"
}
}
},
"_source": False,
"fields": ["question", "answers", "context"],
"size": 10
}
response = opensearch_client.search(
index=index_name,
body=payload
)
response
{'took': 1,
'timed_out': False,
'_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0},
'hits': {'total': {'value': 0, 'relation': 'eq'},
'max_score': None,
'hits': []}}
ヒットしませんでした。"ところ" ではなく "場所" としたことでヒットしなくなってしまったことが主な要因です。いわゆる "てにをは" のようなストップワードはトークンフィルターにより落とされて検索の対象外となっているため、実は上記のクエリは "日本" "梅雨" "ない" "場所" での検索とほぼ同じ形で処理されています。この 4 つのキーワードにマッチするドキュメントがあれば結果として返却されたはずですが、厳密にマッチするドキュメントがないためゼロ件ヒットとなってしまいました。
Neural sparse search
Neural sparse search は、ユーザーから与えられたクエリテキストをスパースエンコーディングモデルを通じてトークンリストに変換することで検索を行います。
テキスト検索が match などのクエリを使用するのとは異なり、Neural sparse search では "neural_sparse" クエリを使用して検索を行います。
%%time
# search
index_name = "jsquad"
query = "日本で梅雨がない場所は?"
payload = {
"size": 10,
"query": {
"neural_sparse": {
"question_sparse_embedding": {
"query_text": query,
}
}
},
"_source" : False,
"fields": ["question", "answers", "context"]
}
# 検索 API を実行
response = opensearch_client.search(
body = payload,
index = index_name,
filter_path = "hits.hits",
search_pipeline = search_embedding_pipeline_id #ベクトル変換を行うパイプラインを指定
)
# 結果を表示
pd.json_normalize(response["hits"]["hits"])
結果は以下の通りです。上位 5 つは質問の意図に沿ったドキュメントとなっています。下位 5 つは質問の意図と離れていますが、上位 5 つとスコアに大きな開きがあることも分かりました。
index | id | score | fields.question | fields.answers | fields.context |
---|---|---|---|---|---|
jsquad | a10336p24q1 | 23.883795 | 梅雨が日本の中でない地域はどこか。 | 北海道, 東北地方 | 梅雨 [SEP] 年によっては梅雨明けの時期が特定できなかったり、あるいは発表がされないこ... |
jsquad | a10336p0q0 | 21.338238 | 日本で梅雨がないのは北海道とどこか。 | 小笠原諸島, 小笠原諸島を除く日本 | 梅雨 [SEP] 梅雨(つゆ、ばいう)は、北海道と小笠原諸島を除く日本、朝鮮半島南部、中国... |
jsquad | a10336p32q3 | 19.378119 | 梅雨がないとされている都道府県はどこ? | 北海道 | 梅雨 [SEP] 実際の気象としては北海道にも道南を中心に梅雨前線がかかることはあるが、平... |
jsquad | a10336p32q2 | 17.695690 | 気候学的には梅雨はないとされている場所は? | 北海道 | 梅雨 [SEP] 実際の気象としては北海道にも道南を中心に梅雨前線がかかることはあるが、平... |
jsquad | a10336p18q0 | 16.959581 | 日本の地域で本格的な長雨に突入しない場所はどこか。 | 北海道 | 梅雨 [SEP] 次に梅雨前線は中国の江淮(長江流域・淮河流域)に北上する。6月下旬には華... |
jsquad | a10336p42q1 | 12.954581 | 梅雨の期間中ほとんど雨が降らない場合を何と呼ぶ? | 空梅雨, 空梅雨(からつゆ) | 梅雨 [SEP] 梅雨の期間中ほとんど雨が降らない場合がある。このような梅雨のことを空梅雨... |
jsquad | a10336p42q2 | 12.952448 | 梅雨の期間中ほとんど雨が降らない場合をなんという? | 空梅雨, 空梅雨(からつゆ) | 梅雨 [SEP] 梅雨の期間中ほとんど雨が降らない場合がある。このような梅雨のことを空梅雨... |
jsquad | a10336p42q4 | 12.803707 | ほとんど雨が降らない梅雨を何という? | 空梅雨, 空梅雨(からつゆ) | 梅雨 [SEP] 梅雨の期間中ほとんど雨が降らない場合がある。このような梅雨のことを空梅雨... |
jsquad | a10336p42q0 | 12.474874 | 梅雨の期間中ほとんど雨が降らない場合がある。このような梅雨のことをなんというか? | 空梅雨, 空梅雨(からつゆ) | 梅雨 [SEP] 梅雨の期間中ほとんど雨が降らない場合がある。このような梅雨のことを空梅雨... |
jsquad | a10336p7q4 | 11.680378 | 梅雨の事を中国では、何というか。 | 「(メイユー)」, メイユー | 梅雨 [SEP] 中国では「(メイユー)」、台湾では「(メイユー)」や「芒種雨」、韓国では... |
まとめ
Sparse encoding によるトークンベースのあいまい検索により、従来型の全文検索ではカバーできないクエリによるゼロ件ヒットの課題を解消できることが確認できました。
Neural sparse search の機能を使用することで、ユーザーはトークンリストをクライアントサイドで生成することなく、通常の全文検索と同じクエリテキストを投げるだけでスパース検索が実行できることも確認できました。OpenSearch のコネクタやパイプラインの機能を活用することで、OpenSearch 内で埋め込みの処理を完結できるため、実装の選択肢が広がります。もちろん、クライアント側でトークンリストを生成したスパース検索を行うことも可能です。
OpenSearch ではテキスト検索とスパース検索のハイブリッド検索も可能であるため、これらの検索を併用することでより幅広いユースケースに対応できる可能性があります。
Discussion