Transformers の SequenceClassification モデルが何をしているのか見る
忙しい人向けの要約
transformers の AutoModelForSequenceClassification は元のモデル AutoModel の hidden_size から指定したクラス数に射影する線形層 (score) を追加して、入力テキストの最終トークンの logits を利用しているよ。
self.model = Gemma2Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
はじめに
transformers には LLM でクラス分類が簡単に行えるように *ForSequenceClassification という命名規則に従ったクラスが用意されています。
例えば、Gemma2 を例にとると、モデルの推論を行い hidden_states を返す Gemma2Model に対し、分類ヘッドを付与した Gemma2ForSequenceClassification といった具合です。
この記事では Gemma2ForSequenceClassification を題材に、そういった分類用のクラスがどのように動いているのかを確認します。
コードで確認してみる
Gemma2ForSequenceClassification の実装
以下、次のような環境で試しています。
torch==2.5.1transformers==4.48.1
まずは実際に動かして挙動を確認してみます。Gemma-2 の 2B モデル (google/gemma-2-2b)を例に取って見てみましょう。
import torch
from transformers import GemmaTokenizerFast
model_id = "google/gemma-2-2b"
tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
seq_model = Gemma2ForSequenceClassification.from_pretrained(model_id)
text = "Hello, I'm prgckwb."
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = seq_model(**inputs)
logits = outputs.logits # tensor([[-2.2487, 1.3710]])
probs = logits.softmax(dim=-1) # tensor([[0.0261, 0.9739]])
preds = probs.argmax(-1) # tensor([1])
print(f'Pred: {preds.item()}') # Pred: 1
というように、なぜか分類用に学習されていない Gemma-2 モデルを Gemma2ForSequenceClassification でインスタンス化しただけで 2クラス分類っぽい動作ができてしまいました。
Gemma2ForSequenceClassification の実装を見てみると、init 関数に次のようなコードがあります。
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Gemma2Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
ここで、Gemma2ForSequenceClassification には、元のモデル Gemma2Model のインスタンス model と、Gemma2Model の hidden_size から指定したクラス数 num_labels に射影する線形層 score が定義されています。
ここから、Gemma-2 から得た埋め込みをどうにかしてクラス数と同じ形状のテンソルに変換していることが分かります。
次に forward 関数の流れを見ていきます。なお、必要な部分をかいつまんで記述するので気になる人は上記の実際のコードを見に行ってみてください。
今、入力テキストは "I am pen" として、これが Gemma2Tokenizer によってトークナイズされているとします。
text = "I am pen"
inputs = tokenizer(text, return_tensors="pt")
print(inputs.input_ids) # Shape: (1, 4)
tensor([[ 2, 235285, 1144, 3008]])
トークンは <bos> | I | am | pen | のように分かれています。
Gemma2ForSequenceClassification にこの inputs を通すと、
# Gemma2Model の推論
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# last_hidden_state の取得
hidden_states = transformer_outputs[0] # Shape: (1, 4, 2304)
# hidden_size -> num_labels への Projection
logits = self.score(hidden_states) # Shape: (1, 4, 2)
ここが重要な点で、今 入力テキストのトークンすべてに対して logits を計算していますが、Gemma2ForSequenceClassification はその全てを logits として返すわけではなく、プーリングの処理が入っています。プーリングは、入力テキストの最後のトークン位置のものを採用しています[1]。
# Padding Token ID がない時
if self.config.pad_token_id is None:
sequence_lengths = -1
# Padding Token ID がある時
else:
# 入力として input_ids が与えられている時 (大体これ)
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
# 参考:
# logits.shape -> (batch_size, num_tokens, num_labels)
# pooled_logits.shape -> (batch_size, num_labels)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
少しややこしいですが、要は "I am pen" が今、<bos> | I | am | pen | のようにトークナイズされている時、最後のトークンである pen の位置の logits が使われています。これが Gemma2ForSequenceClassification で行われている処理です。
ここで、もう一度推論のコードを見てみましょう。
import torch
from transformers import GemmaTokenizerFast
model_id = "google/gemma-2-2b"
tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
seq_model = Gemma2ForSequenceClassification.from_pretrained(model_id)
text = "Hello, I'm prgckwb."
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = seq_model(**inputs)
logits = outputs.logits # tensor([[-2.2487, 1.3710]])
probs = logits.softmax(dim=-1) # tensor([[0.0261, 0.9739]])
preds = probs.argmax(-1) # tensor([1])
print(f'Pred: {preds.item()}') # Pred: 1
先述した通り、入力テキストが "Hello, I'm prgckwb." で <bos> | Hello | , | I | ' | m | pr | g | ck | wb | . | のようにトークナイズされているとき、線形層 score に通された最後の "." トークンに対する logits が返されています。
まとめ
この記事では、transformers のテキスト分類用のクラスがどのように機能しているのかを紹介しました。実際にテキスト分類モデルを学習したい場合は transformers のドキュメントに良いチュートリアルが載っているので、ぜひ試してみてください。
-
筆者は NLP の素人なので、この辺りのトークンのプーリング処理が紹介されている分かりやすい記事や論文があったら教えてください。 ↩︎
Discussion