🥷
transformersライブラリを使い大規模言語モデルから尤度を取得する
概要
transformersライブラリを使い大規模言語モデルから尤度を取得する方法を紹介します。コードの全体はここにあります。
コード
def likelihood(self, batch: torch.Tensor) -> torch.Tensor:
"""
Calculate the likelihood of the batch.
Args:
batch (torch.Tensor): A tensor of shape (batch_size, sequence_length).
Returns:
torch.Tensor: A tensor of shape (batch_size,).
"""
encoded_batch = tokenize(batch)
outputs = model(**encoded_batch)
logits = outputs.logits
all_scores = logits.gather(-1, encoded_batch.input_ids.unsqueeze(-1)).squeeze()
attention_mask = encoded_batch.attention_mask
unmasked_scores = all_scores * attention_mask
avg_likelihood = unmasked_scores.sum(1) / attention_mask.sum(1)
return avg_likelihood
解説
- まず、
tokenizeで入力テキストをID化します。-
encoded_batchはinput_idsとattention_maskを持っています。input_idsはトークン化されたテキストをIDに変換したものです。 -
attention_maskはinput_idsにおいて[PAD]などの特殊トークンでパディングされた箇所が0(パティング)と1(トークン)で表記されています。
-
-
modelにencoded_batchを入力します。- 出力
outputsのlogitsは入力テキストのトークンごとの尤度です。
- 出力
-
logitsには全語彙の尤度が格納されているため、gatherを使いlogitsからinput_idsのトークンIDに対応した尤度を取得します。-
logitsは[batch_size, sequence_length, vocab_size]の形をしており、input_idsは[batch_size, sequence_length]の形をしている。 -
unsqueezeでinput_idsに次元を追加してinput_idsは[batch_size, sequence_length, 1]の形になります。 -
gatherを使いlogitsに対してinput_idsの最後の次元の値をインデックスとして取得します。 -
gatherの出力は[batch_size, sequence_length, 1]の形をしており、squeezeで最後の次元を削除しています。
-
-
attention_maskをかけることでパディングされた箇所の尤度を0にしunmasked_scoresを計算します。 - パディング箇所の尤度を0にしたものをパティング箇所を除いた入力テキスト長で平均することで、テキスト全体の尤度の平均を計算します。
Discussion