RWKV hxa079 アーキテクチャ解説
RWKV hxa079 アーキテクチャ解説
こんにちは。お元気ですか?
今日は、私がRWKV v7をベースに魔改造した RWKV hxa079 アーキテクチャ についてのメモです
背景
RWKV hxa079 は、BlinkDL 氏が提案した RWKV v7 (x070) “Goose” アーキテクチャをベースに、以下の目的で改造したモデルです。
- Transformer Attention への変換を容易にする
- トレーニング・推論速度を向上させる
RWKV v7 "Goose" は、Receptance / Key / Value のプロジェクション層を中心に構成されており、次のような特徴があります。
- Tokenshift(1トークン前の HiddenState との差分を残差接続)
- GroupNorm(出力正規化)
- v_first(Layer 0 の Value 出力を後段層に残差接続)
x070 の構造
RWKV v7 の Attention ブロックは、以下のような構造になっています。
xr = x + xx * self.x_r
xw = x + xx * self.x_w
xk = x + xx * self.x_k
xv = x + xx * self.x_v
xa = x + xx * self.x_a
xg = x + xx * self.x_g
...
v = self.value(xv)
if self.layer_id == 0:
v_first = v
else:
v = v + (v_first - v) * torch.sigmoid(...)
Tokenshift により、各プロジェクション入力が「前のトークンとの差分付き HiddenState」になります。この構造は RWKV 的な時間依存性を持たせる上で有効ですが、推論速度や Transformer との互換性の面では課題があるんです。
課題
1. 推論速度の制約
Tokenshift が各プロジェクションに適用されるため、並列化の際に BMM (Batch Matrix Multiply) を使う必要があります。
しかし、多くの量子化実装は BMM に非対応で、結局シリアル実行せざるを得ない場面が多くなります。
2. Transformer 変換時の学習難易度
Transformer から RWKV へ変換する際、初期重みを教師モデルから継承しますが、Tokenshift が入っていると構造が変わりすぎており、学習が非常に難しくなります。
収束する場合もありますが、基本的に学習時間が2倍以上かかってしまいました。
hxa079 の改良点
こうした背景から、hxa079 では以下の改造を行いました。
- Tokenshift 削除 — 並列計算性を向上
- GroupNorm 削除 — 学習安定性を向上
- k_first の追加 — Layer 0 の Key を残差接続し、勾配が流れやすくなるよう改善
- meta-ICL の計算方法変更 — 内部のmeta icl学習率調整処理を再設計
hxa079 のコード例
def forward(self, x, v_first, k_first, attention_mask, position_embeddings, position_ids, x_emb):
B, T, C = x.size()
H = self.num_attention_heads
# Tokenshift 削除
r = self.receptance(x).view(B, T, H, -1)
k = self.key(x).view(B, T, self.num_key_value_heads, -1)
...
if self.layer_id == 0:
v_first = v
k_first = k
else:
v = v + (v_first - v) * torch.sigmoid(...)
k = k + (k_first - k) * torch.sigmoid(...)
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
...
x = RUN_CUDA_RWKV7g(r, w, k, v, -kk, kk*a, self.head_size, attention_mask)
...
return x, v_first, k_first
効果
-
Transformer → RWKV 変換が容易
- 一般的な Transformer Attention を、そのまま hxa079 に変換可能
-
学習コスト削減
- 約 500M トークン で変換学習が完了
-
推論高速化
- Tokenshift 削除により、量子化実装でも並列処理がしやすくなる
まとめ
hxa079 は、RWKV の持つ「時間方向の再帰性」という強みを残しつつ、Transformer からの移植性と推論速度をすこしだけ改善したアーキテクチャです。
特に、量子化や高速化実装との相性が良いため、大規模モデルの推論基盤としても活用可能だと信じています。
RWKVの人口増えるといいなぁ
Discussion