🌊

ModernBERTを用いた固有表現抽出をCoNLL-2003をファインチューニングして試す

2025/01/05に公開

ModernBERTを用いた固有表現抽出を試してみます。
ModernBERT自体は以下の記事を書きました。

https://zenn.dev/tossy21/articles/93591442269292

ModernBERT-baseをCoNLL-2003データセットを用いてファインチューニングしていきます。
具体的には、ModernBertForTokenClassificationをCoNLL-2003データセットでファインチューニングします。

Google Colabでは、T4 GPUを1台用いました。

データセット:CoNLL-2003

今回は、CoNLL-2003を用います。

CoNLL-2003は、1996年8月から1997年8月までのロイターのニュース記事で、固有表現として「PER(人物)」、「ORG(組織)」、「LOC(場所)」、「MISC(その他)」が付与されたデータセットです。

https://www.clips.uantwerpen.be/conll2003/ner/

スペース区切りの4列で、単語、品詞、チャンク、固有表現の順に並んだデータとなっています。

   U.N.         NNP  I-NP  I-ORG 
   official     NN   I-NP  O 
   Ekeus        NNP  I-NP  I-PER 
   heads        VBZ  I-VP  O 
   for          IN   I-PP  O 
   Baghdad      NNP  I-NP  I-LOC 
   .            .    O     O 
   ...

固有表現は、IOB2形式でタグ付けされています。
先頭の文字のIOには、意味があり、Bが付与されると始まりを表し、IはBからの続きを表します。
Oは固有表現外であることを表します。

https://en.wikipedia.org/wiki/Inside–outside–beginning_(tagging)

CoNLL-2003は、固有表現抽出の評価のためのデータセットとして利用できます。

学習データ、検証データ、テストデータは以下の通りです。

Dataset Articles Sentences Tokens LOC MISC ORG PER
Training set 946 14,987 203,621 7140 3438 6321 6600
Development set 216 3,466 51,362 1837 922 1341 1842
Test set 231 3,684 46,435 1668 702 1661 1617

Google ColabでModernBERTをファインチューニング

ここからは実際にGoogle Colabにて、ModernBERTをファインチューニングしていきます。

パッケージのインストール

各種パッケージをimportしておきます。

# https://huggingface.co/answerdotai/ModernBERT-base Usage
!pip install git+https://github.com/huggingface/transformers.git
!pip install datasets
from transformers import ModernBertForTokenClassification, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import load_dataset
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

tokenizerの定義

tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-base')

データセットのロード

CoNLL-2003をダウンロードします。

raw_datasets = load_dataset('conll2003')

print(f"{raw_datasets}\n")
DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

ラベルの確認

データセットに含まれるラベルを確認しますl。

ner_feature = raw_datasets["train"].features["ner_tags"]
ner_feature

PER、ORG、LOC、MISCの4種類のラベルが定義されています。

Sequence(feature=ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)
label_names = ner_feature.feature.names
label_names
['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

ラベルリストの拡張

特殊なトークンには「-100」というラベルを付けます。これは、特殊なトークンであるCLSやSEPに対して実施します。CLSは、文章全体を捉えた表現であり文章の分散表現として用いられます。例えば、CLSに対応するBERTの出力を分類器に入力して分類問題を解くような使い方ができます。SEPは、文章のペアの境界を示す役割や入力の終わりを示す役割があります。

これらの特殊トークンは、これから使う損失関数(クロスエントロピー)で無視される数であるため、-100を付与します。また、B-をI-に置き換え、同じタグとして扱います。

def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            # If the label is B-XXX we change it to I-XXX
            if label % 2 == 1:
                label += 1
            new_labels.append(label)

    return new_labels

最初の文章で意図したものになっているか確認します。

labels = raw_datasets["train"][0]["ner_tags"]
inputs = tokenizer(raw_datasets["train"][0]["tokens"], is_split_into_words=True)
word_ids = inputs.word_ids()
print(labels)
print(align_labels_with_tokens(labels, word_ids))
[3, 0, 7, 0, 0, 0, 7, 0, 0]
[-100, 3, 0, 0, 7, 0, 0, 0, 0, 7, 0, 0, 0, -100]

データセット全体の前処理

データセット全体の前処理として、すべての入力をトークン化し、すべてのラベルに対して align_labels_with_tokens() を適用します。

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, is_split_into_words=True
    )
    all_labels = examples["ner_tags"]
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))

    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs
tokenized_datasets = raw_datasets.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
)

データ照合

from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

評価指標の定義

トークン分類予測の評価には、seqevalを用います。seqevalは固有表現抽出の評価の際に役立ちます。

https://github.com/chakki-works/seqeval

!pip install seqeval evaluate

import evaluate

metric = evaluate.load("seqeval")

最初の学習サンプルを用いて、動作確認を行います。

labels = raw_datasets["train"][0]["ner_tags"]
labels = [label_names[i] for i in labels]
labels
['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']

インデックス2の値を変更して、擬似的な予測結果を確認します。

predictions = labels.copy()
predictions[2] = "O"
metric.compute(predictions=[predictions], references=[labels])

各エンティティの精度、再現率、F1スコア、そして総合的なスコアが出力されます。

{'MISC': {'precision': 1.0,
  'recall': 0.5,
  'f1': 0.6666666666666666,
  'number': 2},
 'ORG': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1},
 'overall_precision': 1.0,
 'overall_recall': 0.6666666666666666,
 'overall_f1': 0.8,
 'overall_accuracy': 0.8888888888888888}

compute_metricsを定義します。

import numpy as np

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

モデルとtokenizerのダウンロード

モデルとして、ModernBertForTokenClassificationを用います。
ModernBertForTokenClassificationは、ModernBERTの上にtoken classification用のヘッドをつけたモデルです。
各トークンに対して、固有表現を付与するために、今回は本モデルを用います。

id2label = {i: label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}
model = ModernBertForTokenClassification.from_pretrained(
    'answerdotai/ModernBERT-base',
    id2label=id2label,
    label2id=label2id,
)

モデルのラベル数を確認しておきます。

model.config.num_labels
9

WandBの無効化

WandBを無効化しておきます。

!env WANDB_DISABLED=true
!pip uninstall wandb -y

学習設定

args = TrainingArguments(
    "modernbert-finetuned-ner",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
)
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)

学習

約16分ほどで学習が完了しました。

trainer.train()
| Epoch | Training Loss | Validation Loss | Precision | Recall  | F1       | Accuracy |
|-------|---------------|-----------------|-----------|---------|----------|----------|
| 1     | 0.066600      | 0.062593        | 0.911503  | 0.930831 | 0.921066 | 0.984028 |
| 2     | 0.019600      | 0.064478        | 0.919039  | 0.939919 | 0.929362 | 0.985220 |
| 3     | 0.008100      | 0.072610        | 0.918049  | 0.940761 | 0.929266 | 0.985140 |
TrainOutput(global_step=5268, training_loss=0.03176335226701984, metrics={'train_runtime': 991.6589, 'train_samples_per_second': 42.477, 'train_steps_per_second': 5.312, 'total_flos': 1376902220737806.0, 'train_loss': 0.03176335226701984, 'epoch': 3.0})

評価

trainer.evaluate()
{'eval_loss': 0.07261043787002563,
 'eval_precision': 0.9180489407127607,
 'eval_recall': 0.9407606866374958,
 'eval_f1': 0.9292660626714321,
 'eval_accuracy': 0.9851404505542533,
 'eval_runtime': 15.7061,
 'eval_samples_per_second': 206.926,
 'eval_steps_per_second': 25.913,
 'epoch': 3.0}

予測結果の確認

# 実際の予測結果を確認する
predicted_labels = []

# テストデータセットから10個をサンプリング
for i in range(0, 10):
    # テストデータセットからサンプルを1つ選択
    sample = tokenized_datasets['test'][i]

    predictions = trainer.predict([sample])  # sampleをリストで渡す

    # 予測結果を解釈
    predicted_label = predictions.predictions.argmax(-1)[0]

    # special tokenを除外
    true_predictions = [
        label_names[p] for (p, l) in zip(predicted_label, sample["labels"]) if l != -100
    ]
    true_labels = [
        label_names[l] for (p, l) in zip(predicted_label, sample["labels"]) if l != -100
    ]

    tokens = tokenizer.convert_ids_to_tokens(sample['input_ids'])
    
    # 元の文章と予測結果を出力
    print("元の文章:", ' '.join(tokens[1:-1]))
    print("予測結果:", true_predictions)  # special tokenを除外した予測結果を出力
    print("正解ラベル:", " ".join(true_labels))
    print("正誤:", " ".join(["○" if p == l else "×" for p, l in zip(true_predictions, true_labels)]))
    print("---")

元の文章と予測結果および正解ラベル、正誤を出力しました。
例えば、最初の文章であれば、「JAPAN」に対して、B-LOC, I-LOC, I-LOCが正しく付与されています。

元の文章: SO CC ER - J AP AN GET LU CK Y WIN , CH INA IN S UR PR ISE DE FE AT .
予測結果: ['O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
正解ラベル: O O O O B-LOC I-LOC I-LOC O O O O O O B-PER I-PER O O O O O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ × × ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: N ad im L ad ki
予測結果: ['B-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER']
正解ラベル: B-PER I-PER I-PER I-PER I-PER I-PER
正誤: ○ ○ ○ ○ ○ ○
---
元の文章: AL - AIN , United Arab Em ir ates 1996 - 12 - 06
予測結果: ['B-LOC', 'I-LOC', 'I-LOC', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O']
正解ラベル: B-LOC I-LOC I-LOC O B-LOC I-LOC I-LOC I-LOC I-LOC O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: Japan b egan the def ence of their Asian C up title with a l ucky 2 - 1 win against Sy ria in a Group C ch ampionship match on Friday .
予測結果: ['B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'O', 'O', 'B-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O']
正解ラベル: B-LOC O O O O O O O B-MISC I-MISC I-MISC O O O O O O O O O O B-LOC I-LOC O O O O O O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ × × ○ ○ ○ ○ ○ ○
---
元の文章: But China saw their l uck des ert them in the second match of the group , cr ashing to a sur prise 2 - 0 de feat to new com ers U z bek istan .
予測結果: ['O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O']
正解ラベル: O B-LOC O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-LOC I-LOC I-LOC I-LOC O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: China controlled most of the match and saw several ch ances miss ed until the 78 th minute when U z bek stri ker I gor Sh kv yr in took advant age of a m isd irected def ensive header to l ob the ball over the adv ancing Chinese keeper and into an empty net .
予測結果: ['B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'O', 'B-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
正解ラベル: B-LOC O O O O O O O O O O O O O O O O O O B-MISC I-MISC I-MISC O O B-PER I-PER I-PER I-PER I-PER I-PER O O O O O O O O O O O O O O O O O O O O B-MISC O O O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: O leg S hat sk iku made sure of the win in injury time , h itting an un st opp able left foot shot from just outside the area .
予測結果: ['B-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
正解ラベル: B-PER I-PER I-PER I-PER I-PER I-PER O O O O O O O O O O O O O O O O O O O O O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: The former Soviet re public was playing in an Asian C up finals tie for the first time .
予測結果: ['O', 'O', 'B-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
正解ラベル: O O B-MISC O O O O O O B-MISC I-MISC I-MISC O O O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: Despite winning the Asian G ames title two years ago , U z bek istan are in the finals as outs iders .
予測結果: ['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
正解ラベル: O O O B-MISC I-MISC I-MISC O O O O O B-LOC I-LOC I-LOC I-LOC O O O O O O O O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---
元の文章: Two go als from def ensive errors in the last six minutes allowed Japan to come from behind and collect all three points from their opening me eting against Sy ria .
予測結果: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'O']
正解ラベル: O O O O O O O O O O O O O B-LOC O O O O O O O O O O O O O O O B-LOC I-LOC O
正誤: ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○
---

まとめ

今回は、ModernBERTをCoNLL-2003データセットでファインチューニングし、固有表現抽出を試してみました。

学習と評価のコードは以下を参考にさせていただきました。
https://huggingface.co/learn/nlp-course/ja/chapter7/2

最後までお読み頂きありがとうございました。本記事が参考になれば、幸いです。

GitHubで編集を提案

Discussion