Sparse Transformerを理解したい
Sparse Transformerを理解したい記事です[1]。Sparse TransformerはTransformerにスパース性を導入することで、必要メモリと計算量を減らした方法です。GPT-3のような巨大なモデルにも使用されています[2]。本記事ではまず、sparse transformerで新しく導入されたfactorized self-attentionについて紹介し、次にsparse transformerについてまとめます。
Sparse Transformerの凄いところ
- 省メモリなので、長い系列を扱える
- 計算が高速なので大規模ネットワークに使いやすい
- attentionをスパースにしているのに、性能が落ちていない
では、どのようにしてこれらを実現しているのか見ていきましょう。
Factorized Self-Attention
Sparse Transformerは、self-attetnionを複数のattentionに分解することで計算を効率化したTransfomerです。新しいattention (factorized attention) の計算方法として、strided attentionとfixed attentionが導入されました(下図)。
[1] Figure 3より
図上段は
下段は行を出力、列を入力とした時のattentionの参照パターンの模式図です。つまり、
- (a):通常のself attentionです。各出力が過去の時点全てを参照しています。
- (b):strided attetnionと呼ばれるattentionです。遠くの点を周期的に参照していることと、直近数点を参照しています。
- (c):fixed attetnionと呼ばれるattentionです。固定された過去の時点と、直近数点を参照しています。
strided attentionとfixed attentionの参照パターンがこんなスカスカで大丈夫なのか、詳細を見ていきます。
Strided Attention
説明のために、ここでは単純な系列を考えます。この時、strided attentionの参照パターンは以下のようになります(下図)。
strided attetnionの参照パターン
上図からわかる通り、strided attentionを2回通すことで過去の時点を全て参照できています。
Fixed Attention
ここでも単純な系列を考えます。この時、fixed attentionの参照パターンは以下のようになります(下図)。
fixed attetnionの参照パターン
ここでは、fixed attentionを2回通した例を数パターン示しました。fixed attentionはattention head 2の参照先が固定されていることに注意してください。どのパターンでも全ての過去の点を参照できていることがわかります。図右側はstrided attentionと同じに見えますが、strided attentionではどの出力に対しても同じ参照パターンになる点で異なります。
Sparse Transformerの構成方法
基本構成
Sparse Transformerのresidual blockは以下の構成になっています。
Sparse Transformerのresidual blockの構成([1] Figure 4より)
ちょっと分かりにくいですが、GPT-2と同じ構造をしています[3]。GPT-2に関しては以下を参照してください。
定式化
[1]ではfactorized attentionを用いてsparse transformerを構成する方法を3通り紹介しています。数式を使うので、ここで導入の準備をしておきます。
通常のattentionは以下のように定義されます。
ここで、
また、
さて、これで準備ができたので、それぞれの手法について見ていきます。
方法1:Interleave
Factorized attention headを適用したresidual blockを順番に繋ぐだけです。以下のように定式化されます。
方法2:Merged head
factorized attention headを1つのheadにまとめてしまう方法です。以下のように定式化されます。
方法3:Multi-head attention
ここで、
Positional Encoding
Transformerでは、positional encodingを使用して系列における要素の位置情報と、特徴量次元に対する情報を付与していました。Sparse Transformerでは以下の式でデータ構造だけでなく、attentionのパターン情報も付与します。
ここで、
実験結果
最後に、Sparse Transformerの有効性を確認して終わりましょう(下図)。
Bits/byte
[1] Table 1より
density modelingというタスクについて、CIFAR-10、Enwik8、ImageNet 64x64で性能評価を行っています。評価指標はBits/byte(画像タスクにおけるbits/dimと同義)です。Bits/byteについてはここでは解説しませんが、気になる方はPixelCNNの論文を読むとよいでしょう。要は負の対数尤度なので、小さいほど良い指標となります。
いずれのデータセットでもSparse Transformerが最も良い性能になっています。
計算速度
[1] Table 2より
Dense Attentionより高速化されていることがわかります。また、Bits/byteで見ても、通常のattentionより性能が良いことが分かります。
長期依存性の獲得
[1] Table 3より
Enwik8において、モデルが受け入れるコンテキストウィンドウを大きくしていったところ、性能がどんどん良くなっていることがわかります。これはSparse Transformerが長期依存性を獲得できていることを示唆しています。
参考文献
- Child, R., Gray, S., Radford, A. & Sutskever, I. (2019). Generating Long Sequences with Sparse Transformers. arXiv. https://doi.org/10.48550/arxiv.1904.10509
- Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G., Henighan, T., Child, R., Ramesh, A., Ziegler, D. M., Wu, J., Winter, C., … Amodei, D. (2020). Language Models are Few-Shot Learners. arXiv. https://doi.org/10.48550/arxiv.2005.14165
- Radford, A., Wu, J., Child, R., Luan, D., Amodei, D. & Sutskever, I. (2019). Language Models are Unsupervised Multitask Learners.
- Chen, T., Xu, B., Zhang, C. & Guestrin, C. (2016). Training Deep Nets with Sublinear Memory Cost. arXiv. https://doi.org/10.48550/arxiv.1604.06174
- Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L. & Polosukhin, I. (2017). Attention Is All You Need. arXiv. https://doi.org/10.48550/arxiv.1706.03762
Discussion