📝

BERTを使った類似したテキストメッセージの検索

2023/06/18に公開

自然言語処理(NLP)技術であるBERT(Bidirectional Encoder Representations from Transformers)を使い、入力された文章と似た意味をもつレコードをデータベースから取得するアプリケーションを作ってみました。
自然言語処理が未経験の状態から、完成までにかかった期間、結果の評価、今後に向けた反省をまとめています。

類似テキスト検索の原理

自然言語処理には、埋め込み(Embedding)という技術があります。これは、テキストメッセージの意味をベクトル表現(具体的には実数の配列)に変換する技術です。

ベクトル表現への変換は、コンピュータが文章の意味を計算可能な形式にするために行われます。

二次元平面上にある点Aと点Bは、それぞれの座標がわかれば距離を計算できますが、ベクトル表現に変換された文章同士も距離(意味の近さ)を計算することができます。この距離のことをコサイン類似度と呼ぶらしいです。

今回のアプリケーションでは、次のような処理の流れを実装しました。

完成までにかかった期間

開発期間は平日の朝を使用して4日程度です。

念のために断っておくと、今回は類似したテキストの検索機能を実装する際の手順等を把握することが目的だったので、モデルの選定やファインチューニングは行なっていません。このままの検索性能では実用には耐えられないので。性能の向上にはもっと時間とお金が必要だと思います。

使用した技術は次のとおりです。

  • Python
  • BERT (Bidirectional Encoder Representations from Transformers)
  • MySQL
  • Docker
  • Docker Compose
  • Jupyter Notebook

PythonやDockerの基本的な操作、前述した類似度計算の原理などは知っっている状態でスタートしました。
BERTや自然言語処理を実際に使用したのは今回が初めてでした。

実装詳細

下記URLで公開されているソースコードです。
GitHub

tweetを格納するテーブルの作成

MySQLに接続し、必要なテーブルを作成する。

tweets

Column name Data type Description
id INT テーブルの主キー。自動的にインクリメントされる。
text VARCHAR(255) ツイートの本文。
created_at TIMESTAMP レコードが作成された日時。デフォルトは現在のタイムスタンプ。

tweet_vectors

Column name Data type Description
id INT テーブルの主キー。自動的にインクリメントされる。
tweet_id INT tweetsテーブルのidを参照する外部キー。
vector JSON ツイートのBERTによる数値ベクトル表現。
created_at TIMESTAMP レコードが作成された日時。デフォルトは現在のタイムスタンプ。

tweetsにはテキストメッセージを格納する。

tweet_vectorsには、tweetsから取得したテキストメッセージをベクトル表現に変換したものを格納する。

import pymysql

# MySQLデータベースへの接続情報を設定
host = "db"  # Docker Composeで定義したMySQLサービスのサービス名
port = 3306  # MySQLのデフォルトのポート番号
user = "user"  # MySQLのユーザ名
password = "password"  # MySQLのパスワード
database = "tweets_db"  # MySQLのデータベース名

def connect_to_database():
    # MySQLデータベースに接続
    return pymysql.connect(
        host=host,
        port=port,
        user=user,
        password=password,
        database=database
    )

def create_table_if_not_exists(cursor, create_table_query):
    # テーブルを作成するSQL文
    cursor.execute(create_table_query)

    # 変更をコミット(確定)
    connection.commit()

# tweetsテーブルの作成
create_tweets_table_query = """
CREATE TABLE IF NOT EXISTS tweets (
    id INT AUTO_INCREMENT PRIMARY KEY,
    text VARCHAR(255),
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
"""

# tweet_vectorsテーブルの作成
create_tweet_vectors_table_query = """
CREATE TABLE IF NOT EXISTS tweet_vectors (
    id INT AUTO_INCREMENT PRIMARY KEY,
    tweet_id INT,
    vector JSON,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    FOREIGN KEY (tweet_id) REFERENCES tweets(id)
);
"""

# データベースへの接続
connection = connect_to_database()

# カーソルオブジェクトを作成
cursor = connection.cursor()

# テーブルの作成
create_table_if_not_exists(cursor, create_tweets_table_query)
create_table_if_not_exists(cursor, create_tweet_vectors_table_query)

# MySQLデータベースとの接続を閉じる
connection.close()

ダミーのtweetの作成と保存

ダミーのtweetを作成してMySQLに保存する。

具体的には、事前に作成したテキストメッセージtweets.csvを読み込み、tweetsテーブルに保存する。

import configparser
import pandas as pd
from sqlalchemy import create_engine, Table, MetaData
from sqlalchemy.sql import insert

# 設定を読み込む
config = configparser.ConfigParser()
config.read('config.ini')

# MySQLの設定
username = config['DATABASE']['USERNAME']
password = config['DATABASE']['PASSWORD']
hostname = config['DATABASE']['HOSTNAME']
database = config['DATABASE']['DATABASE']

engine = create_engine(f"mysql+pymysql://{username}:{password}@{hostname}/{database}")

# ダミーのツイート
df_tweet = pd.read_csv("tweets.csv")

# ツイートをデータベースに保存
with engine.begin() as connection:
    metadata = MetaData()
    metadata.bind = engine

    tweet_table = Table('tweets', metadata, autoload_with=engine)
    for tweet in df_tweet['text']:
        stmt = insert(tweet_table).values(text=tweet)
        connection.execute(stmt)

BERTによるtweetのベクトル変換

MySQLに保存されているtweetをBERTを利用してベクトル表現に変換し、tweet_vectorsに保存する。

  1. Hugging Faceのtransformersというライブラリを使って、BERTのモデルbert-base-uncasedとトークナイザをロード
  2. tweetsテーブルからテキストメッセージを取得する
  3. テキストメッセージのベクトル変換
  • トークナイザでテキストメッセージをトークン化(モデルに入力できる形式に変換)
  • トークンをBERTモデルに入力してベクトル表現を取得する
  1. ベクトル表現の保存
import configparser
import torch
from sqlalchemy import select
from transformers import BertModel, BertTokenizer
from sqlalchemy import create_engine, Table, MetaData
from sqlalchemy.sql import insert

# 設定を読み込む
config = configparser.ConfigParser()
config.read('config.ini')

# MySQLの設定
username = config['DATABASE']['USERNAME']
password = config['DATABASE']['PASSWORD']
hostname = config['DATABASE']['HOSTNAME']
database = config['DATABASE']['DATABASE']

engine = create_engine(f"mysql+pymysql://{username}:{password}@{hostname}/{database}")

# BERTの設定
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

# ツイートをデータベースから取得
with engine.begin() as connection:
    metadata = MetaData()
    metadata.bind = engine
    tweet_table = Table('tweets', metadata, autoload_with=engine)
    s = select(tweet_table.c.id, tweet_table.c.text)
    result = connection.execute(s).fetchall()

    for row in result:
        tweet_id, text = row
        inputs = tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)

        vector = outputs.last_hidden_state[:, 0, :].numpy().tolist()
        # ベクトルをデータベースに保存
        vector_table = Table("tweet_vectors", metadata, autoload_with=engine)
        stmt = insert(vector_table).values(tweet_id=tweet_id, vector=vector)
        connection.execute(stmt)

任意のtweetと類似するtweetの検索

与えられた入力テキストに最も類似したtweetをデータベースから検索する。

具体的には以下の手順で処理を行う。

  1. tweet_vectorsからすべてのベクトル表現を取得
  2. 検証用のテキストメッセージverification_tweets.csvを読み込み、それぞれに次の処理を行う
  • 入力テキストをベクトルに変換
  • 入力テキストのベクトルとデータベース内の各ツイートのベクトルのコサイン類似度を計算
  • 最も類似度が高いツイートを検索し、そのテキストをtweetsから取得
  1. 入力テキストと最も類似すると計算されたテキストを表示
import configparser
import pandas as pd
from sqlalchemy import select
from sklearn.metrics.pairwise import cosine_similarity
from transformers import BertModel, BertTokenizer
from sqlalchemy import create_engine, Table, MetaData

# 設定を読み込む
config = configparser.ConfigParser()
config.read('config.ini')

# MySQLの設定
username = config['DATABASE']['USERNAME']
password = config['DATABASE']['PASSWORD']
hostname = config['DATABASE']['HOSTNAME']
database = config['DATABASE']['DATABASE']

engine = create_engine(f"mysql+pymysql://{username}:{password}@{hostname}/{database}")

# BERTの設定
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

# データベースからベクトルを取得
metadata = MetaData()
metadata.bind = engine
vector_table = Table("tweet_vectors", metadata, autoload_with=engine)
s = select(vector_table.c.tweet_id, vector_table.c.vector)
with engine.begin() as connection:
    result = connection.execute(s).fetchall()

# 'verification_tweets.csv'を読み込む
df = pd.read_csv('verification_tweets.csv')

# DataFrameのtext列を1件ずつ取り出す
for input_text in df['text']:
    # 入力テキストをベクトルに変換
    inputs = tokenizer(input_text, return_tensors="pt")
    outputs = model(**inputs)
    input_vector = outputs.last_hidden_state[:, 0, :].detach().numpy().tolist()

    max_similarity = 0
    most_similar_tweet = None

    # 各ベクトルと入力テキストのベクトルとの間で類似度を計算
    for row in result:
        tweet_id, vector = row
        similarity = cosine_similarity(input_vector, vector)

        # 最も類似度が高いツイートを見つける
        if similarity[0][0] > max_similarity:
            max_similarity = similarity[0][0]
            most_similar_tweet = tweet_id

    # 最も類似度が高いツイートの本文を取得
    tweet_table = Table("tweets", metadata, autoload_with=engine)
    s = select(tweet_table.c.text).where(tweet_table.c.id == most_similar_tweet)
    with engine.begin() as connection:
        most_similar_tweet_text = connection.execute(s).scalar_one()

    # 最も類似度が高いツイートと入力に使ったテキストを表示
    print(f"Input text: {input_text}")
    print(f"Most similar tweet: {most_similar_tweet_text}")

結果の評価

今回は、100件のテキストのベクトル表現をデータベースに保存し、その中から入力テキストと最も類似したものを検索しました。

試した入力テキストは5件で、結果は次のとおりです。

入力テキスト 期待した検索結果 実際の検索結果
一緒に働く田中が冗談を言って、場の雰囲気が良くなった。 同僚のジョークで一日が明るくなった。笑いは最高の緊張解消法。 犬の喜びそうな顔を見ると、一日の疲れが吹っ飛ぶ。
久しぶりに砂浜に行ってリフレッシュできた。 素晴らしいビーチでリラックス。これぞバケーションの極み。 新しいプログラミング言語を学び始めた。チャレンジは成長につながる。
昨日読んだ推理小説の最後に驚くべきどんでん返しがあった。 昨日観た映画が素晴らしかった。深いメッセージが心に残った。 昨日観た映画が素晴らしかった。深いメッセージが心に残った。
不具合の修正ばかりで1日が終わったが、やりきると清々しい。 コードをデバッグするのは大変だけど、問題を解決したときの達成感は格別。 猫の毛づくろいを見ていると癒される。彼らの日常が特別。
サンスベリアを部屋に置くと有害物質を除去してくれるらしい。 観葉植物が部屋の空気を浄化してくれる。自然の恵みに感謝。 パズルゲームの新レベルをクリア。脳を活性化させる感じが好き。

期待通りのテキストが取得できたのは5件中1件だけで、残り4件は意味的にまったく近くなさそうなものが取得されました。テキストメッセージをベクトルに変換するだけ、という一般的と思えるタスクだったため、ファインチューニングなしでもある程度の性能が出るのでは、と期待していただけに残念な結果でした。

今後に向けた反省

類似したメッセージの検索性能をもう少し向上するために、調べたところ次のことができそうでした。

  • 日本語テキストで事前に訓練されたBERTモデルを使う
      - Hugging Faceというプラットフォームで公開されているらしい
      - BERTを改良したモデルにRoBERTa、ALBERT、ELECTRAがあり、特にELECTRAが良いらしい
      - 今回はBERTのbert-base-uncasedというモデルを使用
  • ファインチューニングでモデルを用途に合わせて調整する
  • 一般的に訓練データは数千〜数十万ほどのサイズが必要
  • モデルの訓練にはGoogle CloudのNVIDIA Tesla V100やAWSのp3.2xlargeインスタンスが使われる

訓練には時間と費用がたくさんかかるらしく、あまり気軽には試せません。ということで、次は日本語テキストで事前に訓練されたELECTRAのモデルやOpenAIのEmbeddings APIを試してみたいと考えています。

また、今回はMySQLを使用しましたが、ベクトルの類似性を高速に検索するための専用のデータベース(ベクトルデータベース)というものがあるため、こちらも使ってみたいです。

Discussion