Transformers の SequenceClassification モデルが何をしているのか見る
忙しい人向けの要約
transformers
の AutoModelForSequenceClassification
は元のモデル AutoModel
の hidden_size
から指定したクラス数に射影する線形層 (score
) を追加して、入力テキストの最終トークンの logits
を利用しているよ。
self.model = Gemma2Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
はじめに
transformers
には LLM でクラス分類が簡単に行えるように *ForSequenceClassification
という命名規則に従ったクラスが用意されています。
例えば、Gemma2 を例にとると、モデルの推論を行い hidden_states
を返す Gemma2Model
に対し、分類ヘッドを付与した Gemma2ForSequenceClassification
といった具合です。
この記事では Gemma2ForSequenceClassification
を題材に、そういった分類用のクラスがどのように動いているのかを確認します。
コードで確認してみる
Gemma2ForSequenceClassification
の実装
以下、次のような環境で試しています。
torch==2.5.1
transformers==4.48.1
まずは実際に動かして挙動を確認してみます。Gemma-2 の 2B モデル (google/gemma-2-2b)を例に取って見てみましょう。
import torch
from transformers import GemmaTokenizerFast
model_id = "google/gemma-2-2b"
tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
seq_model = Gemma2ForSequenceClassification.from_pretrained(model_id)
text = "Hello, I'm prgckwb."
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = seq_model(**inputs)
logits = outputs.logits # tensor([[-2.2487, 1.3710]])
probs = logits.softmax(dim=-1) # tensor([[0.0261, 0.9739]])
preds = probs.argmax(-1) # tensor([1])
print(f'Pred: {preds.item()}') # Pred: 1
というように、なぜか分類用に学習されていない Gemma-2 モデルを Gemma2ForSequenceClassification
でインスタンス化しただけで 2クラス分類っぽい動作ができてしまいました。
Gemma2ForSequenceClassification
の実装を見てみると、init
関数に次のようなコードがあります。
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Gemma2Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
ここで、Gemma2ForSequenceClassification
には、元のモデル Gemma2Model
のインスタンス model
と、Gemma2Model
の hidden_size
から指定したクラス数 num_labels
に射影する線形層 score
が定義されています。
ここから、Gemma-2 から得た埋め込みをどうにかしてクラス数と同じ形状のテンソルに変換していることが分かります。
次に forward
関数の流れを見ていきます。なお、必要な部分をかいつまんで記述するので気になる人は上記の実際のコードを見に行ってみてください。
今、入力テキストは "I am pen" として、これが Gemma2Tokenizer
によってトークナイズされているとします。
text = "I am pen"
inputs = tokenizer(text, return_tensors="pt")
print(inputs.input_ids) # Shape: (1, 4)
tensor([[ 2, 235285, 1144, 3008]])
トークンは <bos> | I | am | pen |
のように分かれています。
Gemma2ForSequenceClassification
にこの inputs
を通すと、
# Gemma2Model の推論
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# last_hidden_state の取得
hidden_states = transformer_outputs[0] # Shape: (1, 4, 2304)
# hidden_size -> num_labels への Projection
logits = self.score(hidden_states) # Shape: (1, 4, 2)
ここが重要な点で、今 入力テキストのトークンすべてに対して logits
を計算していますが、Gemma2ForSequenceClassification
はその全てを logits
として返すわけではなく、プーリングの処理が入っています。プーリングは、入力テキストの最後のトークン位置のものを採用しています[1]。
# Padding Token ID がない時
if self.config.pad_token_id is None:
sequence_lengths = -1
# Padding Token ID がある時
else:
# 入力として input_ids が与えられている時 (大体これ)
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
# 参考:
# logits.shape -> (batch_size, num_tokens, num_labels)
# pooled_logits.shape -> (batch_size, num_labels)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
少しややこしいですが、要は "I am pen" が今、<bos> | I | am | pen |
のようにトークナイズされている時、最後のトークンである pen
の位置の logits
が使われています。これが Gemma2ForSequenceClassification
で行われている処理です。
ここで、もう一度推論のコードを見てみましょう。
import torch
from transformers import GemmaTokenizerFast
model_id = "google/gemma-2-2b"
tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
seq_model = Gemma2ForSequenceClassification.from_pretrained(model_id)
text = "Hello, I'm prgckwb."
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = seq_model(**inputs)
logits = outputs.logits # tensor([[-2.2487, 1.3710]])
probs = logits.softmax(dim=-1) # tensor([[0.0261, 0.9739]])
preds = probs.argmax(-1) # tensor([1])
print(f'Pred: {preds.item()}') # Pred: 1
先述した通り、入力テキストが "Hello, I'm prgckwb." で <bos> | Hello | , | I | ' | m | pr | g | ck | wb | . |
のようにトークナイズされているとき、線形層 score
に通された最後の "."
トークンに対する logits
が返されています。
まとめ
この記事では、transformers
のテキスト分類用のクラスがどのように機能しているのかを紹介しました。実際にテキスト分類モデルを学習したい場合は transformers
のドキュメントに良いチュートリアルが載っているので、ぜひ試してみてください。
-
筆者は NLP の素人なので、この辺りのトークンのプーリング処理が紹介されている分かりやすい記事や論文があったら教えてください。 ↩︎
Discussion