TenosrflowのMultiHeadAttentionのMaskについて調査
はじめに
MultiHeadAttention は Transformer を構成する重要なレイヤーである。Transformer を機械翻訳に利
用する際には、mask を利用し一部の要素を隠して学習することになる。この記事では、論文「Attention is All You Need」[1]で利用されている 3 つの MultiHeadAttention について、それぞれの mask をどう設定すべきかを考察し、それを Tensorflow でどのように実装するのかを説明する。
結果だけ先に述べると、
- Embedding で zero_mask=True を設定する。
- Encoder の Self-Attention は何も設定しない。
- Decoder の Self-Attention は use_causal_mask=True を設定する。
- Encoder-Decoder-Attention は何も設定しない。
を設定すれば正しく学習できている。この記事ではこれを順番に確認していく。
モデルの解説
Transformer を用いた翻訳機のモデルは次のようなものである。(図は論文[1:1]から引用)
このモデル全体の詳細な説明は Tensorflow のチュートリアル[2]に譲り、この記事ではこのモデルの学習に必要な mask の部分について説明を行う。
この記事では図中に現れる 3 つの MultiHeadAttention について、Encoder の Self-Attention、Decoder の Self-Attention、Encoder-Decoder-Attention と名付けた。これらの MultiHeadAttention の mask を考察する。
import tensorflow as tf
from tensorflow.keras.layers import MultiHeadAttention, Embedding, Layer
EMBEDDING_DIM = 128
VOCAB_SIZE = 20000
NUM_HEADS = 8
Encoder の Self-Attention
はじめに、Encoder の Self-Attention の mask を確認する。以下のような入力を考える。
encoder_inputs = Embedding(VOCAB_SIZE, EMBEDDING_DIM, mask_zero=True)(
tf.constant([[1, 2, 3, 4, 0, 0]], dtype=tf.int32)
)
print(encoder_inputs.shape) # (1, 6, 128)
ここで、0 は padding であり、1,2,3,4 は単語 ID とする。mask_zero=True
を指定することでencoder_inputs
に mask の情報が付属するようになる[3]。(ここで付属するとはencoder_inputs._keras_mask
に mask 情報が書き込まれることを指す。)
encoder_inputs
の mask を確認する。
class Mask(Layer):
def __init__(self):
super(Mask, self).__init__()
self.supports_masking = True
def call(self, inputs, mask):
return inputs, mask
_, encoder_mask = Mask()(encoder_inputs)
print(encoder_mask) # [[ True True True True False False]]
付属している mask は call メソットの引数として取り出すことができる。自作で定義したレイヤーでは、supports_masking=True
を設定することで、入力に付属する mask を出力にも付属させることが可能になる[3:1]。TensorFlow のこの機能を用いて、Embedding レイヤーで取得した mask 情報を MultiHeadAttention に伝搬させることが可能になっている。
ちなみに、上記のコードは以下のようにしても出力に mask が付属する。
class Mask(Layer):
def call(self, inputs, mask):
return inputs, mask
encoder_inputs_2, encoder_mask_2 = Mask()(encoder_inputs)
print(encoder_mask_2.numpy()) # [[ True True True True False False]]
encoder_inputs_3, encoder_mask_3 = Mask()(encoder_inputs_2)
print(encoder_mask_3.numpy()) # [[ True True True True False False]]
これはこの Layer が inputs の mask 付きオブジェクトから、別のオブジェクトを生成せずにそのまま返却しているためである。ただ、PositionalEmbedding を普通に実装した場合は、Layer 内で outputs のオブジェクトが新たに生成されるので、inputs の mask は勝手には伝搬されない。そのため最初に記載したように、supports_masking=True
を書く必要がある。
この mask 付きのencoder_inputs
を Encoder の MultiHeadAttention に入力する。
encoder = MultiHeadAttention(EMBEDDING_DIM, NUM_HEADS)
encoder_outputs, encoder_attention = encoder(
encoder_inputs, encoder_inputs, return_attention_scores = True
)
print(encoder_outputs.shape, encoder_attention.shape) # (1, 6, 128) (1, 8, 6, 6)
ここでencoder_attention
は入力の単語同士の関係を表すテンソルである。8 は Head の数を反映しており、(6, 6)は単語同士の関係の重みである。(本来、単語同士の関係を数値にしようとしたら、単語数 × 単語数分の情報を持つ必要がある。ただ、Self-Attention ではこれを単語数分の情報で表現している。これによって少ない情報量で多くの情報を保持しようとしているのがこの論文の一つのキモだと思う。少し Factorization Machine に似た雰囲気も感じる。)
Encoder では全ての単語同士の関係に重みが必要である。その一方で単語と padding の間には重みは必要ない。そのため、mask としては次のようなものが適用されていれば良いことになる。
[[ True True True True False False]
[ True True True True False False]
[ True True True True False False]
[ True True True True False False]
[False False False False False False]
[False False False False False False]]
TensorFlow の MultiHeadAttention では入力のオブジェクトが持つ mask を自動で計算して、上記のような mask を自動で適用してくれている。それはここを見ればよく分かる。実際にこれを動かしてみると以下のようになる。
mask = encoder._compute_attention_mask(
encoder_inputs, encoder_inputs
)
print(mask)
# [[[ True True True True False False]
# [ True True True True False False]
# [ True True True True False False]
# [ True True True True False False]
# [False False False False False False]
# [False False False False False False]]]
確かに自動的に mask 情報が計算されている。またencoder_attention
を見ても padding との関係が無視されていることが確認できるだろう。
Decoder の Self-Attention
次に、Decoder の Self-Attention の mask を確認する。Encoder の時と同様に、mask が付属する入力を作成する。
decoder_inputs = Embedding(VOCAB_SIZE, EMBEDDING_DIM, mask_zero=True)(
tf.constant([[1, 2, 3, 0, 0, 0]], dtype=tf.int32)
)
decoder は次のように作成する。重要なのはuse_causal_mask=True
を指定するところである。
decoder = MultiHeadAttention(NUM_HEADS, EMBEDDING_DIM)
decoder_outputs, decoder_attention = decoder(
decoder_inputs, decoder_inputs, return_attention_scores = True, use_causal_mask=True
)
print(decoder_outputs.shape, decoder_attention.shape) # (1, 6, 128) (1, 8, 6, 6)
機械翻訳において、decoder は次の単語を予想するタスクを解く。そのため、次の単語の情報を入力として使ってしまっては学習ができない。それを解決するため Decoder の Self-Attention では以下のような mask を用いる。
[[ True False False False False False]
[ True True False False False False]
[ True True True False False False]
[False False False False False False]
[False False False False False False]
[False False False False False False]]
当然、単語と padding 間の値にも意味がないので、その部分は mask している。それに加えて右上の三角形の部分も取り除いている。これによって、次の単語以降の情報を入力せずに学習を進めることができている。(ちなみに、このように右上の三角形を取り除くと、Decoder 全体で上記行列の基底で n 行目のベクトルが、文の n 番目の単語までの情報しか与えられていないベクトルとなっている。一回の計算で多段に情報が制限されたベクトルを同時に計算しているのがこの論文のキモとなっている。)
実際に動かして確認すると次のようになる。
mask = decoder._compute_attention_mask(
decoder_inputs, decoder_inputs, use_causal_mask=True
)
print(mask)
# [[[ True False False False False False]
# [ True True False False False False]
# [ True True True False False False]
# [False False False False False False]
# [False False False False False False]
# [False False False False False False]]]
これは正しい結果である。入力に mask を付属させて、use_causal_mask=True
を指定すれば良い。
Encoder-Decoder-Attention
最後に、Encoder-Decoder-Attention の mask について説明する。Encoder-Decoder-Attention は次のように作成する。
attention = MultiHeadAttention(NUM_HEADS, EMBEDDING_DIM)
decoder_outputs, decoder_attention = attention(
query = decoder_outputs,
key = encoder_outputs,
value = encoder_outputs,
return_attention_scores = True
)
print(decoder_outputs.shape, decoder_attention.shape) # (1, 6, 128) (1, 8, 6, 6)
ここで、MultiHeadAttention にはquery = decoder_outputs
, key = value = encoder_outputs
を入力している。Attention は以下のような数式で表される。(スケールは除いている。)
Encoder の言語の padding と Decoder の言語の単語の間の重みに意味がないことから、この Attention の mask は Q(decoder_outputs)と K(encoder_outputs)の mask を合わせたものにする必要がある。具体的には次のような mask が必要になる。
[[ True True True True False False]
[ True True True True False False]
[ True True True True False False]
[False False False False False False]
[False False False False False False]
[False False False False False False]]
K が転置されて Q にかけられてることから、縦方向は Decoder の文字列の方向であり、横方向が Encoder の文字列の方向である。(ちなみに、K が転置されて Q にかけ合わさっていることで前に書いた n 行目が n 単語目までの情報しか保持しないという性質が保たれている。この性質が保たれてるからこそ Decoder の Self-Attention と Encoder-Decoder-Attention が交互に行えるようになり、言語間で混じり合いを強いモデルの作成に成功したのもこの論文のキモだと思う。)
実際に動かしてみて確認する。
mask = attention._compute_attention_mask(
decoder_inputs, encoder_inputs
)
print(mask)
# [[[ True True True True False False]
# [ True True True True False False]
# [ True True True True False False]
# [False False False False False False]
# [False False False False False False]
# [False False False False False False]]]
これは正しい結果である。ちなみに MultiHeadAttention の Layer は、supports_masking がデフォルトで有効化されているため、Encoder-Decoder-Attention まで mask が伝搬してきている。
おわりに
TensorFlow で transformer を実装する際の mask の仕様について説明した。基本的には何も考えずに作ればいい感じに実装できることがわかった。これまでやってきた LSTM の実装においても mask を意識して見直しをしたい。
この記事を書くうえで TensorFlow の実装の中身を少しみた。といってもドキュメントの view source のボタンを押しただけだが。view source を押したら keras の github リポジトリに飛ばされたのは少しびっくりした。いつも使ってる関数が TensorFlow 由来のものなのか Keras 由来のものなのかその辺は理解をしていきたい。
今回、mask を整理したことで MultiHeadAttention の理解がよく進んだ。特に Decoder の Attention は行列の掛け算の性質をうまく使って、情報が一方通行にしか流れないようにしているのが理解できて良かった。また、Encoder-Decoder-Attention の query と key と value について、どれをどれに設定しなければならないのかがよく理解ができた。(これをバラバラに指定すると因果律が崩れてうまく学習が進まないと思う。)記事中で紹介したキモの部分は別記事で解説したい。
Discussion