自作インタープリターをJITコンパイルで高速化する その2: Whileの高速化
はじめに
前回の続きです。
前回はJITコンパイラを書き、関数ごとにネイティブコードを生成して高速化するというのをやってみました。
今回はTracing JITっぽい手法を入れてWhile文の高速化と、あと命令の最適化を行ってみたのでそれについて書きます。
今回の記事は以下のリポジトリのコミット時点でのコードをベースにしています。
高速化するコードについて
今回高速化するのは以下のコードです。
平たく言うと、「5000x5000の掛け算表の総和を求める」というコードです。
fun calc() do
let n = 5000;
let sum = 0;
let i = 1;
while (i < n) do
let j = 1;
while (j < n) do
sum = sum + i * j;
j = j + 1;
end
i = i + 1;
end
return sum;
end
fun main() do
return calc();
end
初回時点(何も最適化を入れていない時点)では3秒ほどかかります。
hyperfine --warmup 3 "zig build run -Doptimize=ReleaseSafe -- run ./example/calc_squares.ob --noopt --nojit"
Benchmark 1: zig build run -Doptimize=ReleaseSafe -- run ./example/calc_squares.ob --noopt --nojit
Time (mean ± σ): 3.155 s ± 0.067 s [User: 3.087 s, System: 0.108 s]
Range (min … max): 3.052 s … 3.241 s 10 runs
参考実装で、ほぼそのままRustに移植したコードを計測します。
hyperfine --warmup 3 "cargo run --release"
Benchmark 1: cargo run --release
Time (mean ± σ): 178.8 ms ± 3.7 ms [User: 30.3 ms, System: 9.1 ms]
Range (min … max): 173.1 ms … 189.4 ms 16 runs
3000msと180msで大体15倍くらい差があることになります。これを高速化しましょう。
[2025/02/24 追記] 今回計測しているのは以下の環境です。
- OS: MacOS 14.6.1
- CPU: M2 Pro
- Zig: 0.13.0
- Rustc: 1.83.0
TraceとJIT
世の中にはTracing JITと言われる手法があり、ざっくりと言えば「命令列のトレース」をとってそこをネイティブコードに置き換える手法です。
例えば上記のコードだと2重whileになっているところが一番重いですが、IR上だと以下のようなコードになっています。
while_cond:
CONDITION ;; 条件
jump_if_not while_end ;; 条件が成立しなければwhile_endに飛ぶ
BODY ;; 処理
jump while_cond
while_end:
これの実行された命令列のトレースをとると、whileが結構な回数回っていると仮定するなら以下のパートが大量に並ぶことになります。
while_cond:
CONDITION
jump_if_not while_end
BODY
jump while_cond
ここの部分をネイティブコードにしてしまおうと言うことですね。
本来のTracing JITであればトレースをとっておき、hot spotを探してそこを高速化の対象とします。が、実装が大変なので今回は「何回も呼ばれるWhile文」だけを捉えるようにしました。
具体的には、上記コードのjump while_condに相当する、前方向にジャンプする処理に着目し、それの回数をラベルごとに記録します。これが一定数(今回は10回としました)貯まったら、次回のラベルジャンプで命令列の記録をとって、それをネイティブコードに落とすというような実装です。
hot spotを探すのがラベルのカウントで済むのでわかりやすいでしょう。
exit pathのチェックとIPの制御
さて上記のような実装をするわけですが、一つ困ることがあります。
それは命令列のトレースをとっても必ずしもコードがその中で完結しないと言うことです。具体的には、上記の命令列の「jump_if_not while_end」の部分で、while_endのラベルはトレースに含まれません。
しかしこれがないとネイティブコードを生成するときにジャンプ先がなくてコンパイルしようがないので困ってしまいます。これはつまり、関数の実行が終わったときにインストラクションポインタ(ip)をどうしたらいいのかをコンパイラが事前に知る術がないことによります。
今回の実装では、ブロックから脱出する部分はfallback用のブロックを手動で追加し、VM側で持っているipを示す変数のポインタをJIT側のネイティブ関数にわたしてそれをfallbackブロックから設定してもらうというような方法を取りました。
言葉で書いてもあれなので、要は以下のような命令列を生成したということです。
while_cond:
CONDITION
jump_if_not fallback_while_end ;; ラベルをすり替える
BODY
jump while_cond
fallback_while_end
set_ip while_end ;; while_endのラベル位置に相当する値をipにセット
ret
fallback_while_end_2 ;; もし2つ以上脱出口があれば以下に続ける
...
set_ipと書いてますが、ipを示す変数は単にJIT関数側の引数として渡されるので要は引数レジスタに値をセットするだけです。while_endのラベル位置に相当する値は事前に計算できるのでコンパイラ側から埋め込んでおくことができます。
JITによる最適化結果
上記のようなトリックにより上手くトレースを最適化できた結果、以下のような結果となりました。
hyperfine --warmup 3 "zig build run -Doptimize=ReleaseSafe -- run ./example/calc_squares.ob --noopt"
Benchmark 1: zig build run -Doptimize=ReleaseSafe -- run ./example/calc_squares.ob --noopt
Time (mean ± σ): 1.137 s ± 0.005 s [User: 1.099 s, System: 0.080 s]
Range (min … max): 1.131 s … 1.146 s 10 runs
3000msから1100msと、大体3倍くらい早くなりました。
しかしこれでもまだRust実装と比べて5倍は開きがあります。まだ高速化する方法はないでしょうか。
スタックマシンの命令列最適化
そもそも今のVMはスタックマシンで、明らかに無駄な実行が山ほどあります。
IRをダンプして、特に重いであろうWhile文の内側の部分を詳細にみてみます。
...
while_cond_736036777:
get_local_d [3] ;; スタックからの値の取得
get_local_d [0] ;; スタックからの値の取得
lt
jump_ifzero while_end_736036777
while_body_736036777:
get_local_d [1] ;; スタックからの値の取得
get_local_d [2] ;; スタックからの値の取得
get_local_d [3] ;; スタックからの値の取得
mul
add
set_local_d [1] ;; スタックへの値のセット
get_local_d [3] ;; スタックからの値の取得
push #1
add
set_local_d [3] ;; スタックへの値のセット
jump while_cond_736036777
...
軽く説明しておくと、 get_local_d [n]
はスタックの bp+n (bpはベースポインタ)の値を取得してスタックの一番上にpushする命令です。 set_local_d [n]
はスタックの一番上の値をpopして、 bp+n の値に上書きする命令です。
この2つの命令はいずれもスタックへの参照と書き換えが走る重い処理です。
そう言う観点で見てみると、最初のltの部分でもう無駄があります。
最初の3つは「bp+3の値を取り出してpush」「bp+0の値を取り出してpush」「2回popして比較した結果をpush」という操作ですが、わざわざpushしなくてもレジスタに入れたまま比較した方が明らかに早いです。
ということで、lt命令の代わりに、スタックからpopしなくてもレジスタに入った値を直接比較できるような命令があると良さそうです。
今回はレジスタ割り付けなどを考えたくなかったことと、get_local_dを減らすことを念頭において以下のような新命令を実装することにしました。
lt_d [n] [m]
という命令で、次と等価です: get_local_d [n]; get_local_d [m]; lt
これにより、スタックから値を引いてレジスタに入れ、スタックに積み直さなくてもレジスタに入れた値同士で比較ができます。
また同様に、 get_local_d [b]; get_local_d [m]; get_local_d [n]; mul; add
と等価な madd [b] [m] [n]
、そして get_local_d [m]; push #imm; add
と等価な add_di [m] #imm
命令を作りました。
また、さらに madd; set_local_d
のような、スタックの一番上に積んでからスタックの中ほどにコピーするような命令も見受けられたのでここも最適化したくなり、maddやadd_diにはセットするoffsetを直接持たせられるようにしました。
これらの最適化を行ってあげると、IRは以下のようになります。
...
while_cond_382205881:
lt_d [3] [0]
jump_ifzero while_end_382205881
while_body_382205881:
madd_d [2] [3] [1] -> [1]
add_di [3] #1 -> [3]
jump while_cond_382205881
...
だいぶ短くなりましたね。
そして計測もしてみると、実際にかなり早くなっているのがわかります。
hyperfine --warmup 3 "zig build run -Doptimize=ReleaseSafe -- run ./example/calc_squares.ob"
Benchmark 1: zig build run -Doptimize=ReleaseSafe -- run ./example/calc_squares.ob
Time (mean ± σ): 180.9 ms ± 1.8 ms [User: 153.1 ms, System: 71.0 ms]
Range (min … max): 177.5 ms … 183.7 ms 16 runs
1100msから180msと、ほぼ6倍の高速化です。
また、Rustと遜色ない時間に落ち着いたので無事に最適化は成功したと言えます。
最適化の振り返り
さて、今回はTraceをとったJIT(ネイティブコード生成)とVMの命令最適化の2つの最適化を行いました。
上記の結果だけ見ると命令最適化の方が効いているようにも見えてしまいますが、実際にどっちがより強く効いているだろうと思って計測してみると以下のようになりました。
JITなし | JITあり | |
---|---|---|
命令最適化なし | 3000ms | 1130ms |
命令最適化あり | 1180ms | 180ms |
ということで、実はこの2つの最適化については片方だけだとそこまで差はありません。両方揃って初めて効果を発揮するということで、やはり最適化というのはボトルネックを正しく潰せるかどうかであって手法だけの問題ではないということがわかります。
終わりに
今回の最適化はあくまで今回対象としたコードに対してよく効く最適化でありますが、当然他のコードに対しては無力なこともあります。特に命令列の最適化はlt, madd, addくらいしか見ていないので他のコードでは他の上手い命令を考えてあげる必要が出てくると思います。
しばらくの間は遅いと思ったらIRを眺めて上手い命令を考えるのでも割と凌げる気はしますが、本気でここを追求するならレジスタ割り付けに真面目に取り組む必要がありそうです。特に今回のようなget_local_dやset_local_dを減らすことがスタックマシンにおいては重要なので、スタックに置かれた値の生存解析をしてあげてよしなにレジスタに逃がしてあげることでそれなりに良い感じのコードが生成できるんじゃないかなと思っていますが、まあこれは本当に行き詰まってしまってからでもいいかなと思っています。
また、今回はトレース解析はサボりましたがこの辺りのアルゴリズムも興味があるので、余裕があればどこかで実装できたらいいかもなと思っています。やる前は難しそうでびびっていたexit pathの特定とipの復元処理ですが、意外となんとかなったので個人的には割と満足しました。
今後はもう少し本腰を入れて言語自体の拡張や機能追加をやっていきたいと思います。
Discussion
動作速度が記述されてますが、実行環境(CPU/OS/VMなど)を書いていただけると、比較/検討の参考になります。よろしくお願い申し上げます。
コメントありがとうございます。
計測環境について追記しておきました。