ModernBERTをIMDb映画レビューコメントでファインチューニングして、テキスト分類を試す
ModernBERTを用いたテキスト分類を試してみます。
ModernBERT自体は以下の記事を書きました。
ModernBERT-baseをIMDbデータセットを用いてファインチューニングします。
Colabでは、T4 GPUを1台用いました。
データセット
今回は、IMDbデータセットを用います。
IMDbデータセットは、映画のレビューコメントに対して「ネガティブ」、「ポジティブ」がラベル付けされた感情分析用データセットです。
そのため、IMDbデータセットは、主に二値分類の評価のために利用することができます。
学習データとテストデータそれぞれ、25,000件のレビューコメントとラベルが用意されています。
Google ColabでModernBERTをファインチューニング
ここからは実際にGoogle ColabでModernBERTをファインチューニングしていきます。
パッケージのインストール
まず初めにパッケージのインストールを行います。
transformersの最新版の4.47.1にModernBERTのリリースが含まれていないため、以下を実行します。
You can use these models directly with the transformers library. Until the next transformers release, doing so requires installing transformers from main:
# https://huggingface.co/answerdotai/ModernBERT-base Usage
!pip install git+https://github.com/huggingface/transformers.git
load_datasetを使うため、datasetsをインストールします。
!pip install datasets
各種パッケージをimportしておきます。
from transformers import ModernBertForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import load_dataset
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
モデルとtokenizerのダウンロード
ModernBERT-baseのモデルとtokenizerをダウンロードします。
model = ModernBertForSequenceClassification.from_pretrained('answerdotai/ModernBERT-base')
tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-base')
tokenize関数定義とデータセット分割
AttributeError: module 'dill' has no attribute 'PY3'
が出た時の回避策です。
!pip install dill==0.3.5.1
padding=True
は削除しています。こちらの設定と合わせました。
Next we define our Tokenizer and a preprocess function to create the input_ids, attention_mask, and token_type_ids the model nees to train. For this example, including truncation=True is enough as we'll rely on our data collation function below to put our batches into the correct shape.
https://github.com/AnswerDotAI/ModernBERT/blob/main/examples/finetune_modernbert_on_glue.ipynb
ちなみに、padding=True, truncation=True
とした場合は、trainer.train()
で以下のエラーが出ました。
RuntimeError: stack expects each tensor to be equal size, but got [159] at entry 0 and [51] at entry 1
def tokenize(batch):
return tokenizer(batch['text'], truncation=True)
train_dataset, test_dataset = load_dataset('imdb', split=['train', 'test'])
print(f"{train_dataset}\n")
train_dataset = train_dataset.map(tokenize, batched=True, batch_size=4)
test_dataset = test_dataset.map(tokenize, batched=True, batch_size=4)
train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
学習とテストのデータセットのラベル総数を確認。
今回は、train、testともにlabelが0と1の総数はそれぞれ12,500件です。
def labelcheck(dataset):
labels = dataset['label']
count_0 = np.count_nonzero(labels_np == 0)
count_1 = np.count_nonzero(labels_np == 1)
print(f"labelが0の総数: {count_0}")
print(f"labelが1の総数: {count_1}")
labelcheck(train_dataset)
labelcheck(test_dataset)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
評価指標の定義
Accuracy、Precision、Recall、F1-scoreで評価します。
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
acc = accuracy_score(labels, preds)
return {
'accuracy': acc,
'f1': f1,
'precision': precision,
'recall': recall
}
学習の引数設定
TrainingArgumentsとTrainerの設定。
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=1,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
fp16=True
)
trainer = Trainer(
model=model,
args=training_args,
processing_class=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=test_dataset
)
学習
WandBのアカウントのAPIキーの入力が求められるため、WandBを予め無効化しておきます。
!env WANDB_DISABLED=true
今回のバッチサイズの設定でGPUメモリは約12GB程度使用していました。
また、学習時間はColabのT4 GPUで約39分かかりました。
trainer.train()
...
TrainOutput(global_step=6250, training_loss=0.2615371661376953, metrics={'train_runtime': 2343.7489, 'train_samples_per_second': 10.667, 'train_steps_per_second': 2.667, 'total_flos': 9173118687459888.0, 'train_loss': 0.2615371661376953, 'epoch': 1.0})
StepごとのTraining lossは以下のような推移でした。
Step Training Loss
500 0.374700
1000 0.381800
1500 0.310200
2000 0.290600
2500 0.291500
3000 0.259300
3500 0.244500
4000 0.236700
4500 0.224700
5000 0.202300
5500 0.178900
6000 0.173600
評価
trainer.evaluate()
F1-scoreも0.9568
ということで高い分類性能が出ています。
{'eval_loss': 0.17229044437408447,
'eval_accuracy': 0.95676,
'eval_f1': 0.9568031968031968,
'eval_precision': 0.9558483033932136,
'eval_recall': 0.95776,
'eval_runtime': 414.2581,
'eval_samples_per_second': 60.349,
'eval_steps_per_second': 15.087,
'epoch': 1.0}
TensorBoardで学習状況の確認
%load_ext tensorboard
%tensorboard --logdir logs
予測結果の確認
実際の予測結果をいくつか確認してみました。
0がnegative、1がpositiveです。
This is the kind of movie which is loved by 50-year old schoolteachers and people who consider themselves aware in social issues - but really haven´t got a clue. The actors - I think all of them are amateurs - do their best, but the script is so full of cliches and stupidities that they can´t save it.<br /><br />Worst of all though is the scool cabaret that the kids are working on - brings back all your worst memories from acting classes in school. The lyric to one of the songs goes something like this in a fast-translation 2.30 in the morning: "I´m the dwarf of society, an emotionally crippled individual."<br /><br />Please!
predicted: 0
correct: 0
Ignore everyone else's comments for this movie and watch it on pay cable (like I did) or rent it. You owe it to yourself. This film is what movies are (supposed to be) all about. Hard to categorize (and God knows how this was pitched as a "high concept"!), but this is one for the angels. Check it out. What have you got to lose?
predicted: 1
correct: 1
きちんと予測できていそうです。
まとめ
今回は、ModernBERTをIMDbデータセットでファインチューニングし、テキスト分類を試してみました。
学習と評価のコードは以下を参考にさせていただきました。
最後までお読み頂きありがとうございました。本記事が参考になれば、幸いです。
Discussion
Google ColaboratoryのWandBは、環境変数
WANDB_DISABLED=true
で切っちゃう方が楽だと思うのです。Qiitaの記事に少しだけ書いておいたので、よければ御参考までに。情報共有ありがとうございます!
WANDB_DISABLED=true
で動作することを確認できました。WANDBを無効化する方が楽だと思いましたので、コードを修正したいと思います。