📖

EmbeddingGemma-300mでテキスト分類(LoRA)

に公開

初めに

GoogleからEmbeddingGemma-300mが出ましたので、ファインチューニングを試してみました。
EmbeddingGemma-300mでテキスト分類(ロジスティック回帰)では、accuracy = 89.6%でした。

今度はLoRAのファインチューニングを試してみます。
データセットは同様にlivedoor ニュースコーパスです。
今回の結果は、accuracy = 95.7%でした。
Gemma 3 270Mのファインチューニングより良い結果となりました。


インストール

!pip install evaluate

データ準備

livedoor ニュースコーパスを取得

%%capture
!wget https://www.rondhuit.com/download/ldcc-20140209.tar.gz
!tar xvf ldcc-20140209.tar.gz

ジャンル別に読み込む

import os
import glob

livedoor_news = {}
for folder in glob.glob("text/*"):
  if os.path.isdir(folder):
    texts = []
    for txt in glob.glob(os.path.join(folder, "*.txt")):
      text = []
      with open(txt, "r") as f:
        lines = f.readlines()
        texts.append('\n'.join([line.strip() for line in lines[3:]]))

    label = os.path.basename(folder)
    livedoor_news[label] = texts

訓練データセットとテストデータセットを作成

from datasets import Dataset, concatenate_datasets

train_dataset = None
val_dataset = None

for label, texts in livedoor_news.items():
  data = []
  for text in texts:
    data.append({"text": text, "label": label})
  dataset = Dataset.from_list(data)
  tmp_train = dataset.train_test_split(test_size=0.25, shuffle=True, seed=0)
  if train_dataset is None:
    train_dataset = tmp_train["train"]
    val_dataset = tmp_train["test"]
  else:
    train_dataset = concatenate_datasets([train_dataset, tmp_train["train"]])
    val_dataset = concatenate_datasets([val_dataset, tmp_train["test"]])

埋め込みモデル作成

from transformers import AutoTokenizer, AutoModel
import torch

model_name = "google/embeddinggemma-300m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

トークン化

classes = list(livedoor_news.keys())

def preprocess(batch):
    inputs = tokenizer(
        batch["text"],
        padding="max_length",
        truncation=True,
        max_length=2048,
        return_tensors="pt"
    )
    inputs["label"] = [classes.index(label) for label in batch["label"]]
    return inputs

train_ds = train_dataset.map(preprocess, batched=True).shuffle(seed=0)
train_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
val_ds = val_dataset.map(preprocess, batched=True)
val_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

学習

学習用クラス作成

import torch.nn as nn

class GemmaForClassification(nn.Module):
    def __init__(self, embedding_model, hidden_size, num_labels):
        super().__init__()
        self.embedding_model = embedding_model
        self.classifier = nn.Linear(hidden_size, num_labels)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.embedding_model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state
        mask = attention_mask.unsqueeze(-1).type_as(last_hidden)
        summed = (last_hidden * mask).sum(1)
        divisor = mask.sum(1).clamp(min=1e-9)
        pooled = summed / divisor

        logits = self.classifier(pooled)

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        return {"loss": loss, "logits": logits}

model = GemmaForClassification(base_model, hidden_size=base_model.config.hidden_size, num_labels=len(classes))

LoRA作成

from peft import LoraConfig, get_peft_model, TaskType

# LoRA 適用
lora_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",  # Attention
        "gate_proj", "up_proj", "down_proj"      # MLP
    ]
)

model.embedding_model = get_peft_model(model.embedding_model, lora_config)
model.embedding_model.print_trainable_parameters()

評価メソッド定義

import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
        "f1_macro": f1.compute(predictions=preds, references=labels, average="macro")["f1"]
    }

パラメーター設定

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./gemma_results",
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=1,
    per_device_train_batch_size=5,
    per_device_eval_batch_size=5,
    learning_rate=2e-4,
    weight_decay=0.01,
    logging_steps=50,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

学習!

trainer.train()

推論

import torch

def predict(text):
    inputs = tokenizer(
      text,
      return_tensors="pt",
      truncation=True,
      max_length=2048,
      padding="max_length"
    ).to("cuda")

    # 推論
    with torch.no_grad():
      outputs = model(**inputs)
      preds = outputs["logits"].argmax(dim=-1).item()

    return preds

model.eval()

h = 0
with open('predict.txt', 'w') as f:
  for i, val in enumerate(val_dataset):
    t = val['label']
    p = livedoor_news_keys[predict(val['text'])]
    #print("予測カテゴリ:", livedoor_news_keys[p])
    if t == p:
      h += 1
    f.write(f"{t},{p}\n")
print(h, i)

集計

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt

results = pd.read_csv('predict.txt', header=None)

figure, ax1 = plt.subplots()

SVM_confusion_df = pd.crosstab(results[0], results[1], rownames=['Actual'], normalize='index')
sn.heatmap(SVM_confusion_df, annot=True, cmap="YlGnBu", ax=ax1, cbar=False)

from sklearn.metrics import classification_report

report = classification_report(y_pred=results[1], y_true=results[0], target_names=classes, output_dict=True)
pd.DataFrame(report).T

Discussion