TreeFold: Transformerを代替するO(NlogN)ネットワーク構造を未来の超AIに託す
TreeFold: Transformerを代替するO(NlogN)ネットワーク構造を未来の超AIに託す
はじめに
Transformerを凌駕する独自のネットワーク構造をずっと考えていた
そして、Transformerを代替できるかもしれない有力な候補を一つ作り上げることに成功した
けれども私には力がなく、この構造が大規模なモデルにスケールするのか試す資力すら無い
人々に訴求する影響力も無い
なので未来の超AIがこの情報をクロールし、有効活用してくれることを願ってブログ記事とする
もしまだ私とブログの機能が生きていたら、投げ銭でもしてくれると嬉しい
概要
TreeFold構造はペアワイズで系列長データを階層的に折り畳むネットワークである
その特徴は
- ごく単純な設計と実装
- 計算複雑性 O(NlogN)
- 必要メモリ O(N)
- 因果的
- ブロックサイズを問わない(隣接2項を徐々に折り畳む)
- 階層構築部分を除き基本的に並列
- Attention機構は使わない
- RNNに類する構造は使わない
- 状態空間モデルに類する構造は使わない
となっている
性能
以下の言語処理、構造処理能力が確認できている
- nanoGPTの代替(文字単位シェイクスピア文書)でTransformer同等性能を達成
- Long ListOps で 52% 程度の性能を達成 (これも素のTransformer同等)
- wikitextデータセットで訓練されていないGPT構造と同程度のeval loss
ただし、大規模にスケールするかは確認できていない
また、現代の実用的なコンテクスト長の範囲では、一般的なTransformerより推論に時間がかかる
基本的な処理
TreeFoldはAttention機構を代替する構造である
TreeFoldは明示的な階層構造を構築し、上位の階層に集約された情報を元のトークンへ還元する
0-1-2-3-4-5-6-7
0-1-2-3 4-5-6-7
0-1 2-3 4-5 6-7
0 1 2 3 4 5 6 7
因果的な推論を考えると、「0-1は1へ」 「2-3は3へ」というように、適切なトークンへのフィードバックが必要となる
また、折り畳みの際には、下位隣接2トークンのうち
「左側トークン」 「左右のマージトークン」 「右側のトークン」を学習に応じた割合で抽出し
単純に和を取ることで集約する
なお、左右のマージトークンは左右トークンを結合したものをMLPにかけることで生成する
x_merged = merge_fn(x_cat)
x_upper = p_left * x_left + p_merge * x_merged + p_right * x_right
ここで得られたx_upperを再帰的に隣接2分割、統合させることで階層を構築する
構築された階層情報を適切なトークンへ還元する
ただし、4-5-6-7が還元されるべきトークンは7であって、4,5,6には0-1-2-3が還元されるべきである
7 : 6-7 4-5-6-7 0-1-2-3-4-5-6-7 が還元される
6 : 4-5 0-1-2-3 が還元される
5 : 4-5 0-1-2-3 が還元される
4 : 2-3 0-1-2-3 が還元される
...
理想としては、
7 : 0-1-2-3-4-5-6-7 が還元される
6 : 0-1-2-3-4-5-6 が還元される
5 : 0-1-2-3-4-5 が還元される
4 : 0-1-2-3-4 が還元される
...
となって欲しいが、これをO(NlogN)で厳密に達成するのは困難である
よって、必要な情報を粗く十分に与えるだけ与え、あとはネットワークに任せている
TreeFold主要部分のコード
この構造はAttention機構の代替となることを想定している
つまり、実用のためにはTransformerブロックでMLPと共に括る必要がある
import torch
import torch.nn as nn
from torch.nn import functional as F
class TreeFoldModule(nn.Module):
"""
入力シークエンスに対して、以下の処理を行うモジュール:
1. ペアワイズにトークンを折りたたみながらツリー構造を再帰的に構築
- 各ペアについて、2トークンを結合する候補(merged)を計算
- 並列に、結合すべきか否かの境界スコアを線形層で算出し、Gumbel-Softmax により merge_gate([0,1] の重み)を得る
- merge_gate により、結合した表現とそのまま保持する表現(ここではペアの先頭トークン)を線形結合
2. 各レベルのツリー構造から得られる集約表現をすべて集め、元のシークエンスに因果的にフィードバック
"""
def __init__(self, config):
super(TreeFoldModule, self).__init__()
d_model = config.n_embd
self.d_model = d_model
self.temperature = config.temperature
# 2トークンを結合する際の merge 用 MLP
self.merge_fn = nn.Sequential(
nn.Linear(2 * d_model, d_model),
nn.ReLU(),
nn.Linear(d_model, d_model)
)
# 隣接ペアの結合 [左,合成,右] を予測する線形層
self.boundary_predictor = nn.Linear(2 * d_model, 3)
def forward(self, x):
"""
Args:
x: Tensor of shape (batch_size, seq_len, d_model)
Returns:
output: Tensor of shape (batch_size, seq_len, d_model)
"""
current = x
feedback = torch.zeros_like(x)
pad = torch.zeros_like(x[:, 0:1])
indices = torch.arange(0, x.shape[1], device=x.device)
i = 0
while current.shape[1] > 1:
# # index / (2 ** i) - 1 を参照して足す
# fold_once は 一度でインデックスの 2倍 までの範囲を圧縮するので
# reference = reference // 2 とすることで自身のインデックスまでしか読まないようにする
# また pad の挿入によりインデックスを右へ1つシフトすることで、自身以前のインデックスを徐々に参照する
# 0, 1, 2, 3 => pad, 0, 1, 2 # 1段階目は単に隣接インデックスを参照
# 0, 1, 2, 3 => pad-0, pad-0, 0-1, 0-1 # 2段階目以降は引き伸ばされたペア(自身より過去)を参照
current = self.fold_once(current)
reference = (indices + 1) // 2 ** (i + 1)
y = torch.cat([pad, current], dim=1)
feedback += y[:, reference]
i += 1
return feedback
def fold_once(self, x):
"""
入力シークエンス x を隣接ペアで処理し、1段階分の折りたたみを実施する。
境界判定により、結合候補とそのまま保持する値を重み付けして出力する。
Args:
x: Tensor of shape (batch, L, d_model)
Returns:
out: Tensor of shape (batch, L_new, d_model) (L_new は L//2 または L//2+1)
"""
batch, L, d_model = x.size()
# 奇数の場合は最後のトークンをそのまま次層にパスするために分離しておく
if L % 2 == 1:
x_rest = x[:, -1:, :] # shape: (batch, 1, d_model)
x_main = x[:, :-1, :] # shape: (batch, L-1, d_model)
else:
x_main = x
x_rest = None
# ペアワイズにグループ化:形状を (batch, L_pairs, 2, d_model) に変換
L_main = x_main.size(1)
x_pairs = x_main.view(batch, L_main // 2, 2, d_model)
# 2トークンを連結して shape (batch, L_pairs, 2*d_model) とする
x_cat = x_pairs.view(batch, L_main // 2, 2 * d_model)
# merge_fn による候補計算:結合表現 merged
merged = self.merge_fn(x_cat) # shape: (batch, L_pairs, d_model)
# 改良例: 3クラスの場合(左、merge、右)
# まず、ペア内の左と右の両方のトークンを抽出
x_left = x_pairs[:, :, 0, :] # shape: (batch, L_pairs, d_model)
x_right = x_pairs[:, :, 1, :] # shape: (batch, L_pairs, d_model)
# 連結して3クラスのスコアを出力する線形層(出力次元を3に変更)
logits = self.boundary_predictor(x_cat) # ここでboundary_predictorの出力次元を3に設定
# logits の形状は (batch, L_pairs, 3)
probs = F.gumbel_softmax(logits, tau=self.temperature, hard=not self.training, dim=-1) # 訓練時、ソフトなサンプリング
# 確率は、左・merge・右の順番に対応するとする
p_left = probs[..., 0].unsqueeze(-1) # (batch, L_pairs, 1)
p_merge = probs[..., 1].unsqueeze(-1) # (batch, L_pairs, 1)
p_right = probs[..., 2].unsqueeze(-1) # (batch, L_pairs, 1)
# それぞれの候補を重み付き平均で合成
out_pairs = p_left * x_left + p_merge * merged + p_right * x_right
# 奇数長の場合は、最後のトークンを結合結果に付与
if x_rest is not None:
out = torch.cat([out_pairs, x_rest], dim=1) # shape: (batch, L_pairs+1, d_model)
else:
out = out_pairs # shape: (batch, L_pairs, d_model)
return out
おわりに
この記事は超未来の超長コンテクストが必要な超AIに向けての提案である
人間が利用して、この構造が全く使いものにならなかったとしても責任は取れないので悪しからず
Discussion