BP-Transformer: 二分木構造を用いた効率的なAttention
はじめに
この記事は論文 [1]"BP-Transformer: Modelling Long-Range Context via Binary Partitioning" の解説です。要旨を一言で言うと、Transformer の自己注意機構において、入力系列を二分木構造に基づいた階層的スパン(multi-scale spans)に分割することで、計算量を
論文内の手法の解説
本論文の手法は、図 3 に要約されます。以下ではこれの図の意味を説明します。
考える設定と既存手法の問題点
Transformer は自然言語処理の多くのタスクで優れた性能を示していますが、自己注意機構(self-attention)の計算量が入力系列長
従来の Transformer の計算時間、パラメータ数
長さ
本論文の手法:BP-Transformer
BP-Transformer (BPT) は、グラフニューラルネットワークの観点から Transformer を再定式化し、効率的なグラフ構造を構築することで計算量を削減します。
Transformer をグラフニューラルネットワークとして捉える
まず、Transformer の自己注意機構をグラフニューラルネットワーク (GNN) として解釈します。入力トークンをノード、注意機構による情報伝達をエッジとみなすと、通常の Transformer は完全連結グラフ上でのメッセージパッシングと考えられます。
グラフ
注記(表記について): 本稿では実装や一般的な表記に合わせ、モデル次元を
-
: グラフ(ノードと有向エッジの集合)。ノードにはトークンノードとスパンノードが含まれる。\mathcal{G} -
: グラフのノード。通常はノードu, v の更新を考える時、u は近傍ノードを意味する。v\in\mathcal{A}(u) -
: ノード\mathcal{A}(u) の近傍集合(incoming predecessors)。GSA ではこの集合だけに attention を向ける。u -
: ノード\mathbf{h}_u,\,\mathbf{h}_v\in\mathbb{R}^{d} の表現ベクトル(層入力)。ここでu,v はモデル(トークン)次元。トd -
: 近傍ノードの表現を縦に並べた行列(形状:\mathbf{A}^u )。|\mathcal{A}(u)|\times d -
: attention ヘッドの数(head 数)。式中の添字h はヘッドを表す(i )。i=1,\dots,h -
: 各ヘッドの出力次元(per-head dimension)。d_k -
: ヘッド\mathbf{W}^Q_i,\mathbf{W}^K_i,\mathbf{W}^V_i \in\mathbb{R}^{d\times d_k} の線形投影行列(入力はモデル次元i 、出力はヘッド次元d )。d_k -
: ノード\mathbf{Q}^u_i=\mathbf{h}_u\mathbf{W}^Q_i\in\mathbb{R}^{d_k} のヘッドu によるクエリベクトル。i -
: 近傍ノード群のヘッド\mathbf{K}^u_i=\mathbf{A}^u\mathbf{W}^K_i\in\mathbb{R}^{|\mathcal{A}(u)|\times d_k} によるキー行列(行ごとに各近傍ノードのキー)。i -
: 近傍ノード群のヘッド\mathbf{V}^u_i=\mathbf{A}^u\mathbf{W}^V_i\in\mathbb{R}^{|\mathcal{A}(u)|\times d_k} によるバリュー行列。i -
: ヘッド\mathrm{head}^u_i=\mathrm{softmax}\left(\dfrac{\mathbf{Q}^u_i(\mathbf{K}^u_i)^T}{\sqrt{d_k}}\right)\mathbf{V}^u_i\in\mathbb{R}^{d_k} の出力(i 個の注意重みを用いて加重平均したベクトル)。|\mathcal{A}(u)| -
: 複数ヘッドを結合した後に使う線形射影行列(concat(heads) をモデル次元\mathbf{W}^O\in\mathbb{R}^{h d_k\times d} に戻す)。d
通常の Transformer では、全てのトークンが互いに接続された完全グラフを使用しています。つまり、通常の Transformer では、
BPT の核心は、この完全グラフをより疎なグラフ構造に置き換えることです。具体的には、適当な定数
この構造の詳細を以下で説明します。
グラフの構築:Binary Partitioning (二分分割)
BPT の鍵となるアイデアは、細かい粒度から粗い粒度へ(fine-to-coarse) という帰納的バイアスを導入することです。つまり、あるトークンから見て、近い文脈は細かい粒度で、遠い文脈は粗い粒度で表現します。
これを実現するために、入力系列を Binary Partitioning(二分分割) によって階層的なスパンに分割します。
ノードの構築
長さ
- 系列を再帰的に 2 つに分割していきます
- 各分割が単一のトークンになるまで続けます
- 結果として、
個のトークンノード(葉ノード)とn 個のスパンノード(内部ノード)が生成されますn-1
形式的には、ノード
BPT では、2 種類のノードを区別します:
- トークンノード (Token nodes): 二分木の葉ノード。入力系列の各トークンに対応
- スパンノード (Span nodes): 二分木の内部ノード。複数のトークンをまとめた範囲を表現
エッジの構築
二分木構造をそのまま使用すると、遠いトークン間の情報伝達パスが長くなってしまいます。そこで、BPT では 2 種類のエッジを導入します:
1. Affiliated Edges (所属エッジ)
スパンノード
所属エッジにより、スパンノードの表現は、そのスパンに含まれるトークンノードの情報を直接集約して計算されます。
2. Contextual Edges (文脈エッジ)
トークンノード
トークンノード
- レベル 0(トークンレベル):
u_{0,p_0}, \ldots, u_{0,p_0+k-1} - レベル 1(小スパン):
u_{1,p_1}, \ldots, u_{1,p_1+k-1} - レベル 2(中スパン):
u_{2,p_2}, \ldots, u_{2,p_2+k-1} - ...以降同様
ここで、レベル
p_0 = i+1 p_l = \mathrm{parent}(p_{l-1} + k)
左側の文脈も同様に構築します。ここで、
この構造により、各レベルで(ほぼ※)
(※)実装上の効率のため(?)、もし
計算量とメモリ使用量
長さ
-
ノード数:
(O(2n) 個のトークンノード +n 個のスパンノード)n-1 -
エッジ数:
O(k \cdot n \log(n/k)) -
計算量:
O(d \cdot k \cdot n \log(n/k)) -
メモリ使用量:
O(k \cdot n \log(n/k))
例えば、計算量について比較すると、通常の Transformer の
重要な性質として、BPT のグラフでは 任意の 2 つのトークンノード間の距離が最大でも 2 です。なぜなら、トークンノードが直接結ばれる場合(距離
グラフの更新
グラフ構造を構築した後、全てのノードの表現を Graph Self-Attention を用いて更新します。
有向グラフ
- トークンノードの先行ノードは、それに注意を向ける各層のノード(トークンノードもスパンノードも含まれる)
- スパンノードの先行ノードは、それに含まれる全てのトークンノード
スパンノードの表現は初期値としてゼロベクトルで初期化され、トークンノードの表現は対応する単語埋め込みで初期化されます。
通常の Transformer と同様に、複数のグラフ層をスタックできます。各層で、以下の更新を行います:
-
: グラフ\text{GSA}^{(t)}(\mathcal{G}, \cdot) の下での graph self-attention の出力\mathcal{G} -
: feed foward net の出力\text{FFN}^{(t)}(\cdot) -
: 層\mathbf{H}^{t-1} における attention 層への入力(t は初期トークン埋め込み)\mathbf{H}^{0} -
: 層\mathbf{Z}^{t} における feed foward 層への入力t
(本記事の著者注)
つまり、Post‑Norm による学習(残差和のあとに LayerNorm を適用する)を採用していまるようです。
この論文出版時点と同時期〜それ以降頃に、 Pre-Norm ( LayerNorm を先に適用してからサブ層=Attention を通過する )ような学習法が学習安定性の観点から好ましいことが知られており、ここは1つ改善点になるかと考えています。
木構造上の相対位置エンコーディング
通常の Transformer では絶対位置エンコーディングや系列上の相対位置エンコーディングが使用されますが、BPT では木構造に基づく相対位置エンコーディングを導入します。
ノード
-
:r_{v,u} = r^{\text{self}} のとき(自己ループ)v = u -
またはr_{v,u} = r^{\text{left}}_{j,i} :r^{\text{right}}_{j,i} がv の左/右文脈の第u レベルでj 番目に追加されたノードのときi -
:r_{v,u} = r^{\text{anc}}_j がu のレベルv での祖先ノードのときj
これらの位置表現を用いて、attention の計算を以下のように修正します:
全ての
(本記事の著者注)
相対位置埋め込みについても、上記の方針ではノード
この埋め込み手法は比較的シンプルなタスクでは精度が良い一方、埋め込みの相対距離に対しての連続性が担保されないことや、埋め込みを行う相対距離を伸ばした際の計算量の問題などがあります。
近年では、より効率的な相対位置埋め込みとして本論文出版後に RoPE (Rotary Positional Embeddings) のような手法が提案され、実務でもよく使われています。標準的な RoPE は sin/cos による回転を用するため追加の学習パラメータを持たず、計算コストが低く、訓練時より長い系列長への外挿(extrapolation)にも比較的強いという利点があります。一方で RoPE は連続的な位置情報を扱う設計であり、BP のような木構造に固有の「左/右」や「どの階層で追加されたか」といったカテゴリ的な関係をそのまま明示的に表現するわけではありません。
簡潔に利点・注意点をまとめると:
- RoPE の利点:追加学習パラメータが不要で外挿に強く、トークン単位の位置処理に低コストで適用可能。
- 学習ベクトル r の利点:木構造固有のカテゴリ情報(階層性・左右性)を明示的に表現できる。
そのため、BP-Transformer の実装では次のようなハイブリッド設計が現実的で有用となる可能性が考えられます。すなわち、トークン間の相対位置情報には RoPE を適用して連続的な距離情報と外挿性を確保し、スパン間やレベル間の構造的な関係(
k の役割
ハイパーパラメータ
-
が小さい:疎なグラフ、計算量削減、長距離文脈は粗い粒度のみk -
が大きい:密なグラフ、計算量増加、長距離文脈も細かい粒度で捉えるk -
を極端に大きくする(k )とほぼ全トークンを参照するため計算量・メモリ使用量は通常の Transformer に近づき、k\to N においては実効的に通常の Transformer と一致します。(ただしスパンノードや所属エッジ自体は残るため、不要な計算が追加された状態になります)k>=N
実験では、単語レベルのタスクでは
実用例:テキスト分類タスク
IMDB データセット(平均長 294 単語のレビュー文)でのテキスト分類を例に説明します。
- 各レビュー文を二分分割により階層的スパンに分割
-
とすると、各トークンは左右それぞれ約k=4 個のノードに接続4 \log_2 294 \approx 32 - 通常の Transformer では約 294 個のノードに接続するのと比べ、大幅に疎なグラフ
- 最終層のルートノード(全系列を表すスパン)の表現を分類に使用
この設定で、BPT は通常の Transformer を精度で上回りつつ、メモリ使用量と計算時間を大幅に削減します。
結果
論文では、文レベルと文書レベルの様々なタスクで実験が行われています。以下に主要な結果を要約します。
テキスト分類
SST-5 データセット(短文)
平均長 19 単語の感情分類タスク(5 クラス)です。
- BPT (
): 52.71% (標準偏差 0.32)k=2 - Star Transformer: 52.9%
- Transformer: 50.4%
- Bi-LSTM: 49.8%
BPT は短文でも通常の Transformer を上回る性能を示し、構造的な帰納的バイアスが有効であることを示しています。
IMDB データセット(長文)
平均長 294 単語のレビュー文の感情分類タスク(2 クラス)です。
- BPT (
): 92.12% (標準偏差 0.11)k=4 - BCN+Char+CoVe(追加の事前学習利用): 91.8%
- QRNN: 91.4%
- Star Transformer: 90.50%
- Transformer: 89.24%
BPT は長文において特に大きな改善を示し、通常の Transformer を 2.88 ポイント、Star Transformer を 1.62 ポイント上回りました。
k の影響
IMDB データセットで
系列シフトへの頑健性
二分分割の性質上、系列の先頭に要素を挿入するとグラフ構造が変わります。しかし、SST-5 で先頭に 0〜7 個のパディングを追加する実験を行った結果、精度の変化は小さく(最大で 1.2 ポイント程度)、BPT がシフトに対して頑健であることが示されました。
言語モデリング
文字レベルの言語モデリングデータセット Enwik8 と Text8 で評価しました(BPC: bits-per-character で評価、低い方が良い)。
Enwik8
- BPT (
, 文脈長 8192): 1.02 BPCk=64 - Adaptive Span Transformer: 1.02 BPC
- Transformer-XL: 1.06 BPC
- Transformer: 1.11 BPC
- mLSTM: 1.24 BPC
Text8
- BPT (
, 文脈長 8192): 1.11 BPCk=64 - Adaptive Span Transformer: 1.11 BPC
- Recurrent Highway Networks: 1.27 BPC
- HM-LSTM: 1.29 BPC
BPT は、パラメータ数 38M で最先端の性能を達成しました。これは他の手法と同等かより少ないパラメータ数です。
文脈長の影響
- Enwik8: 1.07 (512) → 1.02 (8192)
- Text8: 1.16 (512) → 1.11 (8192)
これは、BPT が長距離依存関係を効果的に利用できることを示しています。
Attention Degree(注意密度)の比較
文脈長 512 に固定し、異なるスパースアテンション手法を比較しました:
- BPT: 様々な
値で最も効率的な性能k - Restricted Attention(ウィンドウサイズ制限): 同じ attention degree で BPT より低性能
- Sparse Transformer: 同じ attention degree で BPT より低性能
これは、BPT の fine-to-coarse 構造が他のスパース化手法より効果的であることを示しています。
機械翻訳
文書レベル機械翻訳(IWSLT 2015 中英翻訳)
平均文書長 120 文に対して評価しました(BLEU スコアで評価)。
- BPT (
, 文脈長 64): 19.84 BLEUk=4 - BPT (
, 単一文): 19.19 BLEUk=4 - Transformer(単一文、著者実装): 18.91 BLEU
- HAN-NMT: 17.78 BLEU
- Transformer+cache: 17.32 BLEU
BPT は文書レベルの文脈を利用することで、単一文のみのモデルより 0.65 ポイント改善しました。
文脈長の影響
文脈長を増やしても、BPT は性能が安定していました:
- Transformer: 18.85 (文脈 0) → 15.55 (文脈 128) と大幅に悪化
- BPT (
): 19.19 (文脈 0) → 19.84 (文脈 128) と改善を維持k=4
これは、BPT の帰納的バイアスが過学習を防ぎ、長文脈でも効果的に学習できることを示しています。
文レベル機械翻訳(WMT14 英独翻訳)
大規模データセット(4.5M 文対)での評価です。
- BPT (
): 27.6 BLEUk=4 - Transformer: 27.3 BLEU
- ConvS2S: 25.16 BLEU
- GNMT+RL: 24.6 BLEU
BPT は同じパラメータ数の通常の Transformer を上回り、
スループットと GPU メモリ使用量
BPT では、#計算量とメモリ使用量 に記載したように、計算量及びメモリ使用量の面で通常の Transformer より優れています。
そこで、疎なアテンションのための一連の CUDA カーネルを設計し、言語モデリングの設定(6 層、
GPU メモリ使用量
論文 Figure.6 に示されるように、系列長が増加するにつれて:
- Transformer: 急激に増加(8192 で約 10.5GB)
- BPT (
): 緩やかに増加(8192 で約 4GB)k=1 - BPT (
): 緩やかに増加(8192 で約 5.5GB)k=4 - BPT (
): やや増加(8192 で約 8GB)k=64
BPT は常に Transformer より少ないメモリで動作します。
スループット(トークン/秒)
論文 Figure.7 に示されるように、
- 短い系列(〜1024): Transformer の方が速い(ノード数が 2 倍になるオーバーヘッドのため)
- 長い系列(2048〜): BPT の方が速く、差は系列長が長いほど顕著
- 系列長 8192: BPT (
)は約 50,000 トークン/秒、Transformer は処理困難k=4
ことがわかります。すなわち、BPT は Transformer が処理困難な長い系列でも、速度とメモリ効率の両面で実用的な性能を示します。
まとめ
BP-Transformer の主要な成果:
-
計算量削減:
からO(n^2) へ削減O(k \cdot n \log(n/k)) - 長距離依存: 任意のトークンペア間の距離が最大 2、効率的な長距離依存学習
- 帰納的バイアス: fine-to-coarse 構造により、短文でも過学習を防ぎ、長文で効果的
- 汎用性: テキスト分類、言語モデリング、機械翻訳で一貫して高性能
- 効率性: 実装レベルでも、長文で大幅なメモリ削減と速度向上を実現
参考文献
[1]: BP-Transformer: Modelling Long-Range Context via Binary Partitioning(https://arxiv.org/abs/1911.04070)
Discussion