🧮

多倍長整数の実装8(LLVMを用いたasmコード生成)

2022/08/18に公開

初めに

前回まではx64用asmコードを用いた実装を紹介しました。高速な実装のためにはCPUの特性と命令を駆使しなければなりません。今回はLLVMを用いてより汎用的で(そこそこ)高速なコードを効率よく生成することを目指します。
記事全体の一覧は多倍長整数の実装1(C/C++)参照。

方針

LLVMは仮想機械用の中間コード(LLVM IR)をもち、そのアセンブリ言語(以下ll)の文法はLLVM Language Reference Manualに記されています。
llファイルを作ってclangに渡すとターゲットCPUに応じた最適化されたコードが生成されます。従って、llファイルを一つ作るとLLVMが対応する様々なアーキテクチャに対応できます(実際にはそううまくはいかないことが多いですが)。そこで、今回は今まで作ってきたadd, mulUnitなどの低レベル関数をllで記述してみましょう。

llの概略

詳細は本家マニュアルを参照していただくことにして、ここでは今回のターゲットに必要最小限の話を紹介します。
LLMV IRは任意個の整数レジスタを扱えます。レジスタのビットサイズも固定ですが、任意サイズです。ただしキャリーフラグは存在しません。

まず32bit整数同士の加算結果を返す関数

extern "C" int add(int x, int y)
{
    return x + y;
}

をどう表現するかを見てみましょう。

clangで-S -O2 -emit-llvmオプションでコンパイルするとllファイルが生成されます。
生成されたファイルにはいろいろな付加情報がついていますが、最小限を抜き出すと次のようになります。

define i32 @add(i32 %0, i32 %1) { ; 32bitレジスタ%0と%1を引数にもちi32を返す関数
  %3 = add i32 %1, %0             ; 32bit加算をして%3にセット
  ret i32 %3                      ; %3をreturnする
}

%で始まる変数がレジスタです。ここでは数字しか出ていませんが%abcのようなものでもOKです。
このファイル(add32.ll)と次の(main1.cpp)を一緒にコンパイルするとちゃんと動きます。clangはllを普通にコンパイルできるのが便利です。gccも対応してくれるとありがたいのですが。

#include <stdio.h>
#include <stdint.h>

extern "C" int add(int x, int y);

int main()
{
  for (int i = 0; i < 10; i++) {
    printf("%d + %d = %d\n", i, i + 3, add(i, i + 3));
  }
}
$ clang++-12 -O2 main1.cpp add32.ll
$ ./a.out
0 + 3 = 3
1 + 4 = 5
2 + 5 = 7
3 + 6 = 9
4 + 7 = 11
5 + 8 = 13
6 + 9 = 15
7 + 10 = 17
8 + 11 = 19
9 + 12 = 21

add関数に対応するコードは-S -o - -O2で確認できます(余計な部分を除去しています)。

$ clang++-12 -S -O2 -o - add32.ll
        .text
        .globl  add
add:
        leal    (%rdi,%rsi), %eax
        retq

--targetオプションでLLVMがサポートする他のアーキテクチャのasmコードも生成できます。
たとえばM1などのAArch64なら

$ clang++-12 -O2 -S -o - --target=aarch64 add32.ll
        .text
        .globl  add
add:
        add     w0, w1, w0
        ret

clangが対応しているアーキテクチャは-print-targetで確認できます。

$ clang++-12 -print-targets
  Registered Targets:
    aarch64    - AArch64 (little endian)
    aarch64_32 - AArch64 (little endian ILP32)
    aarch64_be - AArch64 (big endian)
    amdgcn     - AMD GCN GPUs
    arm        - ARM
    arm64      - ARM64 (little endian)
    arm64_32   - ARM64 (little endian ILP32)
    armeb      - ARM (big endian)
    avr        - Atmel AVR Microcontroller
    ...

RISC-Vにも対応してます。便利ですね。

$ clang++-12 -S -o - -O2 --target=riscv64 a.ll
        .text
        .globl  add
add:
        addw    a0, a0, a1
        ret

メモリの読み書きはloadとstoreを使います。レジスタのサイズ変更はtrunc(小さくする方)やzext(0拡張)などがあります。こちらも詳細はマニュアルを参照ください。

clangの_ExtIntの紹介(余談)

clangには整数型の拡張として_ExtIntというのがあります。
たとえば256bit整数と64bit整数を掛けて320bit整数を得る関数はC++で

#include <stdint.h>
typedef unsigned _ExtInt(256) uint256_t; // 256bit符号無し整数
typedef unsigned _ExtInt(256+64) uint320_t; // 320bit符号無し整数

uint64_t mulUnit256(uint256_t *pz, const uint256_t *px, uint64_t y)
{
  uint256_t x = *px;
  uint320_t z = uint320_t(x) * uint320_t(y);
  *pz = uint256_t(z);
  return uint64_t(z >> 256);
}

と記述できます(clang-12で確認)。
typedef unsigned _ExtInt(n)でnビットの整数を定義します。演算は同じサイズの型同士でしかできないのでxとyの両方をuint320_tにキャストして掛けます。下位256bitをpzに保存し、上位の64bitをreturnします。

このコードのllファイルは

define i64 @mulUnit256(i256* %0, i256* readonly %1, i64 %2) {
  %4 = load i256, i256* %1    ; %1(=px)から256bit読み込んで%4(=x)に代入
  %5 = zext i256 %4 to i320   ; %4を320bitにゼロ拡張
  %6 = zext i64 %2 to i320    ; %2(=y)を320bitにゼロ拡張
  %7 = mul i320 %5, %6        ; 掛けて320bitの値を得る
  %8 = trunc i320 %7 to i256  ; それを256bitに切り捨てて
  store i256 %8, i256* %0     ; %0(=pz)に書き込み
  %9 = lshr i320 %7, 256      ; 残りを256bitシフトして
  %10 = trunc i320 %9 to i64  ; 64bitに切り捨てて
  ret i64 %10                 ; returnする
}

となりました。コメントをつけたので概ね内容は理解できるでしょう。コードを見れば分かるようにllは静的単一代入SSA(Static Single Assignment form)で記述します。
このllファイルを-S -O2 -mbmi2をつけてコンパイルすると

>clang++-12 -O2 -mbmi2 -S -o - a.cpp -masm=intel
        .text
        .globl  mulUnit256
mulUnit256:
        mulx    r9, r8, qword ptr [rsi + 16]
        mulx    rax, rcx, qword ptr [rsi + 24]
        add     rcx, r9
        adc     rax, 0
        mulx    r9, r10, qword ptr [rsi]
        mulx    rsi, rdx, qword ptr [rsi + 8]
        add     rdx, r9
        adc     rsi, r8
        adc     rcx, 0
        adc     rax, 0
        mov     qword ptr [rdi], r10
        mov     qword ptr [rdi + 8], rdx
        mov     qword ptr [rdi + 16], rsi
        mov     qword ptr [rdi + 24], rcx
        ret

となりました。元のllファイルでは320bit同士の乗算であっても元がそれぞれ256bit, 64bitをゼロ拡張したものであると判断して必要最小限の乗算しかしていないことに注意してください。256bitの切り捨てやシフトもレジスタ上では不要なのでそのようなコードもありません。
adcとmulxのintrinsicによる実装で紹介した生成コードにかなり近いです。addが一つ余計ですが、たいしたものです。

ここまで出来るなら、全部_ExtIntで書けばええやん、と思ってしまいます。しかしここで残念なお知らせです。片方が64bitなら乗算はできるのですがそれより大きいビットサイズを指定するとランタイムエラーになります。自分でmulUnitを組み合わせて教科書ライクな乗算コードを書かねばなりません。
また、_ExtIntはCの新しい規格として実験的に実装されていたのですが、_BitIntという名前に変わってしまいました(cf. Introduce _BitInt, deprecate _ExtInt)。更にclang-13より新しいバージョンのclangでは何故か128ビットまでしか指定できません

clang++-16が出すエラー。

add-extint.cpp(2,18): warning: '_ExtInt' is deprecated; use '_BitInt' instead [-Wdeprecated-type]
typedef unsigned _ExtInt(256) uint256_t;
                 ^~~~~~~
                 _BitInt
add-extint.cpp(2,1): error: unsigned _BitInt of bit sizes greater than 128 not supported
  • _ExtIntは非推奨になったので代わりに_BitIntを使いなさい。
  • _BitIntは128より大きい値は対応していない。

なんじゃそりゃ。おそらく、そのうち解決されるのでしょうが今のところ使えません。
というわけで、現状ではclangに頼らずにllファイルを直接書くのが望ましいです。

llファイル生成補助ツール

LLVM IRの命令は全ての演算にレジスタのビット情報を書かなければなりません。最初にi256やi32と宣言してるのだからその情報を使ってくれればよいのにやってくれないのです。またSSAなので一行ごとに新しい変数名を定義する必要があり、手書きの場合はとても面倒です。

そこで、llファイルを作るための簡単なツールmcl/src/llvm_gen.hppを開発しました。前回のXbyakライクなPython DSL(gen_x64_asm.py)と似たものですが、開発時期はllvm_gen.hppのほうがずっと前です。将来JIT生成したいなと思ってC++で作ったのですが未だに実装していません(必要性が低かった)。素直にPythonで書けばよかったですね。
それはともかく、llvm_gen.hppを使ったコードがmcl/src/gen_bint.cppです。一部を紹介しましょう。

256bit整数加算Unit addPre(Unit *pz, const Unit *px, const Unit *py)は次のように書けます(細部省略)。

void gen_add()
{
    const int bit = 256;
    const int unit = 64;
    const int N = bit / unit;
    Operand r(Int, unit);                       // 戻り値
    Operand pz(IntPtr, unit);                   // pz
    Operand px(IntPtr, unit);                   // px
    Operand py(IntPtr, unit);                   // py
    Function f("mclb_add", r, pz, px, py);      // 関数のプロトタイプ宣言を表すインスタンス
    beginFunc(f);                               // 関数始まり
    Operand x = zext(loadN(px, N), bit + unit); // pxからN個分メモリから読み込みbit + unitにゼロ拡張
    Operand y = zext(loadN(py, N), bit + unit); // pyからN個分メモリから読み込みbit + unitにゼロ拡張
    Operand z;
    z = add(x, y);                              // 加算
    storeN(trunc(z, bit), pz);                  // 下位256ビットをpzにstore
    r = trunc(lshr(z, bit), unit);              // 256bit右シフトして64bitにtruncate
    ret(r);                                     // 値rを返す
    endFunc();                                  // 関数終わり
}

Operandは整数レジスタやポインタなどを表すクラスです。用途(IntIntPtr)とサイズ(自然数)を指定してインスタンスを作ります。
LLVM IRの主に整数演算命令に対応する関数が用意されていて、それを呼び出すと対応するサイズつきのllコードを生成します。複数回代入可能で、自動的に名前の番号が増えます。

loadNやstoreNはN個分シフトしながら読み書きするサブ関数です。このコードをコンパイルして実行すると

define i64 @mclb_add4(i64* noalias  %r2, i64* noalias  %r3, i64* noalias  %r4)
{
%r6 = bitcast i64* %r3 to i256*
%r7 = load i256, i256* %r6
%r8 = zext i256 %r7 to i320
%r10 = bitcast i64* %r4 to i256*
%r11 = load i256, i256* %r10
%r12 = zext i256 %r11 to i320
%r13 = add i320 %r8, %r12
%r14 = trunc i320 %r13 to i256
%r16 = getelementptr i64, i64* %r2, i32 0
%r17 = trunc i256 %r14 to i64
store i64 %r17, i64* %r16
%r18 = lshr i256 %r14, 64
%r20 = getelementptr i64, i64* %r2, i32 1
%r21 = trunc i256 %r18 to i64
store i64 %r21, i64* %r20
%r22 = lshr i256 %r18, 64
%r24 = getelementptr i64, i64* %r2, i32 2
%r25 = trunc i256 %r22 to i64
store i64 %r25, i64* %r24
%r26 = lshr i256 %r22, 64
%r28 = getelementptr i64, i64* %r2, i32 3
%r29 = trunc i256 %r26 to i64
store i64 %r29, i64* %r28
%r30 = lshr i320 %r13, 256
%r31 = trunc i320 %r30 to i64
ret i64 %r31
}

このようなllファイルが作成されます(cf. mcl/src/bint64.ll)。このllファイルをx64用のasmに変換すると

$ clang++-12 -S -o - -O2 -masm=intel a.ll
        .text
        .globl  mclb_add4
mclb_add4:
    mov     r8, qword ptr [rdx]
    add     r8, qword ptr [rsi]
    mov     r9, qword ptr [rdx + 8]
    adc     r9, qword ptr [rsi + 8]
    mov     rcx, qword ptr [rdx + 16]
    adc     rcx, qword ptr [rsi + 16]
    mov     rdx, qword ptr [rdx + 24]
    adc     rdx, qword ptr [rsi + 24]
    setb    al
    movzx   eax, al
    mov     qword ptr [rdi], r8
    mov     qword ptr [rdi + 8], r9
    mov     qword ptr [rdi + 16], rcx
    mov     qword ptr [rdi + 24], rdx
    ret

最適化により冗長なtruncやzest, lshrなどは全て消えて望みのコードが生成されました。もちろんAArch64用のコードも生成できます。

$ clang++-12 -S -o - -O2 --target=aarch64 a.ll
        .globl  mclb_add4
mclb_add4:
    ldp     x9, x8, [x1]
    ldp     x11, x10, [x2]
    ldp     x13, x12, [x1, #16]
    ldp     x14, x15, [x2, #16]
    adds    x9, x11, x9
    adcs    x8, x10, x8
    adcs    x10, x14, x13
    stp     x9, x8, [x0]
    adcs    x9, x15, x12
    adcs    x8, xzr, xzr
    stp     x10, x9, [x0, #16]
    mov     x0, x8
    ret

llファイルを手書きするのはなかなか大変ですが、今回のケースではllvm_gen.hppのような補助ツールを使うと読みやすく、それなりによいコードを得られることが分かりました。
mclのAArch64向けMontgomery乗算などはこの方法を用いて実装しています。

GitHubで編集を提案

Discussion