🎤

OpenAI Whisperのコードの動きを追ってみるメモ

2024/06/22に公開

ただ愚直に動きを追う

参考

https://youtu.be/2rS3DdvW-pQ

環境準備

レポジトリ

git clone https://github.com/openai/whisper.git

でレポジトリを持ってくる.

ファイル用意

音声データ

一番上に適当なwavファイルなども追加. (whisper.cppに入っていたjfk.wav音声データ)
https://github.com/ggerganov/whisper.cpp/blob/master/samples/jfk.wav

trymain.pyを追加

baseサイズのmodelを使用する. 実行したら自動でダウンロードされる.
さきほどのwavファイルを指定して実行

import whisper

model = whisper.load_model("base")
result = model.transcribe("jfk.wav", beam_size=5) #Beam Searchの場合
#result = model.transcribe("jfk.wav") #Greedy Searchの場合
print(result["text"])

デバッグ時のTokenの確認方法

tokenizer = get_tokenizer(True)
tokenizer.decode([50258])
# '<|startoftranscript|>'

実行

transcribe.pyのtranscribeから追跡.
まず指定されたファイルからmelへlog変換したデータを読み込み.

mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)

最初は言語認識.

mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)

mel_semgementのサイズは(80, 3000) (80の特徴サイズで3000ms)

まず、AudioEncodrでmelを処理.

xの処理後のサイズは(1,1500,512)

"|StartOfTranscript|"のトークンIDだけを引数にするようにtokenizer.sotをもったxを定義

    x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)
# [n_audio, 1]
# tensor([[50258]]) 50258 = "|StartOfTranscript|"

上記のxとmelを使用してDecode処理

    logits = model.logits(x, mel)[:, 0]

内部は

#Whisper
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
    return self.decoder(tokens, audio_features)

これはDecode処理を行っているのみ.
結果は

logitsの結果を使って、languageのprobsやtokensを取得.

# logitsと同サイズのFalseのMaskを生成
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
# tensor([True, True, True,  ..., True, True, True]) : size 51865
# list(tokenizer.all_language_tokens) = [50261, 50357, 50302, 50334, 50281, 50282, 50264, 50269, 50324, 50291, 50355, 50338, 50272, 50289, 50354, 50312, 50279, 50345, 50313, 50304, 50274, 50308, 50306, 50326, 50267, 50333, 50284, 50273, 50309, 50328, 50339, 50277, 50319, 50298, 50278, 50299, 50331, 50266, 50262, 50337, 50290, 50303, 50286, 50314, 50322, 50323, 50271, 50347, 50348, 50352, 50270, 50320, 50318, 50330, 50294, 50341, 50288, 50342, 50343, ...]
# 言語99個を表すtokenId.これらはFalseにはしない.
mask[list(tokenizer.all_language_tokens)] = False
# mask対象だけ-infにする. つまり言語だけがlogitsを残した状態
logits[:, mask] = -np.inf
# その言語の中で最大のindexを取得.
language_tokens = logits.argmax(dim=-1)
# tensor([50259])
language_token_probs = logits.softmax(dim=-1).cpu()
# language_token_probs
# tensor([[0., 0., 0.,  ..., 0., 0., 0.]])
# language_token_probs.shape
# torch.Size([1, 51865])
language_probs = [
        {
            c: language_token_probs[i, j].item()
            for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
        }
        for i in range(n_audio)
    ]
# language_probs
# [{'de': 0.001759388018399477, 'su': 2.2345496120124153e-07, 'bn': 0.00023899634834378958, 'am': 1.9668059394462034e-06, 'el': 0.0017844138201326132, 'ms': 0.0001395636354573071, 'ko': 0.000954355753492564, 'pl': 0.00028663044213317335, 'sn': 0.00023973184579517692, 'hr': 2.0213905372656882e-05, 'ba': 3.014122995637081e-08, 'fo': 1.4884884876664728e-05, 'ar': 0.0019773580133914948, 'th': 3.658004061435349e-05, 'ha': 2.6089003313245485e-07, 'hy': 1.838958269217983e-05, 'he': 0.0012696300400421023, 'lb': 2.870650064323854e-07, 'ne': 7.937302143545821e-06, ...}]

これをprobsとして受け取り. 最大の値を取り出す.

decode_options["language"] = max(probs, key=probs.get)
#'en'

decode処理

言語を認識した後は3000ごとの認識処理が進む.
decode_with_fallbackにそのsegmentが(80, 3000)のサイズで渡る.

decoding.pyのDecodingTaskのrunにてまずsegmentに対してencode処理.
_get_audio_featuresを実行.
xないしaudio_featuresは

x.shape
# torch.Size([1, 1500, 512])

tokensの初期化。'<|startoftranscript|><|en|><|transcribe|>'が初期の状態。
これをBeamSearchの5回分を保持.

tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# tensor([[50258, 50259, 50359]]) : '<|startoftranscript|><|en|><|transcribe|>'
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# tensor([[50258, 50259, 50359],
        [50258, 50259, 50359],
        [50258, 50259, 50359],
        [50258, 50259, 50359],
        [50258, 50259, 50359]])

いよいよ_main_loopにてdecode処理が始まる.

tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)

この中では、sum_logprobsをまず5個の0.

sample_lenは224回のdecode処理を行うため.
tokensとaudio_featuresの入力を使って、logitsを計算.

            for i in range(self.sample_len):
                logits = self.inference.logits(tokens, audio_features)

                if (
                    i == 0 and self.tokenizer.no_speech is not None
                ):  # save no_speech_probs
                    probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
                    no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()

                # now we need to consider the logits at the last token only
                logits = logits[:, -1]

                # apply the logit filters, e.g. for suppressing or applying penalty to
                for logit_filter in self.logit_filters:
                    logit_filter.apply(logits, tokens)

                # expand the tokens tensor with the selected next tokens
                tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)

                if completed or tokens.shape[-1] > self.n_ctx:
                    break

logitsはもちろんbeam_sizeの5個分を持つ.

logits.shape
# torch.Size([5, 3, 51865]) [beam_size, 初回だけ3token使用以後1, tokenIDの種類数]
logits
# tensor([[[-2.9108, -4.6617,  1.1199,  ...,  1.6766,  1.0728,  1.9878],
         [10.9311,  9.8526,  9.3970,  ...,  9.6659,  9.1890,  9.8944],
         [ 9.2720,  7.3892,  3.9997,  ...,  4.2641,  4.2296,  0.8123]],

        [[-2.9108, -4.6617,  1.1199,  ...,  1.6766,  1.0728,  1.9878],
         [10.9311,  9.8526,  9.3970,  ...,  9.6659,  9.1890,  9.8944],
         [ 9.2720,  7.3892,  3.9997,  ...,  4.2641,  4.2296,  0.8123]],

        [[-2.9108, -4.6617,  1.1199,  ...,  1.6766,  1.0728,  1.9878],
         [10.9311,  9.8526,  9.3970,  ...,  9.6659,  9.1890,  9.8944],
         [ 9.2720,  7.3892,  3.9997,  ...,  4.2641,  4.2296,  0.8123]],

        [[-2.9108, -4.6617,  1.1199,  ...,  1.6766,  1.0728,  1.9878],
         [10.9311,  9.8526,  9.3970,  ...,  9.6659,  9.1890,  9.8944],
         [ 9.2720,  7.3892,  3.9997,  ...,  4.2641,  4.2296,  0.8123]],

        [[-2.9108, -4.6617,  1.1199,  ...,  1.6766,  1.0728,  1.9878],
         [10.9311,  9.8526,  9.3970,  ...,  9.6659,  9.1890,  9.8944],
         [ 9.2720,  7.3892,  3.9997,  ...,  4.2641,  4.2296,  0.8123]]])

logitsの中で不要な特徴データの部分を-infでマスクするため. 不要な記号や言語等々.

                for logit_filter in self.logit_filters:
                    logit_filter.apply(logits, tokens)

logitsは下記のようにマスクされる.※全てではない.

tensor([[-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf]])

そして、

                tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)

ここの中でlogitsとtokensから次のtokenを予測する.
予測後、最後にcompletedが立っていれば、終了.

logits処理の中

logitsの中では、渡された5つの候補のうち末尾のtokenを得る.

tensor([[370],
        [370],
        [452],
        [ 11],
        [452]])

上記のxに対して下記のforward処理. blocksは6つ.

 offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset : offset + x.shape[-1]]
        )
        x = x.to(xa.dtype)

        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)

        x = self.ln(x)
        logits = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()

BeamSearch

ここでは、beam_size = 5としている。
logitsも5x51865のサイズの配列で、GreedySearchの場合は1x51865.

logprobsも同様のサイズで同様の内容

n_audioは1

ここからそれぞれのlogitsにおける候補を出していく.

# STEP 1: calculate the cumulative log probabilities for possible candidates
for j in range(self.beam_size):
    idx = i * self.beam_size + j
    prefix = tokens[idx].tolist()
    for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
        new_logprob = (sum_logprobs[idx] + logprob).item()
        sequence = tuple(prefix + [token.item()])
        scores[sequence] = new_logprob
        sources[sequence] = idx

まず5つの候補(self.beam_size)でループ. 各logprobsでtopkを取得して、最も可能性のある6つのtokenとそのlogprobの値を取得.
この値をsum_logprobsに追加するとともにsequenceを作成.

sequenceの値は下記.

(50258, 50259, 50359, 50364)

sequenceをKeyにして、Scoresにnew_logprobを追加.
全ての候補に関して、それぞれ6個ずつの候補を計算し終わると、
同じ候補は重複して、下記のようなscoreを得る

{(50258, 50259, 50359, 50364): -0.18531185388565063, (50258, 50259, 50359, 50379): -4.650999546051025, (50258, 50259, 50359, 50380): -4.690244197845459, (50258, 50259, 50359, 50376): -4.711496829986572, (50258, 50259, 50359, 50374): -4.738053798675537, (50258, 50259, 50359, 50378): -4.740087032318115}

次のSTEP2では上記のscoreの高い順に取得して上位五つを得たらbeam searchとしては終了.

            for sequence in sorted(scores, key=scores.get, reverse=True):
                if sequence[-1] == self.eot:
                    finished[sequence] = scores[sequence]
                else:
                    sum_logprobs[len(next_tokens)] = scores[sequence]
                    next_tokens.append(sequence)
                    source_indices.append(sources[sequence])

                    saved += 1
                    if saved == self.beam_size:
                        break

scoresのキーがSequenceになっているため、それを今回のTokenの候補列next_tokensに詰めていく. savedが5になったら終了.

完了するとnext_tokensは

となって、5つのtokens候補列となる.

何度かtokensを更新しているとscoresも先ほどは5つほどだったが、30フルで貯まるようになる.

scores
{(50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 1029, 406, 437, 428): -2.5755414962768555, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 1029, 406, 437, 291): -6.958094596862793, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 1029, 406, 437, 50644): -9.61194133758545, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 1029, 406, 437, 2260): -9.743788719177246, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 1029, 406, 437, 527): -9.947134971618652, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 1029, 406, 437, 50642): -10.028239250183105, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 11, 1029, 406, 437): -3.06215500831604, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 11, 1029, 406, 11): -5.478387832641602, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 11, 1029, 406, 13): -6.601102828979492, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 11, 1029, 406, 485): -7.215543746948242, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 11, 1029, 406, 50614): -8.077592849731445, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 11, 1029, 406, 0): -8.501394271850586, (50258, 50259, 50359, 50364, 400, 370, 11, 452, 7177, 6280, 11, 1029, 406): -3.26399302482605, (50258, 50259, 50359, 50364, 400, 370, 11, 452, 7177, 6280, 11, 1029, 12854): -4.529715538024902, (50258, 50259, 50359, 50364, 400, 370, 11, 452, 7177, 6280, 11, 1029, 11): -6.782641410827637, (50258, 50259, 50359, 50364, 400, 370, 11, 452, 7177, 6280, 11, 1029, 1726): -8.242207527160645, (50258, 50259, 50359, 50364, 400, 370, 11, 452, 7177, 6280, 11, 1029, 426): -8.829968452453613, (50258, 50259, 50359, 50364, 400, 370, 11, 452, 7177, 6280, 11, 1029, 9146): -9.037169456481934, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 2351, 406, 437, 428): -3.4666168689727783, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 2351, 406, 437, 291): -7.964332103729248, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 2351, 406, 437, 50644): -10.476215362548828, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 2351, 406, 437, 2260): -10.741695404052734, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 2351, 406, 437, 527): -10.778972625732422, (50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 2351, 406, 437, 50642): -10.84210205078125, (50258, 50259, 50359, 50364, 293, 370, 452, 7177, 6280, 1029, 406, 437, 428): -3.6312332153320312, (50258, 50259, 50359, 50364, 293, 370, 452, 7177, 6280, 1029, 406, 437, 291): -8.141895294189453, (50258, 50259, 50359, 50364, 293, 370, 452, 7177, 6280, 1029, 406, 437, 50644): -10.527690887451172, (50258, 50259, 50359, 50364, 293, 370, 452, 7177, 6280, 1029, 406, 437, 50642): -10.755844116210938, (50258, 50259, 50359, 50364, 293, 370, 452, 7177, 6280, 1029, 406, 437, 50640): -10.79433822631836, (50258, 50259, 50359, 50364, 293, 370, 452, 7177, 6280, 1029, 406, 437, 2260): -11.004585266113281}

そしてBeamSearch前後のtokensを見ると
前Token

'<|startoftranscript|><|en|><|transcribe|> And so my fellow Americans ask not what your'
'<|startoftranscript|><|en|><|transcribe|> And so my fellow Americans, ask not what'
'<|startoftranscript|><|en|><|transcribe|> And so, my fellow Americans, ask not'
'<|startoftranscript|><|en|><|transcribe|> And so my fellow Americans asked not what your'
'<|startoftranscript|><|en|><|transcribe|> and so my fellow Americans ask not what your'

後Tokens

'<|startoftranscript|><|en|><|transcribe|> And so my fellow Americans ask not what your country'
'<|startoftranscript|><|en|><|transcribe|> And so my fellow Americans, ask not what your'
'<|startoftranscript|><|en|><|transcribe|> And so, my fellow Americans, ask not what'
'<|startoftranscript|><|en|><|transcribe|> And so my fellow Americans asked not what your country'
'<|startoftranscript|><|en|><|transcribe|> and so my fellow Americans ask not what your country'

また同時に上記のパターンごとのsum_logprobsも残していく.

tensor([ -7.7090,  -8.0222,  -8.6720, -10.6214, -10.7251])

STEP2のもう一つの条件分岐に入るとfinished_sequencesに入る.

'<|startoftranscript|><|en|><|transcribe|> And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.<|endoftext|>'

しかし、5個の候補のうち終了したのは上記だけのため、まだ処理は続行.
これが最後まで進められると、finished_sequencesが

(50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 11, 1029, 406, 437, 428, 1941, 393, 360, 337, 291, 11, 1029, 437, 291, 393, 360, 337, 428, 1941, 13, 50914, 50257):
-8.072649955749512
(50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 11, 1029, 406, 437, 428, 1941, 393, 360, 337, 291, 11, 1029, 437, 291, 393, 360, 337, 428, 1941, 13, 50889, 50257):
-8.992630004882812
(50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 1029, 406, 437, 428, 1941, 393, 360, 337, 291, 11, 50764, 50764, 1029, 437, 291, 393, 360, 337, 428, 1941, 13, 50914, 50257):
-8.70301342010498
(50258, 50259, 50359, 50364, 400, 370, 452, 7177, 6280, 1029, 406, 437, 428, 1941, 393, 360, 337, 291, 11, 50764, 50764, 1029, 437, 291, 393, 360, 337, 428, 1941, 13, 50964, 50257):
-10.751739501953125
(50258, 50259, 50359, 50364, 400, 370, 11, 452, 7177, 6280, 11, 1029, 406, 437, 428, 1941, 393, 360, 337, 291, 11, 50764, 50764, 1029, 437, 291, 393, 360, 337, 428, 1941, 13, 50914, 50257):
-8.878769874572754

を内包する長さ1のsequencesとなり、これが5つとなるため、終了する.completed = Trueに.

テキストの確定

_main_loopを抜けたあとは、tokensを整形して、下記の状態に.

[[tensor([50364,   400,   370,   452,  7177,  6280,    11,  1029,   406,   437,
       ...  360,   337,   428,  1941,    13, 50914]),
 tensor([50364,   400,   370,   452,  7177,  6280,    11,  1029,   406,   437,
       ...  360,   337,   428,  1941,    13, 50889]),
 tensor([50364,   400,   370,   452,  7177,  6280,  1029,   406,   437,   428,
       ...  360,   337,   428,  1941,    13, 50914]),
 tensor([50364,   400,   370,   452,  7177,  6280,  1029,   406,   437,   428,
       ...  360,   337,   428,  1941,    13, 50964]),
 tensor([50364,   400,   370,    11,   452,  7177,  6280,    11,  1029,   406,
       ...  360,   337,   428,  1941,    13, 50914])]]

そして下記でベストな文字列が選択される.

# select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
# [4]
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
# [[50364, 400, 370, 11, 452, 7177, 6280, 11, 1029, 406, 437, 428, 1941, 393, 360, 337, 291, 11, 50764, ...]]
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
# ['And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country.']

Discussion