AVX-512による最速のexp(x)を目指して
初めに
ここではfloat配列に対する指数関数exp(x)のAVX-512による近似計算例を紹介します。
exp(x)の近似計算方法
まず
ここで
すると
以前は一度
1b-23
=
Sollyaという近似計算のためのソフトを利用します。guessdegree
を使って [-0.5, 0.5]
で精度 1b-23
を得るための次数を求めます。
プロンプトで
> guessdegree(2^x,[-0.5,0.5],1b-23);
[5;5]
とします。 5次まで求めるとよいようです。
次に近似式を求めます。fpminimax
は引数にかなり癖があってマニュアルを読んでもよく分かりません。とりあえず次のようにしたらうまく出来ました。
fpminimax(2^x,[|1,2,3,4,5|],[|D...|],[-0.5,0.5],1);
Warning: For at least 5 of the constants displayed in decimal, rounding has happened.
1 + x * (0.69314697759916432673321651236619800329208374023437
+ x * (0.24022242085378028852993281816452508792281150817871
+ x * (5.5507337432541360711102385039339424110949039459229e-2
+ x * (9.6715126395259202324306002651610469911247491836548e-3
+ x * 1.326472719636653634089906717008489067666232585907e-3))))
それぞれの係数を配列float c[6] = {1, 0.6931, 0.24022, ... }
として利用します。
最終的に次のステップでexp(x)を計算します。
y \leftarrow x \times \log_2(e). -
ここでn \leftarrow {\tt round}(x). は四捨五入関数。{\tt round} a \leftarrow x - n. w=2^a \leftarrow 1 + a(c[1] + a(c[2] + a(c[3] + a(c[4] + a c[5])))). z \leftarrow 2^n. - return
.zw
AVX-512による実装
前節の方針にしたがってAVX-512による実装を行います。
2^n x の計算
順序が前後しますが、先にステップ5とステップ6をまとめた
AVX2までは次の方法をとっていました。(非負)整数 1 << n
ですが、ここで必要なのはfloat型なのでちょっとしたビット演算をします。
floatのビット表現(符号s : 1ビット、指数部e : 8ビット、仮数部f : 23ビット)
floatのビット表現 | 符号s | 指数部e | 仮数部f |
---|---|---|---|
ビット長 | 1 | 8 | 23 |
に合わせて(n + 127) << 23
を求めてそれをfloat値として扱うのです。しかし、AVX-512ではvscalefps
があるのでそれを使って直接計算できます。
しかもレイテンシ4と乗算と同じコストでできます。とても便利ですね。ただ
四捨五入
続いてステップの最初の四捨五入に戻ります。
SSE時代からあるcvtps2dq
(float→int変換)はAVX-512で拡張されて丸めモードを指定できます。しかし、結果はint型なので、小数部を求めたり前述のvscalefps
に渡すためにはfloat型に戻さなければなりません。
SSE4(AVX)で登場したvroundps
は結果をfloatの型として受け取れますが、AVX-512には拡張されていません。
そこでAVX-512ではvrndscaleps
を使います。これは
を求める、ちょっと変わった命令です。
が、今回はvreduceps
を使うことにしました。これは小数部
を直接求める命令です。つまり、上記ステップ3の小数部
n \leftarrow {\tt vrndscaleps}(x). a \leftarrow x - n.
を
a \leftarrow {\tt vreduceps}(x). n \leftarrow x - a.
とする。
加えて、レイテンシはvrndscaleps
が8clkなのに対してvreduceps
だと4clkです。より早くローラン展開を開始できます(マニュアルを何度も眺めていて気がついた)。
ローラン展開
この多項式の計算は前回AVX-512のFMAを用いた多項式の評価で紹介したFMAを使います。
実装例
v0
が入力値を表すzmmレジスタ、v1
, v2
は一時レジスタ、self.log2_e
やself.expCoeff[]
は定数を格納しているレジスタとします。
vmulps(v0, v0, self.log2_e)
vreduceps(v1, v0, 0) # a = x - n
vsubps(v0, v0, v1) # n = x - a = round(x)
vmovaps(v2, self.expCoeff[5])
for i in range(4, -1, -1):
vfmadd213ps(v2, v1, self.expCoeff[i])
vscalefps(v0, v2, v0) # v2 * 2^n
ループアンロール
ここは純粋にPythonによる話になります。
ループアンロールするにはそれぞれの命令に対して必要なレジスタをいくつか用意し、繰り返し並べる必要があります。
たとえばv0 = [zmm0, zmm1, zmm2]
, v1=[zmm3, zmm4, zmm5]
, v2=[zmm6, zmm7, zmm8]
のとき、
Unroll(3, vaddps, v0, v1, v2)
と書くと
vaddps(zmm0, zmm3, zmm6)
vaddps(zmm1, zmm4, zmm7)
vaddps(zmm2, zmm5, zmm8)
となって欲しいです。引数がアドレスだったら、
# Unroll(2, vaddps, [zmm0, zmm1], [zmm2, zmm3], ptr(rax))
vaddps(zmm0, zmm3, ptr(rax))
vaddps(zmm1, zmm2, ptr(rax+64))
のようにオフセットがずれて欲しいです。また多項式の計算では引数の一部が配列ではない(定数なので)ときもあります。
これらのことを考慮して
def Unroll(n, op, *args, addrOffset=None):
xs = list(args)
for i in range(n):
ys = []
for e in xs:
if isinstance(e, list): # 引数が配列ならi番目を利用する
ys.append(e[i])
elif isinstance(e, Address): # 引数がアドレスなら
if addrOffset == None:
if e.broadcast:
addrOffset = 0 # broadcastモードならオフセット0
else:
addrOffset = SIMD_BYTE # そうでないときはSIMDのサイズずらす(addrOffsetで細かい制御はできる)
ys.append(e + addrOffset*i)
else:
ys.append(e)
op(*ys)
という関数を作ってみました。そしてアンロール回数を毎回書かずに、また一斉置換しやすいように次のヘルパー関数を用意しました。
def genUnrollFunc(n):
"""
return a function takes op and outputs a function that takes *args and outputs n unrolled op
"""
def fn(op, addrOffset=None):
def gn(*args):
Unroll(n, op, *args, addrOffset=addrOffset)
return gn
return fn
命令オペランドを引数にとり、Unrollするための関数を返す関数です。これらを使うと前節のAVX-512のコードは次のように書けます。
un = genUnrollFunc(n) # アンロール回数を指定する
un(vmulps)(v0, v0, self.log2_e)
un(vreduceps)(v1, v0, 0) # a = x - n
un(vsubps)(v0, v0, v1) # n = x - a = round(x)
un(vmovaps)(v2, self.expCoeff[5])
for i in range(4, -1, -1):
un(vfmadd213ps)(v2, v1, self.expCoeff[i])
un(vscalefps)(v0, v2, v0) # v2 * 2^n
元のASMのオペコードopをun(op)に置換しただけです。。v0
などはアンロールしたいだけのレジスタを割り当てておきます。C++の場合は命令ごとに型が異なって変なマクロや、マクロを使わないならトリッキーなtemplateが必要でしたが、Pythonだと自由度が高いので便利ですね。
ベンチマーク
今回計算中に必要な定数は6個、exp一つあたりに必要なレジスタは3個なので、8回アンロールしても
アンロール回数を1から8まで変更しながら測定してみました(Xeon Platinum 8280 Turbo Boost off)。
アンロール回数 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|
allreg | 17.91 | 15.89 | 14.14 | 13.85 | 13.68 | 13.08 | 13.03 | 13.78 |
allmem | 18.06 | 16.21 | 14.82 | 14.37 | 14.54 | 14.61 | 14.66 | 16.19 |
allregが全てレジスタに載せた状態だとN=7が最も速かったです。allmemはptr_bを使うバージョン。こちらはN=4が最速でした。
allregでN=8が遅くなっている要因の一つはコードが肥大化し過ぎたせいかなと思います。全体のコードはfmathのgen_fmath.pyです。
まとめ
AVX-512を使ったstd::expfの近似計算例を紹介しました。小数部を求めるvreduceps
がポイントです。今回s_xbyakを使ってみて、かなり便利だと思いました。
Discussion
Sollya いいソフトですよね