💭
【LLM】BERTで512トークンを超える長文を扱う方法
BERTは、テキスト分類や感情分析などの自然言語処理タスクにおいて、
革命的な成果をもたらしたモデルの一つです。
しかしながら、その使用にあたっては、致命的な制約があります。
それは、モデルが一度に処理できる文章の長さが512トークンまでという制約です。
この制約のため、BERTでは長文を扱えないと考えていらっしゃる方も多いと思います。
私もそうでした。
この記事では、BERTで512トークンを超える長文を扱う方法について解説していきます。
BERTで長文を扱う方法
BERTで長文を扱うための方法として、いくつかのアプローチが提案されています。
ここでは、最も、原始的な方法として、長文を分割して処理する方法をご紹介します。
具体的には、
- 1つの文書をBERTが扱える512トークンの範囲で複数のセグメント(=チャンクといいます)に分割します。
 - 各チャンクを個別にBERTに入力します。
 - BERTが出力したベクトルを平均します。
これにより、512トークンを超える長文でも文脈をとらえた処理を行うことができます。 
タスクの概要
上場企業等が有価証券報告書に記載している文書を分析します。
各社とも、軽く512トークンは超える文書です(5,000〜8,000トークン)
BERTで全文書のベクトルを取得し、その中で特定企業の文書と類似している文書を検索するタスクを行います。
今回、simCSEモデルとFaissライブラリーを使っていますが、こちらの書籍を参考にしたものです。
ライブラリーのインストール
!pip install datasets faiss-cpu scipy transformers[ja,torch]
モデルとトークナイザーの読み込み
# Hugging Face Hubにアップロードされた
# 教師なしSimCSEのトークナイザとエンコーダを読み込む
from transformers import AutoModel, AutoTokenizer
import numpy as np
import torch
import torch.nn.functional as F
model_name = "llm-book/bert-base-japanese-v3-unsup-simcse-jawiki"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()
データセットの読み込み
- 2023年3月期の有価証券報告書の「事業等のリスク」をデータセットとして使います。
 - 約2,500社分のデータがあります。
 
from pprint import pprint
from datasets import load_dataset
dataset = load_dataset("aino813/yuho-risk-202303",split="train")
データセットの中身を確認してみます
pprint(dataset)
<output>
Dataset({
    features: ['company', 'text', 'industry', 'code'],
    num_rows: 2551
})
datasetをpandasのデータフレームに変換します
dataset.set_format(type="pandas")
df = dataset[:]
df

GPUに移動
- まず、読み込んだモデルをGPUのメモリに移動させます。
 
device = "cuda:0"
model = model.to(device)
BERTモデルでベクトルを計算する関数
- トークナイザーで文章をID化した後、BERTモデルで読み込みベクトルを出力する関数です。
 - BERTモデルには、学習済みのsimCSEモデルを使っています。
 
def get_cls_token_vector(encoded_input, model):
    with torch.no_grad():
        outputs = model(**encoded_input)
    last_hidden_states = outputs.last_hidden_state
    cls_token_vector = last_hidden_states[:, 0, :].squeeze().cpu().numpy() 
    return cls_token_vector
512トークンを超える文書のベクトルを計算する関数
こちらが、512トークンを超える文書をBERTで扱う際のコードになります。
- まず、トークナイザーを用いて文章全てをトークンに分け、トークンIDを取得します。
 - トークンIDをchunk_size(512トークン)で分解し、chunksというリストに格納します。
 - chunksから、順番に512トークンずつ取り出し、それぞれベクトルを計算していきます。
 - 各chunkで計算したベクトルを平均します。
 
def get_cls_mean_vector(text,model):
    chunk_size = 512
    
    encoded_input = tokenizer(
        text,
        padding=False,
        truncation=False,
        max_length=8000,
        return_tensors='pt')
    input_ids = encoded_input['input_ids'][0]
    chunks = [input_ids[i:i + chunk_size] for i in range(0, len(input_ids), chunk_size)]
    cls_vectors = []
    for chunk in chunks:
        chunk_encoded = {
            'input_ids': chunk.unsqueeze(0).cuda(),  # GPUにデータを移動
            'attention_mask': (chunk != tokenizer.pad_token_id).unsqueeze(0).cuda()}  # GPUにデータを移動
        cls_vector = get_cls_token_vector(chunk_encoded, model)
        cls_vectors.append(cls_vector)  # ここでNumPy配列をそのまま追加
    average_cls_vector = np.mean(cls_vectors, axis=0)
    return average_cls_vector
テキストデータを読み込み、ベクトルを計算
- dfから該当のテキストデータを読み込み、リストに格納します。
 - 各テキストについて、ベクトルを計算していきます。
 - 時間を計算します。
 
from time import time
from tqdm import tqdm
texts =[]
for i, row in df.iterrows():
    texts.append(row["text"])
    
start_time = time()
vecs = []
for text in tqdm(texts):
    vecs.append(get_cls_mean_vector(text,model))
sec = time() -start_time
print(f"計算時間{sec:.2f}s")
faissライブラリーを使った検索
import faiss
#   vecsをNumpyに変換します
vecs_array = np.array(vecs).astype("float32")
# ベクトルを正規化します
vecs_array /= np.linalg.norm(vecs_array, axis=1)[:, np.newaxis]
# FAISS インデックスの作成
index = faiss.IndexFlatIP(768)  # IP = Inner Product
# ベクトルをインデックスに追加(正規化済み)します
index.add(vecs_array)
検索対象のインデックス番号をデータフレームから取得
company = "トヨタ自動車"
i = df[df["company"].str.contains(company)].index[0]
i
<output>
2393
検索
# 検索対象のベクトルをインデックスで指定
query_vec = vecs_array[i].reshape(1,-1)
# 検索(最も類似する上位20個のベクトルを検索)
k = 20
D, I = index.search(query_vec, k)
import pandas as pd
result = pd.DataFrame({"index":I[0],"類似度":D[0]})
result["company"] = result["index"].apply(lambda x : df.loc[x,"company"])
result

Discussion
コメント失礼します。
「データセットの読み込み」で、23年3月期の有報のMD&A(経営者による財政状態および経営成績の検討と分析)情報を取得したいのですが、hugging faceなどを参照しても引数をどう指定すればよいのかわかりません、、、
そこで、事業等のリスクの引数(aino813/yuho-risk-202303)はどこで参照されたか教えていただけないでしょうか。
コメントありがとうございます。
こちらの「事業等のリスク」は、私がEDINETから取得して、huggingfaceにアップロードしたものになります。
返信ありがとうございます。
すみません、勘違いしていました。返信を見て納得しました。ありがとうございます。