🧯

pytorch, BERT, batch処理, paddingするとしない時とは値が異なる

2023/11/04に公開

回答:気のせいです.

正確には,誤差の範囲,で良いと思います.

以降より細かい検証の話です.

環境情報

  • OS: win 10
  • pytorch==2.10

以下のコードでCLSの値を取り出します

from transformers import BertJapaneseTokenizer, BertModel
model_name = 'cl-tohoku/bert-base-japanese-whole-word-masking'

tokenizer = BertJapaneseTokenizer.from_pretrained( model_name )
pt_model = BertModel.from_pretrained( model_name )
pt_model.eval()


def encoding( text ) -> np.array:
    tokenized = tokenizer( text, padding=True, return_tensors="pt" )
    batched_ids = tokenized.input_ids
    mask = tokenized.attention_mask
    
    print( "***********************" )
    print( f"batched_ids: {batched_ids}" ) # 入力テキストがどの番号にtokenizeされたかの確認
    print( f"mask: {mask}" )        # 入力テキストのどこがmaskされているかの確認
    print( "***********************" )

    with torch.no_grad():
        outputs = pt_model( batched_ids, attention_mask=mask )

    last_hidden_states = outputs[0]
    cls = last_hidden_states[ :, 0, : ]
    return cls.numpy()

まずは同じ文をBERTに与えて,それらのCLSの値が一致するかを確かめます

e1 = encoding( "こんにちは" )
e2 = encoding( "こんにちは" )
print( f"shapes: {e1.shape}, {e2.shape}" )
print( e1[0, :30] ) # embeddingの先頭30要素の値を確認しておきます.念のため.
print( e2[0, :30] )
print( f"matched?: {(e1[0, :] == e2[0, :]).all()}" ) # e1とe2の要素の値が同じならTrue, そうでなければFalse.

# 以下実行結果
***********************
batched_ids: tensor([[    2, 10350, 25746, 28450,     3]])
mask: tensor([[1, 1, 1, 1, 1]])
***********************
***********************
batched_ids: tensor([[    2, 10350, 25746, 28450,     3]])
mask: tensor([[1, 1, 1, 1, 1]])
***********************
shapes: (1, 768), (1, 768)
[-0.12317186 -0.29084066  0.13944848  0.15317392 -0.21099006  0.24076676
 -0.27982163 -0.27193725 -0.33552122  0.1790561  -0.38695636  0.19878092
  0.08561726 -0.11280812  0.6676775  -0.2985249  -0.52726907  0.75172484
  0.35113254  0.22260407 -0.1364403   0.52606434 -1.0147591   0.10886844
 -0.19263619 -0.25389215 -0.10554688 -0.04117253  0.18171763  0.28062537]
[-0.12317186 -0.29084066  0.13944848  0.15317392 -0.21099006  0.24076676
 -0.27982163 -0.27193725 -0.33552122  0.1790561  -0.38695636  0.19878092
  0.08561726 -0.11280812  0.6676775  -0.2985249  -0.52726907  0.75172484
  0.35113254  0.22260407 -0.1364403   0.52606434 -1.0147591   0.10886844
 -0.19263619 -0.25389215 -0.10554688 -0.04117253  0.18171763  0.28062537]
matched?: True

e1とe2は完全に同じものです.matched?: Trueなため,二つのリストが格納している値は全て完全に同じものです.

paddingするとどうなるか確かめる:条件1

e1 = encoding( "こんにちは" )
e2 = encoding( ["こんにちは", "猫"] )
print( f"shapes: {e1.shape}, {e2.shape}" )
print( e1[0, :30] ) # [0, :30]は,0番目の文(こんにちは)のembeddingの先頭30要素,を意味する
print( e2[0, :30] ) # 同上
print( f"matched?: {(e1[0, :] == e2[0, :]).all()}" )

# 以下実行結果
***********************
batched_ids: tensor([[    2, 10350, 25746, 28450,     3]])
mask: tensor([[1, 1, 1, 1, 1]])
***********************
***********************
batched_ids: tensor([[    2, 10350, 25746, 28450,     3],
                     [    2,  6040,     3,     0,     0]]) # <- 文「猫」がパディングされた.パディングトークンは0.
mask: tensor([[1, 1, 1, 1, 1],
              [1, 1, 1, 0, 0]]) # <- 0 はマスクされた箇所.
***********************
shapes: (1, 768), (2, 768) # <- e2には2文入力したので,shape=(2, 768)
[-0.12317186 -0.29084066  0.13944848  0.15317392 -0.21099006  0.24076676
 -0.27982163 -0.27193725 -0.33552122  0.1790561  -0.38695636  0.19878092
  0.08561726 -0.11280812  0.6676775  -0.2985249  -0.52726907  0.75172484
  0.35113254  0.22260407 -0.1364403   0.52606434 -1.0147591   0.10886844
 -0.19263619 -0.25389215 -0.10554688 -0.04117253  0.18171763  0.28062537]
[-0.12317186 -0.29084066  0.13944848  0.15317392 -0.21099006  0.24076676
 -0.27982163 -0.27193725 -0.33552122  0.1790561  -0.38695636  0.19878092
  0.08561726 -0.11280812  0.6676775  -0.2985249  -0.52726907  0.75172484
  0.35113254  0.22260407 -0.1364403   0.52606434 -1.0147591   0.10886844
 -0.19263619 -0.25389215 -0.10554688 -0.04117253  0.18171763  0.28062537]
matched?: True

e2は二つの入力文のエンコード結果になっています.e2の二つめの文は,一つめの文こんにちはよりトークン数が少ないため,適当にpaddingされたことが上記の表示から確認できます.

現時点では,e1のこんにちはのembeddingとe2のこんにちはのembeddingは同じものです.

paddingするとどうなるか確かめる:条件2

e1 = encoding( "こんにちは" )
e2 = encoding( ["こんにちは", "今日の晩御飯は唐揚げの予定です"] )
print( f"shapes: {e1.shape}, {e2.shape}" )
print( e1[0, :30] )
print( e2[0, :30] )
print( f"matched?: {(e1[0, :] == e2[0, :]).all()}" )

# 以下実行結果
***********************
batched_ids: tensor([[    2, 10350, 25746, 28450,     3]])
mask: tensor([[1, 1, 1, 1, 1]])
***********************
***********************
batched_ids: tensor([[    2, 10350, 25746, 28450,     3,     0,     0,     0,     0,     0,     0,     0,     0], # <- 文「こんにちは」がパディングされた.
                     [    2,  3246,     5,  4423,  1351, 29916,     9,  3425, 14702,     5,  1484,  2992,     3]])
mask: tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
              [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
***********************
shapes: (1, 768), (2, 768)
[-0.12317186 -0.29084066  0.13944848  0.15317392 -0.21099006  0.24076676
 -0.27982163 -0.27193725 -0.33552122  0.1790561  -0.38695636  0.19878092
  0.08561726 -0.11280812  0.6676775  -0.2985249  -0.52726907  0.75172484
  0.35113254  0.22260407 -0.1364403   0.52606434 -1.0147591   0.10886844
 -0.19263619 -0.25389215 -0.10554688 -0.04117253  0.18171763  0.28062537]
[-0.1231717  -0.29084083  0.13944837  0.1531737  -0.21098952  0.24076748
 -0.27982208 -0.2719369  -0.33552155  0.17905572 -0.38695666  0.19878091
  0.08561811 -0.1128076   0.66767853 -0.2985254  -0.52726924  0.7517249
  0.35113272  0.22260398 -0.13643995  0.5260654  -1.0147597   0.10886877
 -0.19263577 -0.25389168 -0.10554667 -0.04117169  0.18171719  0.28062564]
matched?: False

この条件ではe2の二文のうち,一つめの文こんにちはにパディングがかかるように,二つめの文を長くしました.意図した通りに文こんにちはにはパディングがかかりました.

そこでまずmatched?の表示を見ると,Falseです.すなわちe1のこんにちはとe2のこんにちはは完全に同じものではありません.しかしembeddingの値を確認すると,e1の先頭3つは-0.12317186 -0.29084066 0.13944848であり,e2は-0.1231717 -0.29084083 0.13944837です.それぞれ先頭から,-0.123171, -0.290840, 0.139448と,小数点以下6桁までは一致しているようです.

すなわちこの結果は,バッチ処理のために文をpaddingすると,謎の計算誤差が生じて,paddingしない場合とは値が完全に同じにはならない,ことを示唆しています.といっても,小数点以下6桁までは一致しているので,あまり気にすることは無いと思います.

従って,気のせい,で済ませましょう.

おわりに

やり方間違ってるよ~,という場合は教えてください.

GitHubで編集を提案

Discussion