🙆

Sparse Transformerを理解したい

2022/12/06に公開

 Sparse Transformerを理解したい記事です[1]。Sparse TransformerはTransformerにスパース性を導入することで、必要メモリと計算量を減らした方法です。GPT-3のような巨大なモデルにも使用されています[2]。本記事ではまず、sparse transformerで新しく導入されたfactorized self-attentionについて紹介し、次にsparse transformerについてまとめます。

Sparse Transformerの凄いところ

  1. 省メモリなので、長い系列を扱える
  2. 計算が高速なので大規模ネットワークに使いやすい
  3. attentionをスパースにしているのに、性能が落ちていない

 では、どのようにしてこれらを実現しているのか見ていきましょう。

Factorized Self-Attention

 Sparse Transformerは、self-attetnionを複数のattentionに分解することで計算を効率化したTransfomerです。新しいattention (factorized attention) の計算方法として、strided attentionとfixed attentionが導入されました(下図)。


[1] Figure 3より

 図上段は 6 \times 6 の画像に対する2つのattention headsの例です。濃い青の位置に対するattetnionの範囲を薄い青で表しています。

 下段は行を出力、列を入力とした時のattentionの参照パターンの模式図です。つまり、(i, j)は、i番目の出力に対するj番目の入力のattentionの有無を示しています。

  • (a):通常のself attentionです。各出力が過去の時点全てを参照しています。
  • (b):strided attetnionと呼ばれるattentionです。遠くの点を周期的に参照していることと、直近数点を参照しています。
  • (c):fixed attetnionと呼ばれるattentionです。固定された過去の時点と、直近数点を参照しています。

 strided attentionとfixed attentionの参照パターンがこんなスカスカで大丈夫なのか、詳細を見ていきます。

Strided Attention

 説明のために、ここでは単純な系列を考えます。この時、strided attentionの参照パターンは以下のようになります(下図)。


strided attetnionの参照パターン

 上図からわかる通り、strided attentionを2回通すことで過去の時点を全て参照できています。

Fixed Attention

 ここでも単純な系列を考えます。この時、fixed attentionの参照パターンは以下のようになります(下図)。


fixed attetnionの参照パターン

 ここでは、fixed attentionを2回通した例を数パターン示しました。fixed attentionはattention head 2の参照先が固定されていることに注意してください。どのパターンでも全ての過去の点を参照できていることがわかります。図右側はstrided attentionと同じに見えますが、strided attentionではどの出力に対しても同じ参照パターンになる点で異なります。

Sparse Transformerの構成方法

基本構成

 Sparse Transformerのresidual blockは以下の構成になっています。


Sparse Transformerのresidual blockの構成([1] Figure 4より)

 ちょっと分かりにくいですが、GPT-2と同じ構造をしています[3]。GPT-2に関しては以下を参照してください。

https://zenn.dev/sunbluesome/articles/775ffd67fb7454

定式化

 [1]ではfactorized attentionを用いてsparse transformerを構成する方法を3通り紹介しています。数式を使うので、ここで導入の準備をしておきます。

 通常のattentionは以下のように定義されます。

\begin{align} \text{attention}(X) &= W_p \cdot \text{attend}(X, S) \\ \text{attend}(X, S) &= \left(a(\boldsymbol{x}_i, S_i)\right)_{i \in \{1, \ldots, n\}} \\ a(\boldsymbol{x}_i, S_i) &= \text{softmax}\left(\frac{\left(W_q, \boldsymbol{x}_i\right) K_{S_i}^T}{\sqrt{d}} V_{S_i}\right) \\ K_{S_i} &= \left(W_k \boldsymbol{x}_j\right)_{j \in S_i} \\ V_{S_i} &= \left(W_v \boldsymbol{x}_j\right)_{j \in S_i} \\ \end{align}

 ここで、x_i \in Xi番目の入力ベクトル、S_i \in Si番目の出力が参照する入力ベクトルのインデックス集合、W_{\{p, q, k, v \}}は重み行列です。特に、W_q, W_k, W_vは入力ベクトルx_iをクエリ、キー、バリューへ変換する為に使われます。\text{attend}\text{attention}に重み行列がかかっていないだけです。K_{S_i}, V_{S_i}はキーとバリューです。もしわからない場合は、以下の記事を参照してください。

https://zenn.dev/sunbluesome/articles/078ac9a9afca6a

 また、m番目のfactorized attention headのインデックス集合を以下のように表します。

\begin{align} A^{(m)} \sub S \end{align}

 さて、これで準備ができたので、それぞれの手法について見ていきます。

方法1:Interleave

 Factorized attention headを適用したresidual blockを順番に繋ぐだけです。以下のように定式化されます。

\begin{align} \text{attention}(X) = W_p \cdot \text{attend}(X, A^{(r \mod p)}) \end{align}

rは現在のresidual blockのインデックス、pはfactorized attention headの数です。

方法2:Merged head

 factorized attention headを1つのheadにまとめてしまう方法です。以下のように定式化されます。

\begin{align} \text{attention}(X) = W_p \cdot \text{attend}\left(X, \bigcup_{m=1}^{p} A^{(m)}\right) \end{align}

方法3:Multi-head attention

\begin{align} \text{attention}(X) = W_p \left(\text{attend}(X, A)^{(i)}\right)_{i \in \left\{1, \ldots, n_h\right\}} \end{align}

 ここで、n_hはheadの数で、各headは並列計算されます。最後に計算結果が特徴量次元に沿って連結されます。また、Aの構成方法には自由度があり、上記の方法1、2の両方とも使えます。

Positional Encoding

 Transformerでは、positional encodingを使用して系列における要素の位置情報と、特徴量次元に対する情報を付与していました。Sparse Transformerでは以下の式でデータ構造だけでなく、attentionのパターン情報も付与します。

\begin{align} \text{embed}(X, W_e) = \left(\boldsymbol{x}_i W_e + \sum_{j=1}^{n_{emb}} \boldsymbol{o}_i^{(j)} W_j\right)_{\boldsymbol{x}_i \in X} \end{align}

 ここで、\boldsymbol{x}_iは系列のone-hot encodingされたi番目の要素、\boldsymbol{o}_i^{(j)}\boldsymbol{x}_ij番目の次元を表すone-hotベクトルです。n_{emb}n_{emb} = d_{data}もしくはn_{emb} = d_{attn}です。d_{data}はデータの次元数、d_{attn}はfactorized attentionの次元数です。

d_{emb}についてですが、画像に対してはd_{emb} = d_{data} = 3、テキストや音声データに対してはd_{emb} = d_{attn} = 2を使用します。

実験結果

 最後に、Sparse Transformerの有効性を確認して終わりましょう(下図)。

Bits/byte


[1] Table 1より

 density modelingというタスクについて、CIFAR-10、Enwik8、ImageNet 64x64で性能評価を行っています。評価指標はBits/byte(画像タスクにおけるbits/dimと同義)です。Bits/byteについてはここでは解説しませんが、気になる方はPixelCNNの論文を読むとよいでしょう。要は負の対数尤度なので、小さいほど良い指標となります。

 いずれのデータセットでもSparse Transformerが最も良い性能になっています。

計算速度


[1] Table 2より

 Dense Attentionより高速化されていることがわかります。また、Bits/byteで見ても、通常のattentionより性能が良いことが分かります。

長期依存性の獲得


[1] Table 3より

 Enwik8において、モデルが受け入れるコンテキストウィンドウを大きくしていったところ、性能がどんどん良くなっていることがわかります。これはSparse Transformerが長期依存性を獲得できていることを示唆しています。

参考文献

  1. Child, R., Gray, S., Radford, A. & Sutskever, I. (2019). Generating Long Sequences with Sparse Transformers. arXiv. https://doi.org/10.48550/arxiv.1904.10509
  2. Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G., Henighan, T., Child, R., Ramesh, A., Ziegler, D. M., Wu, J., Winter, C., … Amodei, D. (2020). Language Models are Few-Shot Learners. arXiv. https://doi.org/10.48550/arxiv.2005.14165
  3. Radford, A., Wu, J., Child, R., Luan, D., Amodei, D. & Sutskever, I. (2019). Language Models are Unsupervised Multitask Learners.
  4. Chen, T., Xu, B., Zhang, C. & Guestrin, C. (2016). Training Deep Nets with Sublinear Memory Cost. arXiv. https://doi.org/10.48550/arxiv.1604.06174
  5. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L. & Polosukhin, I. (2017). Attention Is All You Need. arXiv. https://doi.org/10.48550/arxiv.1706.03762

Discussion