ModernBERTを用いた固有表現抽出をCoNLL-2003をファインチューニングして試す
ModernBERTを用いた固有表現抽出を試してみます。
ModernBERT自体は以下の記事を書きました。
ModernBERT-baseをCoNLL-2003データセットを用いてファインチューニングしていきます。
具体的には、ModernBertForTokenClassificationをCoNLL-2003データセットでファインチューニングします。
Google Colabでは、T4 GPUを1台用いました。
データセット:CoNLL-2003
今回は、CoNLL-2003を用います。
CoNLL-2003は、1996年8月から1997年8月までのロイターのニュース記事で、固有表現として「PER(人物)」、「ORG(組織)」、「LOC(場所)」、「MISC(その他)」が付与されたデータセットです。
スペース区切りの4列で、単語、品詞、チャンク、固有表現の順に並んだデータとなっています。
U.N. NNP I-NP I-ORG
official NN I-NP O
Ekeus NNP I-NP I-PER
heads VBZ I-VP O
for IN I-PP O
Baghdad NNP I-NP I-LOC
. . O O
...
固有表現は、IOB2形式でタグ付けされています。
先頭の文字のIOには、意味があり、Bが付与されると始まりを表し、IはBからの続きを表します。
Oは固有表現外であることを表します。
CoNLL-2003は、固有表現抽出の評価のためのデータセットとして利用できます。
学習データ、検証データ、テストデータは以下の通りです。
Dataset | Articles | Sentences | Tokens | LOC | MISC | ORG | PER |
---|---|---|---|---|---|---|---|
Training set | 946 | 14,987 | 203,621 | 7140 | 3438 | 6321 | 6600 |
Development set | 216 | 3,466 | 51,362 | 1837 | 922 | 1341 | 1842 |
Test set | 231 | 3,684 | 46,435 | 1668 | 702 | 1661 | 1617 |
Google ColabでModernBERTをファインチューニング
ここからは実際にGoogle Colabにて、ModernBERTをファインチューニングしていきます。
パッケージのインストール
各種パッケージをimportしておきます。
# https://huggingface.co/answerdotai/ModernBERT-base Usage
!pip install git+https://github.com/huggingface/transformers.git
!pip install datasets
from transformers import ModernBertForTokenClassification, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import load_dataset
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
tokenizerの定義
tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-base')
データセットのロード
CoNLL-2003をダウンロードします。
raw_datasets = load_dataset('conll2003')
print(f"{raw_datasets}\n")
DatasetDict({
train: Dataset({
features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
num_rows: 14041
})
validation: Dataset({
features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
num_rows: 3250
})
test: Dataset({
features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
num_rows: 3453
})
})
ラベルの確認
データセットに含まれるラベルを確認しますl。
ner_feature = raw_datasets["train"].features["ner_tags"]
ner_feature
PER、ORG、LOC、MISCの4種類のラベルが定義されています。
Sequence(feature=ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)
label_names = ner_feature.feature.names
label_names
['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
ラベルリストの拡張
特殊なトークンには「-100」というラベルを付けます。これは、特殊なトークンであるCLSやSEPに対して実施します。CLSは、文章全体を捉えた表現であり文章の分散表現として用いられます。例えば、CLSに対応するBERTの出力を分類器に入力して分類問題を解くような使い方ができます。SEPは、文章のペアの境界を示す役割や入力の終わりを示す役割があります。
これらの特殊トークンは、これから使う損失関数(クロスエントロピー)で無視される数であるため、-100を付与します。また、B-をI-に置き換え、同じタグとして扱います。
def align_labels_with_tokens(labels, word_ids):
new_labels = []
current_word = None
for word_id in word_ids:
if word_id != current_word:
# Start of a new word!
current_word = word_id
label = -100 if word_id is None else labels[word_id]
new_labels.append(label)
elif word_id is None:
# Special token
new_labels.append(-100)
else:
# Same word as previous token
label = labels[word_id]
# If the label is B-XXX we change it to I-XXX
if label % 2 == 1:
label += 1
new_labels.append(label)
return new_labels
最初の文章で意図したものになっているか確認します。
labels = raw_datasets["train"][0]["ner_tags"]
inputs = tokenizer(raw_datasets["train"][0]["tokens"], is_split_into_words=True)
word_ids = inputs.word_ids()
print(labels)
print(align_labels_with_tokens(labels, word_ids))
[3, 0, 7, 0, 0, 0, 7, 0, 0]
[-100, 3, 0, 0, 7, 0, 0, 0, 0, 7, 0, 0, 0, -100]
データセット全体の前処理
データセット全体の前処理として、すべての入力をトークン化し、すべてのラベルに対して align_labels_with_tokens()
を適用します。
def tokenize_and_align_labels(examples):
tokenized_inputs = tokenizer(
examples["tokens"], truncation=True, is_split_into_words=True
)
all_labels = examples["ner_tags"]
new_labels = []
for i, labels in enumerate(all_labels):
word_ids = tokenized_inputs.word_ids(i)
new_labels.append(align_labels_with_tokens(labels, word_ids))
tokenized_inputs["labels"] = new_labels
return tokenized_inputs
tokenized_datasets = raw_datasets.map(
tokenize_and_align_labels,
batched=True,
remove_columns=raw_datasets["train"].column_names,
)
データ照合
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
評価指標の定義
トークン分類予測の評価には、seqevalを用います。seqevalは固有表現抽出の評価の際に役立ちます。
!pip install seqeval evaluate
import evaluate
metric = evaluate.load("seqeval")
最初の学習サンプルを用いて、動作確認を行います。
labels = raw_datasets["train"][0]["ner_tags"]
labels = [label_names[i] for i in labels]
labels
['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']
インデックス2の値を変更して、擬似的な予測結果を確認します。
predictions = labels.copy()
predictions[2] = "O"
metric.compute(predictions=[predictions], references=[labels])
各エンティティの精度、再現率、F1スコア、そして総合的なスコアが出力されます。
{'MISC': {'precision': 1.0,
'recall': 0.5,
'f1': 0.6666666666666666,
'number': 2},
'ORG': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1},
'overall_precision': 1.0,
'overall_recall': 0.6666666666666666,
'overall_f1': 0.8,
'overall_accuracy': 0.8888888888888888}
compute_metricsを定義します。
import numpy as np
def compute_metrics(eval_preds):
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-1)
# Remove ignored index (special tokens) and convert to labels
true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
true_predictions = [
[label_names[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
return {
"precision": all_metrics["overall_precision"],
"recall": all_metrics["overall_recall"],
"f1": all_metrics["overall_f1"],
"accuracy": all_metrics["overall_accuracy"],
}
モデルとtokenizerのダウンロード
モデルとして、ModernBertForTokenClassificationを用います。
ModernBertForTokenClassificationは、ModernBERTの上にtoken classification用のヘッドをつけたモデルです。
各トークンに対して、固有表現を付与するために、今回は本モデルを用います。
id2label = {i: label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}
model = ModernBertForTokenClassification.from_pretrained(
'answerdotai/ModernBERT-base',
id2label=id2label,
label2id=label2id,
)
モデルのラベル数を確認しておきます。
model.config.num_labels
9
WandBの無効化
WandBを無効化しておきます。
!env WANDB_DISABLED=true
!pip uninstall wandb -y
学習設定
args = TrainingArguments(
"modernbert-finetuned-ner",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
num_train_epochs=3,
weight_decay=0.01,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
)
学習
約16分ほどで学習が完了しました。
trainer.train()
| Epoch | Training Loss | Validation Loss | Precision | Recall | F1 | Accuracy |
|-------|---------------|-----------------|-----------|---------|----------|----------|
| 1 | 0.066600 | 0.062593 | 0.911503 | 0.930831 | 0.921066 | 0.984028 |
| 2 | 0.019600 | 0.064478 | 0.919039 | 0.939919 | 0.929362 | 0.985220 |
| 3 | 0.008100 | 0.072610 | 0.918049 | 0.940761 | 0.929266 | 0.985140 |
TrainOutput(global_step=5268, training_loss=0.03176335226701984, metrics={'train_runtime': 991.6589, 'train_samples_per_second': 42.477, 'train_steps_per_second': 5.312, 'total_flos': 1376902220737806.0, 'train_loss': 0.03176335226701984, 'epoch': 3.0})
評価
trainer.evaluate()
{'eval_loss': 0.07261043787002563,
'eval_precision': 0.9180489407127607,
'eval_recall': 0.9407606866374958,
'eval_f1': 0.9292660626714321,
'eval_accuracy': 0.9851404505542533,
'eval_runtime': 15.7061,
'eval_samples_per_second': 206.926,
'eval_steps_per_second': 25.913,
'epoch': 3.0}
予測結果の確認
# 実際の予測結果を確認する
predicted_labels = []
# テストデータセットから10個をサンプリング
for i in range(0, 10):
# テストデータセットからサンプルを1つ選択
sample = tokenized_datasets['test'][i]
predictions = trainer.predict([sample]) # sampleをリストで渡す
# 予測結果を解釈
predicted_label = predictions.predictions.argmax(-1)[0]
# special tokenを除外
true_predictions = [
label_names[p] for (p, l) in zip(predicted_label, sample["labels"]) if l != -100
]
true_labels = [
label_names[l] for (p, l) in zip(predicted_label, sample["labels"]) if l != -100
]
tokens = tokenizer.convert_ids_to_tokens(sample['input_ids'])
# 元の文章と予測結果を出力
print("元の文章:", ' '.join(tokens[1:-1]))
print("予測結果:", true_predictions) # special tokenを除外した予測結果を出力
print("正解ラベル:", " ".join(true_labels))
print("正誤:", " ".join(["○" if p == l else "×" for p, l in zip(true_predictions, true_labels)]))
print("---")
元の文章と予測結果および正解ラベル、正誤を出力しました。
例えば、最初の文章であれば、「JAPAN」に対して、B-LOC
, I-LOC
, I-LOC
が正しく付与されています。
元の文章: SO CC ER - J AP AN GET LU CK Y WIN , CH INA IN S UR PR ISE DE FE AT .
予測結果: ['O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
正解ラベル: O O O O B-LOC I-LOC I-LOC O O O O O O B-PER I-PER O O O O O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ × × ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: N ad im L ad ki
予測結果: ['B-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER']
正解ラベル: B-PER I-PER I-PER I-PER I-PER I-PER
正誤: ○ ○ ○ ○ ○ ○
---
元の文章: AL - AIN , United Arab Em ir ates 1996 - 12 - 06
予測結果: ['B-LOC', 'I-LOC', 'I-LOC', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O']
正解ラベル: B-LOC I-LOC I-LOC O B-LOC I-LOC I-LOC I-LOC I-LOC O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: Japan b egan the def ence of their Asian C up title with a l ucky 2 - 1 win against Sy ria in a Group C ch ampionship match on Friday .
予測結果: ['B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'O', 'O', 'B-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O']
正解ラベル: B-LOC O O O O O O O B-MISC I-MISC I-MISC O O O O O O O O O O B-LOC I-LOC O O O O O O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ × × ○ ○ ○ ○ ○ ○
---
元の文章: But China saw their l uck des ert them in the second match of the group , cr ashing to a sur prise 2 - 0 de feat to new com ers U z bek istan .
予測結果: ['O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O']
正解ラベル: O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-LOC I-LOC I-LOC I-LOC O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: China controlled most of the match and saw several ch ances miss ed until the 78 th minute when U z bek stri ker I gor Sh kv yr in took advant age of a m isd irected def ensive header to l ob the ball over the adv ancing Chinese keeper and into an empty net .
予測結果: ['B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'O', 'B-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
正解ラベル: B-LOC O O O O O O O O O O O O O O O O O O B-MISC I-MISC I-MISC O O B-PER I-PER I-PER I-PER I-PER I-PER O O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: O leg S hat sk iku made sure of the win in injury time , h itting an un st opp able left foot shot from just outside the area .
予測結果: ['B-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
正解ラベル: B-PER I-PER I-PER I-PER I-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: The former Soviet re public was playing in an Asian C up finals tie for the first time .
予測結果: ['O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
正解ラベル: O O B-MISC O O O O O O B-MISC I-MISC I-MISC O O O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: Despite winning the Asian G ames title two years ago , U z bek istan are in the finals as outs iders .
予測結果: ['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
正解ラベル: O O O B-MISC I-MISC I-MISC O O O O O B-LOC I-LOC I-LOC I-LOC O O O O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: Two go als from def ensive errors in the last six minutes allowed Japan to come from behind and collect all three points from their opening me eting against Sy ria .
予測結果: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'O']
正解ラベル: O O O O O O O O O O O O O B-LOC O O O O O O O O O O O O O O O B-LOC I-LOC O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
まとめ
今回は、ModernBERTをCoNLL-2003データセットでファインチューニングし、固有表現抽出を試してみました。
学習と評価のコードは以下を参考にさせていただきました。
最後までお読み頂きありがとうございました。本記事が参考になれば、幸いです。
Discussion