🥷

transformersライブラリを使い大規模言語モデルから尤度を取得する

2023/12/31に公開

概要

transformersライブラリを使い大規模言語モデルから尤度を取得する方法を紹介します。コードの全体はここにあります。

コード

def likelihood(self, batch: torch.Tensor) -> torch.Tensor:
        """
        Calculate the likelihood of the batch.

        Args:
            batch (torch.Tensor): A tensor of shape (batch_size, sequence_length).

        Returns:
            torch.Tensor: A tensor of shape (batch_size,).
        """
        encoded_batch = tokenize(batch)
        outputs = model(**encoded_batch)
        logits = outputs.logits
        all_scores = logits.gather(-1, encoded_batch.input_ids.unsqueeze(-1)).squeeze()
        attention_mask = encoded_batch.attention_mask
        unmasked_scores = all_scores * attention_mask
        avg_likelihood = unmasked_scores.sum(1) / attention_mask.sum(1)
        return avg_likelihood

解説

  • まず、tokenize で入力テキストをID化します。
    • encoded_batchinput_idsattention_mask を持っています。input_ids はトークン化されたテキストをIDに変換したものです。
    • attention_maskinput_ids において [PAD] などの特殊トークンでパディングされた箇所が0(パティング)と1(トークン)で表記されています。
  • modelencoded_batch を入力します。
    • 出力 outputslogits は入力テキストのトークンごとの尤度です。
  • logits には全語彙の尤度が格納されているため、 gather を使い logits から input_ids のトークンIDに対応した尤度を取得します。
    • logits[batch_size, sequence_length, vocab_size] の形をしており、input_ids[batch_size, sequence_length] の形をしている。
    • unsqueezeinput_ids に次元を追加して input_ids[batch_size, sequence_length, 1] の形になります。
    • gather を使い logits に対して input_ids の最後の次元の値をインデックスとして取得します。
    • gather の出力は [batch_size, sequence_length, 1] の形をしており、 squeeze で最後の次元を削除しています。
  • attention_mask をかけることでパディングされた箇所の尤度を0にし unmasked_scores を計算します。
  • パディング箇所の尤度を0にしたものをパティング箇所を除いた入力テキスト長で平均することで、テキスト全体の尤度の平均を計算します。

Discussion