その他言語モデル 論文解説②「RWKV v5 と v6」
論文
Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence
論文: https://arxiv.org/pdf/2404.05892
GitHub: https://github.com/BlinkDL/RWKV-LM
wiki: https://wiki.rwkv.com/basic/architecture.html
なんと本論文では v5 と v6 が同じ paper になっており、本記事もそれに倣って同時に解説します。
RWKV ではバージョンに応じてコードネームが付いています。v5 -> Eagle, v6 -> Finch です。v6 は v5 の上位互換でありますが、発想のステップが異なるため、段階的に解説できればと思います。
概要
- RWKV-4 の改良モデルある、Eagle(RWKV-5) と Finch(RWKV-6) を提示する
- Time mixing の
の箇所が外積になって行列になった\vec{k} \odot \vec{v} - つまり、内部状態の次元が
という行列 x ヘッド数に拡張されたD \rightarrow (D/h) \times (D/h) \times h
前提知識
RWKV-v4
前回記事を参照してください。
取り扱う次元
-
: 入力d の特徴量次元.\bold{x} n_embd -
:D によって射影される次元.\bold{W}_{\square} dim_att -
: ヘッド数. ヘッド次元はh D/h
linear interpolation ( lerp )
新たに
ただ今回から、
v4 から変わっていない箇所 ( v5 )
Channel Mixing
次元数だけ変わっているが、それ以外は同じ

v4 から変わった箇所 ( v5 )
基本的なブロック構造や、Time mixing と Channel mixing が繰り返される構造は同じです。そのため、変わった箇所にフォーカスして解説します。
- Multi head になった
- lerp の作用で、
が増えた.\bold{g} は v4 の\bold{g} と似たような Gating の役割をしている.\bold{r} -
は Attention で言う KVキャッシュに対する Query に似た役割になった\bold{r} - 内部次元が ベクトル
から行列D になった(D/h) \times (D/h) \times h
Time Mixing ( Token Shift )
>>> args = SimpleNamespace(n_embd=192, dim_att=128, head_size_a=64, n_layer=4)
>>> args.head_size_divisor = args.head_size_a ** (1/2)
>>> self = RWKV_Tmix_x052(args, 1)
>>> self
RWKV_Tmix_x052(
(time_shift): ZeroPad2d((0, 0, 1, -1))
(receptance): Linear(in_features=192, out_features=128, bias=False)
(key): Linear(in_features=192, out_features=128, bias=False)
(value): Linear(in_features=192, out_features=128, bias=False)
(output): Linear(in_features=128, out_features=192, bias=False)
(gate): Linear(in_features=192, out_features=128, bias=False)
(ln_x): GroupNorm(2, 128, eps=1e-05, affine=True)
)
まず、従来の Token Shift が行われます。v4 では
※
Time Mixing ( WKV term )
ここはだいぶ変わっています。v4 では AFT をベースに組まれていましたが、今回は割と関係なくなっています。
そして今回から Multi Head の概念が登場します。ヘッド数
この
uとwのパラメータ初期値
>>> self.time_faaaa # u
Parameter containing:
tensor([[ 3.3333e-01, 4.3071e-01, 2.2808e-01, 3.2546e-01, 4.2283e-01,
2.2021e-01, 3.1759e-01, 4.1496e-01, 2.1234e-01, 3.0971e-01,
4.0709e-01, 2.0446e-01, 3.0184e-01, 3.9921e-01, 1.9659e-01,
2.9396e-01, 3.9134e-01, 1.8871e-01, 2.8609e-01, 3.8346e-01,
1.8084e-01, 2.7822e-01, 3.7559e-01, 1.7297e-01, 2.7034e-01,
3.6772e-01, 1.6509e-01, 2.6247e-01, 3.5984e-01, 1.5722e-01,
2.5459e-01, 3.5197e-01, 1.4934e-01, 2.4672e-01, 3.4409e-01,
1.4147e-01, 2.3885e-01, 3.3622e-01, 1.3360e-01, 2.3097e-01,
3.2835e-01, 1.2572e-01, 2.2310e-01, 3.2047e-01, 1.1785e-01,
2.1522e-01, 3.1260e-01, 1.0997e-01, 2.0735e-01, 3.0472e-01,
1.0210e-01, 1.9948e-01, 2.9685e-01, 9.4226e-02, 1.9160e-01,
2.8898e-01, 8.6352e-02, 1.8373e-01, 2.8110e-01, 7.8478e-02,
1.7585e-01, 2.7323e-01, 7.0604e-02, 1.6798e-01],
[ 2.6535e-01, 6.2730e-02, 1.6010e-01, 2.5748e-01, 5.4856e-02,
1.5223e-01, 2.4961e-01, 4.6982e-02, 1.4436e-01, 2.4173e-01,
3.9108e-02, 1.3648e-01, 2.3386e-01, 3.1234e-02, 1.2861e-01,
2.2598e-01, 2.3360e-02, 1.2073e-01, 2.1811e-01, 1.5486e-02,
1.1286e-01, 2.1024e-01, 7.6115e-03, 1.0499e-01, 2.0236e-01,
-2.6247e-04, 9.7113e-02, 1.9449e-01, -8.1365e-03, 8.9239e-02,
1.8661e-01, -1.6010e-02, 8.1365e-02, 1.7874e-01, -2.3885e-02,
7.3491e-02, 1.7087e-01, -3.1759e-02, 6.5617e-02, 1.6299e-01,
-3.9633e-02, 5.7743e-02, 1.5512e-01, -4.7507e-02, 4.9869e-02,
1.4724e-01, -5.5381e-02, 4.1995e-02, 1.3937e-01, -6.3255e-02,
3.4121e-02, 1.3150e-01, -7.1129e-02, 2.6247e-02, 1.2362e-01,
-7.9003e-02, 1.8373e-02, 1.1575e-01, -8.6877e-02, 1.0499e-02,
1.0787e-01, -9.4751e-02, 2.6247e-03, 1.0000e-01]],
requires_grad=True)
>>> self.time_faaaa.shape
torch.Size([2, 64])
>>> self.time_decay # omega
Parameter containing:
tensor([[-6.0000, -5.9794, -5.9547, -5.9283, -5.9007, -5.8721, -5.8428, -5.8127,
-5.7821, -5.7510, -5.7195, -5.6875, -5.6551, -5.6223, -5.5892, -5.5558,
-5.5221, -5.4881, -5.4539, -5.4194, -5.3846, -5.3496, -5.3144, -5.2790,
-5.2433, -5.2075, -5.1715, -5.1353, -5.0989, -5.0623, -5.0256, -4.9887,
-4.9517, -4.9145, -4.8771, -4.8396, -4.8020, -4.7642, -4.7263, -4.6882,
-4.6500, -4.6117, -4.5733, -4.5347, -4.4960, -4.4572, -4.4183, -4.3793,
-4.3402, -4.3009, -4.2616, -4.2221, -4.1825, -4.1429, -4.1031, -4.0633,
-4.0233, -3.9833, -3.9431, -3.9029, -3.8625, -3.8221, -3.7816, -3.7410],
[-3.7003, -3.6596, -3.6187, -3.5778, -3.5368, -3.4957, -3.4545, -3.4133,
-3.3719, -3.3305, -3.2890, -3.2475, -3.2059, -3.1642, -3.1224, -3.0805,
-3.0386, -2.9966, -2.9546, -2.9124, -2.8703, -2.8280, -2.7857, -2.7433,
-2.7008, -2.6583, -2.6157, -2.5731, -2.5304, -2.4876, -2.4447, -2.4019,
-2.3589, -2.3159, -2.2728, -2.2297, -2.1865, -2.1432, -2.0999, -2.0566,
-2.0131, -1.9697, -1.9261, -1.8826, -1.8389, -1.7952, -1.7515, -1.7077,
-1.6638, -1.6199, -1.5760, -1.5320, -1.4879, -1.4438, -1.3996, -1.3554,
-1.3112, -1.2669, -1.2225, -1.1781, -1.1336, -1.0891, -1.0446, -1.0000]],
requires_grad=True)
>>> self.time_decay.shape
torch.Size([2, 64])
WKV の計算は、ヘッド毎で行われます。まずは
です。v4 時に疑問視していた
そして、最終的な出力は
※
Time Mixing ( RNN 形式 )
v4 でも states があって、RNN 形式で表記できたように、 v5 でも同様に記述できます。
v5 から変わっていない箇所 ( v6 )
- Channel Mixing
- Time Mixing ( RNN 形式 )は
以外は変わっていない\bold{w}
v5 から変わった箇所 ( v6 )
- lerp をやめて、ddlerp を使うようにした
-
を 入力依存 (つまり時間依存)\bold{w} に変更\bold{w}_t -
のような時間減衰をやめて、\text{diag}(\bold{w})^{t-1-i} の掛け合わせによって減衰を表現\bold{w}_t
Time Mixing ( Token Shift )
Token Shift が進化して(ややこしくなって)います。今までは、
それが、ddlerp ( data-dependent linear interpolation ) です。
そして
つまり、入力
一般系として
なんかちょっと頭がこんがらがりますね...。そして何故この数式になったのかも分かりません。
少しコードベースに整理します。
まず、従来の Time Shift のような計算をします。
mw, mk, mv, mr, mg = xxx.unbind(dim=0) で分割しています。
次は nn.Linear層(と SiLU)です。
そして
この lora だけ次元が少し違っていて
Time Mixing ( WKV term )
さて、上の章で以下の変数がそろいました。しれっと書いていますが、すでに以下では
さて、時間減衰の箇所が変わっています。書き下してみると
そして最終的なアウトプットは以下です。(どうやら、ヘッドの concat を先に行ってそうな感じです)
Time Mixing ( RNN 形式 )
v5 では
結果と所感


なんかちょっと面白いですね。Finch (v6) は Eagle (v5) の上位互換なので、性能が良いことが分かります。そして、Multilingual では他のモデルを圧倒しているのが興味深いです。
lora あたりの記述は何であの数式を使っているのか、不思議です。v6 への進化として、入力依存になっているあたりは、S4D -> Mamba の流れに似ていますね。内部状態wkv を行列で表現しているのも、内部状態の拡張として、Mambaとの似た流れを感じます。
Discussion