Approximations for Sin/Log2/Exp2
What
Tinygradの$400 bounty「Approximations for Sin/Log2/Exp2」を解いてそこそこ頭を使ったので適当にメモ(と言っても解いたの数ヶ月前だけど)
まだ解かれてないBountyはここにある。
Bountyの難しさ
「Taylor approximations (or other approximation) for LOG2/EXP2/SIN in function.py passing all tests (function.pyで完結するLOG2/EXP2/SINの完全な近似)」
求められてるのは,(1) C言語やCUDA等にbuiltinで搭載されている数学関数を一切使わずに,(2) TernaryOps.WHERE
, 四則演算,bitcastなど基本的な操作のみ (3) FP16~64のデータ型の全範囲に対して (4) atolが[1e-3, 1e-8]
に収まるようにする 近似実装を見つけてあげる必要がある
(1) については,SIMDのIntrinsicでsin/log2/exp2を近似するときに同じことを考えるので,flibやlibc, SLEEFが採用してるRange Reductionアルゴリズムを参考にすれば難しくない。自分がよく使っているSLEEFを参考にした。
Paper: https://arxiv.org/abs/2001.09258
SLEEFの論文の近似手法をTinygradに持ってきて,下の行からEXP2/LOG2/SINを削除すればOK
(2) について,当初使える計算の手札はこれしかなかった。
-
ビット演算 (BinaryOps.AND
a & b
, BinaryOps.ORa | b
) -
加算乗算 (BinaryOps.ADD
a + b
, BinaryOps.MULa * b
) -
整数除算 (BinaryOps.IDIV
a // b
) -
符号反転 (UnaryOps.NEG
-a
) -
比較演算 (BinaryOps.CMPNE
!=
, BinaryOps.CMPLT<
) -
Blend操作 (TernaryOps.WHERE
np.where(a, b, c)
)
Bountyを解き始めた当初TinygradにBitwiseのANDとORが存在しなく,代わりにBooleanにbitcastして加算/乗算するアプローチでやろうとしたがコンパイラのバグで断念。GithubでDiscussionして最低限ANDとORは必要だろうということでprepreq(tinygrad用語)でPRをMergeしてもらった
近似実装のアルゴリズム自体は単純だけど,tinygradのweirdnessに悩まされて他にも4つくらいPRを追加しないと動作しなかった。例えばSchedulerが同一のカーネルにFuseするグループを見つけるアルゴリズムがTopological Sortされてなくて,20分以内にGithubのCIが終わらず通らないなど... $400以上のBountyは基本的にこうしたコンパイラのweirdnessとの戦いになる。幸いなことにtinygradのコードは<8500 linesに収まっているので,ちょっと時間をかければ全体像を把握しやすいから自分で直せるし,ちゃんとしたIssueを投げれば社員のすごい人が対処してくれる。
暇な時間を使って一週間くらいで最初のコードを書いて,その後Mergeされるまでの二週間はずっとこのコンパイラに対するバグ修正を考えてた。
このBountyが誕生した経緯についてだが,丁度これを解いたときにこのBountyを作成した17歳の高校生から(???)Discordで説明を受けた。TinygradのmacOS+Clang Backendは共用ライブラリをコンパイルしてdlopenする仕組みになっていて,これが非常に遅い。で,彼がClangをJITコンパイラに変えてV8 Turbofanのようにshell codeとしてプログラムをロードする仕組みを作ったらしいのだが,ここでほとんどのCPUにはsqrt/sin/exp/logの命令なんてないので,dynamic linkingを実装する必要がある。これがbottle neckになって困っていたので近似計算が必要だったらしい。
でもその後Discordで色々揉めてて結局Mergeされてない。(僕の努力...)
後は検証用に作ったJuliaのコードでも適当に投げておく log/expはどっか行っちゃった,ゴメン
xsin
import Base: TwicePrecision, significand_bits, significand_mask, exponent_mask, exponent_bias
const two_over_pi_f = [
0x00000000,
0x28be60db,
0x9391054a,
0x7f09d5f4,
0x7d4d3770,
0x36d8a566,
0x4f10e410
]
function float_to_bits(x::Float32)
return reinterpret(UInt32, x)
end
function bits_to_float(x::UInt32)
return reinterpret(Float32, x)
end
function float_to_bits(x::Float64)
return reinterpret(UInt64, x)
end
function bits_to_float(x::UInt64)
return reinterpret(Float64, x)
end
function float_to_bits(x::Float16)
return reinterpret(UInt16, x)
end
function bits_to_float(x::UInt16)
return reinterpret(Float16, x)
end
function my_frexp(value::Float16)
if value == Float16(0.0)
return Float16(0.0), 0
end
bits = float_to_bits(value)
exponent = (bits >> 10) & 0x1F
if exponent == 0
return value, 0
end
exp = exponent - 15
return bits_to_float((bits & 0x7f80) | 0x3C00), exp
end
function my_frexp(value::Float32)
if value == 0.0f0
return 0.0f0, 0
end
bits = float_to_bits(value)
exponent = (bits >> 23) & 0xFF
if exponent == 0
return value, 0
end
exp = exponent - 126
return bits_to_float((bits & 0x807FFFFF) | 0x3F000000), exp
end
function my_frexp(value::Float64)
if value == 0.0
return 0.0, 0
end
bits = float_to_bits(value)
exponent = (bits >> 52) & 0x7FF
if exponent == 0
return value, 0
end
exp = exponent - 1022
return bits_to_float((bits & 0x800FFFFFFFFFFFFF) | 0x3FE0000000000000), exp
end
```julia
function reduce_large(a::Float64)
f, e = my_frexp(a)
ia = UInt64(abs(f) * 0x1.0p32)
i = UInt64(e) >> 5
e = UInt64(e) & 31
if e != 0
hi = (two_over_pi_f[Int(i) + 1] << e) | (two_over_pi_f[Int(i) + 2] >> (32 - e))
mid = (two_over_pi_f[Int(i) + 2] << e) | (two_over_pi_f[Int(i) + 3] >> (32 - e))
lo = (two_over_pi_f[Int(i) + 3] << e) | (two_over_pi_f[Int(i) + 4] >> (32 - e))
else
hi = two_over_pi_f[Int(i) + 1]
mid = two_over_pi_f[Int(i) + 2]
lo = two_over_pi_f[Int(i) + 3]
end
p = UInt64(ia) * UInt64(lo)
p = UInt64(ia) * UInt64(mid) + (p >> 32)
p = (UInt64(ia) * UInt64(hi) << 32) + p
q = Int32(p >> 62)
p = p & 0x3fffffffffffffff
if p & 0x2000000000000000 != 0
p -= 0x4000000000000000
q += 1
end
d = Float64(p)
d *= 0x1.921fb54442d18p-62
r = Float64(d)
if a < 0.0
r = -r
q = -q
end
return r, q
end
function reduce_large(a::Float32)
f, e = my_frexp(a)
ia = UInt64(abs(float(f)) * 0x1.0p32)
i = UInt64(e) >> 5
e = UInt64(e) & 31
if e != 0
hi = (two_over_pi_f[Int(i) + 1] << e) | (two_over_pi_f[Int(i) + 2] >> (32 - e))
mid = (two_over_pi_f[Int(i) + 2] << e) | (two_over_pi_f[Int(i) + 3] >> (32 - e))
lo = (two_over_pi_f[Int(i) + 3] << e) | (two_over_pi_f[Int(i) + 4] >> (32 - e))
else
hi = two_over_pi_f[Int(i)+ 1]
mid = two_over_pi_f[Int(i) + 2]
lo = two_over_pi_f[Int(i) + 3]
end
p = UInt64(ia) * UInt64(lo)
p = UInt64(ia) * UInt64(mid) + (p >> 32)
p = ((UInt64(ia) * UInt64(hi)) << 32) + p
q = Int32(p >> 62)
p = p & 0x3fffffffffffffff
if p & 0x2000000000000000 != 0
p -= 0x4000000000000000
q += 1
end
d = Float64(p)#Int64(p))
d *= 0x1.921fb54442d18p-62
r = Float32(d)
if a < 0.0
r = -r
q = -q
end
return r, q
end
function reduce_large(a::Float16)
f, e = my_frexp(a)
m = abs(Float16(f))
ia = UInt64(m * 0x1.0p32)
i = UInt64(e) >> 5
e = UInt64(e) & 31
if e != 0
hi = (two_over_pi_f[Int(i) + 1] << e) | (two_over_pi_f[Int(i) + 2] >> (32 - e))
mid = (two_over_pi_f[Int(i) + 2] << e) | (two_over_pi_f[Int(i) + 3] >> (32 - e))
lo = (two_over_pi_f[Int(i) + 3] << e) | (two_over_pi_f[Int(i) + 4] >> (32 - e))
else
hi = two_over_pi_f[Int(i) + 1]
mid = two_over_pi_f[Int(i) + 2]
lo = two_over_pi_f[Int(i) + 3]
end
p = UInt64(ia) * UInt64(lo)
p = UInt64(ia) * UInt64(mid) + (p >> 32)
p = (UInt64(ia) * UInt64(hi) << 32) + p
q = Int64(p >> 62)
p = p & 0x3fffffffffffffff
if p & 0x2000000000000000 != 0
p -= 0x4000000000000000
q += 1
end
d = Float32(p)
d *= 0x1.921fb54442d18p-62
r = Float16(d)
if a < Float16(0.0)
r = -r
q = -q
end
return r, q
end
function xsin(x)
r, q = reduce_large(x)
q = q % 4
if q == 0
return sin(r)
elseif q == 1
return sin(r + π/2)
elseif q == 2
return sin(-r)
elseif q == 3
return -sin(r + π/2)
end
end
for i in 0:2e5
i = i
approx = xsin(Float32(i))
exact = sin(Float32(i))
diff = abs(out - approx)
if diff > 1e-3
println("Failed at $i. diff=$diff")
end
end
xlog2
xexp2 (どっか行った)
400行くらいのコードで6万くらい貰えて学生にはめちゃめちゃ美味しいのと,これ解いて本社でインターンする権利が与えられるのでめちゃめちゃ美味しい (噂では月100万円以上くれるらしい) あとは,海外だとGeorge Hotzはカルト的な人気を誇っているというのを直接感じられてよかった。日本の某界隈もこういうちゃんとした人やプロジェクトを持ち上げて欲しいなと思ってる。
次やるとしたら2x times faster pattern matcherのやつだけど,最近はCommon LispでDLコンパイラ自作する方が楽しくてあまりPythonを書く気がしない。
独立したPRは分割して送らないとRejectされたり,意外と行数制限は緩かったり,そこら辺のローカルなルールは実際にやって知見を溜めないとわからなかった。これ見て自分もやろうと思った方がいたらぜひDMください。(単純に自分がDLコンパイラとか数値計算やってる人とお話ししたいというのがありますが...)
References
- SLEEF: A Portable Vectorized Library of C Standard Mathematical Functions (https://arxiv.org/abs/2001.09258, https://github.com/shibatch/sleef/tree/master)
- A new range reduction algorithm (https://ieeexplore.ieee.org/document/987766)
- Payne-Hanek reduction in Julia (https://gist.github.com/simonbyrne/d640ac1c1db3e1774cf2d405049beef3)
- Payne Hanek algorithm implementation in C (https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c)
- A faster and more accurate implementation of sincosf() (https://forums.developer.nvidia.com/t/a-faster-and-more-accurate-implementation-of-sincosf/44620)
- ARGUMENT REDUCTION FOR HUGE ARGUMENTS:
Good to the Last Bit (https://redirect.cs.umbc.edu/~phatak/645/supl/Ng-ArgReduction.pdf).
Discussion