類似記事のグルーピングを ChatGPT と GNN でやってみた
シンプルフォームのエンジニアの杉です。本記事は SimpleForm Advent Calendar 2023 の 11 日目の記事です。
この記事ではニュースなどの文章に対して類似記事のグルーピングを ChatGPT と Graph Neural Network ( GNN )を用いて試してみたことについて書かせていただきます。
背景
世の中には多くのニュースが存在します。1 日のうちに作成される記事を全て読むということは難しいです。これらのニュースを要約やグルーピングすることができればより多くの記事を読むことが可能になります。
本記事では ChatGPT と GNN を用いたニュースのグルーピングに挑戦したことについて報告をします。
構成
今回試してみる内容は以下のような流れになっています。
- 文章を ChatGPT を用いて要約・カテゴリ検出・キーワード抽出を行う
- 要約文は BERT を用いてベクトル化する
- ベクトル化された要約文・カテゴリ・キーワードを用いてグラフに落とし込む
- グラフのリンク予測を行い、どのグループと繋がっているかを予測する
評価時には、要約ベクトルの類似度が高いものを10件抽出し、その中でリンク予測を行いました。
Graph Neural Network ( GNN )とは、ニューラルネットワークのインプットデータがグラフで表現されたデータであるものを指します。グラフで表現をすることで、繋がりを表現することや異なる種類のノード・エッジを一つのデータの中で表現することが可能になります。
データは ChatGPT を用いて生成をしました。仕組みとしてうまく機能するかを確認することが目的なため、今回は 1000 件のデータを生成しました。約 5 記事程度が同じグループの記事となっています。
1. 要約
そのままの文章を用いてベクトル化することも可能ですが、文章量などを揃えることで類似度の判定を高めるのではと思い、今回は要約を取り入れました。
例えば、以下の 2 つの記事を ChatGPT を用いて生成しました。
記事 A
本日、電車の遅延原因が驚くべきものとなりました。都内での列車において、なんとたぬきによる侵入により一時停止。乗客には大きな驚きと笑いが広がりました。当局は異例の事態に戸惑いつつも、迅速な対応を心掛けています。
記事 B
混沌とした朝、都内の通勤電車が予測不能な遅延に見舞われました。その原因は、驚くべきことに、駅構内に突如現れたたぬきでした。急停車は一時的な騒ぎを巻き起こし、通勤者たちは驚きと笑いの渦に包まれました。鉄道会社は、動物保護団体と協力し、たぬきを安全な場所へ移動させ、現場の復旧作業を開始しました。
こちらを 100 文字程度に要約をしたところ、以下のようになりました。
記事 A (要約後)
都内の列車が本日、たぬきの侵入により一時停止。驚きと笑いが広がり、当局は異例の事態に戸惑いつつも、対応に迅速に取り組んでいます。
記事 B (要約後)
都内の通勤電車が混乱の朝に、たぬきの出現で予測不能な遅延。急停車に通勤者は驚きと笑い、鉄道会社は動物保護団体と協力し、たぬきを安全な場所へ移動させ、復旧作業を始めました。
これらをそれぞれ BERT でベクトル化をし、類似度を確認しました。類似度の判定にはコサイン類似度を使用しました。コサイン類似度は 2 つのベクトルの向きがどのくらい似ているかを測る指標のひとつです。結果として 0.99400365
から 0.9980706
に向上しました。
要約をする前から類似度はかなり高くでているので要約をしてもあまり変わりがないという可能性もありますが、少しの向上は確認できたためこのまま進めていきます。
要約は ChatGPT を用いて行いました。 ChatGPT は他の特徴量抽出でも使用したため、特徴量の抽出の項目で詳細を書かせていただきます。
2. 特徴量の抽出
今回使用している特徴量は 3 つです。
1 つ目は要約文を BERT を用いてベクトル化をしたものです。以下のようにして、ベクトル化を行いました。
import torch
import transformers
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = 'cl-tohoku/bert-base-japanese-whole-word-masking'
tokenizer = transformers.BertJapaneseTokenizer.from_pretrained(model_name)
bert_model = transformers.BertModel.from_pretrained(model_name).to(device)
max_len = 128
def vectorize(text: str):
text_encode = tokenizer.encode(text)
len_text_encode = len(text_encode)
if len_text_encode >= max_len:
inputs = len_text_encode[:max_len]
else:
inputs = text_encode + [0] * (max_len - len_text_encode)
inputs_tensor = torch.tensor([inputs], dtype=torch.long).to(device)
seq_out = bert_model(inputs_tensor)[0]
pooled_out = bert_model(inputs_tensor)[1]
if device == "cuda":
return seq_out[0][0].cpu().detach().numpy()
else:
return seq_out[0][0].detach().numpy()
2 つ目はカテゴリです。今回は ChatGPT を用いて要約文を生成する際にカテゴリもつけてもらうようにしました。
カテゴリは多くなりすぎないように以下の 5 つに設定をしました。
Economy
, Entertainment
, Sports
, Weather
, Other
3 つ目は重要キーワードです。こちらも ChatGPT を用いて行なっています。以下のようなコードで取得を行いました。
import openai
def get_summarize_text(self, text: str):
response = openai.chat.completions.create(
model="gpt-3.5-turbo-1106",
messages=[{"role": "user", "content": text}],
tools=[
{
"type": "function",
"function": {
"name": "test_json",
"description": "入力された文章を要約し、JSONとして処理します。要約時に人名や法人名、年齢や日付などニュースを特定することに重要な要素は削らないでください。",
"parameters": {
"type": "object",
"properties": {
"summary_text": {
"type": "string",
"description": "100文字程度で要約した文章",
},
"keywords": {
"type": "string",
"description": "文章中で重要なキーワードをカンマ区切りで抽出してください。最大10個です。",
},
"category": {
"type": "string",
"description": "文章のカテゴリを次の中から1つ選んでください。[Economy, Entertainment, Sports, Weather, Other]",
},
},
"required": ["summary_text", "keywords", "category"],
},
},
}
],
)
return response
例えば、以下の記事では
りんご農園で大量のりんごが不正に持ち出され、被害額は数百万円に達しました。犯人は逃走中で、警察が捜査を進めています。
被害者
、りんご農家
、警察
、逃走
、不正持ち出し
の 5 キーワードが重要単語として抽出されています。
3. GNN
特徴量をそれぞれ変換をし、グラフとして保持します。
今回使用したグラフは以下のような構造になっています。
ノードとして
- 記事をまとめたグループ ID
- 記事 ID (要約のベクトル)
- キーワード
- カテゴリ
とし、
- グループ ID と記事 ID
- 記事 ID とキーワード
- 記事 ID とカテゴリ
にエッジをはっています。
このグラフの中でグループと記事の間にエッジが存在するかを予測しました。GNN モデルは以下のものを使用しました。
class GNN(Module):
def __init__(self, hidden_channels: int):
super().__init__()
self.conv1 = SAGEConv(hidden_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, hidden_channels)
def forward(self, x: Tensor, edge_index: Tensor):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
今回、 pytorch-geometric を使用しましたがドキュメントが充実しているため、こちらを見ていただけるとより理解が深まるかと思います。
結果
全体のデータの 1 割をテストデータとし評価を行った結果、 Accuracy: 0.9150 となりました。試してみるという目的においては十分な精度を得ることができたと思われます。
しかし、今回は少ないデータ数で行ったため良い結果となりましたが、実際の大規模なデータなどを使用すると精度は低くなる可能性があります。
また、今回は短い期間のニュースを使用したため時間軸を考えなくても大きな問題はありませんでしたが、広い時間でとってしまうと精度は悪化する可能性があります。
さいごに
類似記事のグルーピングをするために ChatGPT と GNN を用いてみることを試してみました。
まだ課題は残っていますが、1 つの方法として可能性がありそうということがわかりました。グラフで使用する特徴量を変更することや、リンク予測ではなくグルーピング自体を予測してみるという方法に変えてみるということも可能だと思いますので気になった方はぜひ試してみると良いかと思います。
最後まで読んでいただきありがとうございます。明日以降も SimpleForm Advent Calendar は続きますのでよろしくお願いします。
また、シンプルフォーム株式会社では大規模言語モデルに関する輪読会も開催をしています。ぜひ気になった方は覗いてみてください!
Discussion