👏

Rustを使って円周率1億桁計算したお話

2021/02/11に公開

今回は最近界隈で流行りのプログラミング言語Rustを使ってなんと円周率小数点以下1億桁を計算したというお話です!!

RustはC++に変わるシステムプログラミング言語として開発されてめちゃめちゃ高速です。せっかく高速なのだからそれが役に立つ処理ということで円周率を計算させてみました〜

アルゴリズムはChudnovskyの公式で高速化手法としてBinary Splitting Methodを用いました。アルゴリズムの説明は以下の記事を参考にして下さい。

参考記事↓

https://qiita.com/peria/items/c02ef9fc18fb0362fb89

では実際に実装コードの説明していきたいと思います〜

今回はシングルスレッド版とマルチスレッド版の両方を作ってみました〜

シングルスレッド版

コード

use rug::{ops::Pow, Float, Integer};

// 定数たち
const N: i64 = 100000000;
const SN: i64 = N / 14; // n: small N
const A: i64 = 13591409;
const B: i64 = 545140134;
const C: i64 = 640320;
const CT: i64 = C * C * C;
const CTD24: i64 = CT / 24;

fn calc_x(k: i64) -> Integer {
    if k == 0 {
        return Integer::from(1);
    }
    Integer::from(k).pow(3) * CTD24
}

fn calc_y(k: i64) -> Integer {
    A + Integer::from(B) * k
}

fn calc_z(k: i64) -> Integer {
    if k == SN - 1 {
        return Integer::from(0);
    }
    (-1) * Integer::from((6 * k + 1) * (2 * k + 1)) * (6 * k + 5)
}

fn calc(left: i64, right: i64) -> (Integer, Integer, Integer) {
    if right - left == 1 {
        return (calc_x(left), calc_y(left), calc_z(left));
    }

    let mid = (left + right) >> 1;

    let (lx, ly, lz) = calc(left, mid);
    let (rx, ry, rz) = calc(mid, right);

    (lx * &rx, &rx * ly + ry * &lz, &lz * rz)
}

fn main() {
    let (x, y, _z) = calc(0, SN);

    // with_valのprocはあくまでも有効桁のビット長(N / log_10^2), 10進数の桁数とは違う
    // 1e8 / log10 ^2の演算結果を四捨五入した値
    let prec: u32 = 332192810;

    // precは10進数1e8桁の場合u32の制限に引っかからない(u32のMAXが4294967295)
    let ans = Float::with_val(prec, CT).sqrt() * x / 12 / y;
    println!("{}", ans);
}

浮動小数点の扱い

小数点以下1億桁まで計算する必要があるので多倍長整数or少数演算をどうするか考える必要があります。
幸いRustにはrugというGMPをラップした多倍長ライブラリがあるので今回はそれを用いることにします。

因みに事前に必要な定数は全てi64で定義しています。これは定数呼び出しをInteger型ではできないからです(恐らくスタック領域に定数が保存されるためだと思われる)。
今回の定数は全てi64に収まるのでまあ特段大きな問題はないです。

次に注意点として、Float型の数を作成するwith_val関数で第一引数に取るprecは有効桁のビット長であり、10進数の桁とは異なるということです。
そのため求めたい10進数の桁をNとするとprecを

prec = \frac{N}{ \log _{10} 2 }

のようにして導出する必要があります。今回はprecを事前計算してmain関数内にぶち込んでおきました。

浮動小数点の扱いに関してはこのぐらいでしょうか。

所有権のお話

calc関数の返り値を計算時、変数にところどころアンパサンドが付いているのが気になるかも知れませんが、Integer型のデータがヒープ領域に保存されており、かつ複数回値が使われている参照を行っているという認識で良さそうです(所有権が移るのを防ぐ)。

Macでコンパイル時の注意点

https://github.com/rust-lang/rust/issues/59164

Homebrewでgccをインストールした場合、Cのランタイムライブラリが見つからずにコンパイル失敗してしまいます。

対策として、.cargo/config.toml

[target.x86_64-apple-darwin]
rustflags = ["-C", "link-args=-L /usr/local/Cellar/gcc/10.2.0_2/lib/gcc/10 -lgcc_ext.10.5"]

と書いておきましょう。これでライブラリを見つけられるようになります(gccのバージョン、ライブラリのパスは適宜自分の環境に合わせてください)。

次はいよいよマルチスレッド版です。

マルチスレッド版

マルチスレッド用のライブラリはthreadのみで済みました。共有変数の話とかを考えずに済んでよかった。。。

コード

use rug::{ops::Pow, Float, Integer};
use std::thread;

// 定数たち
const N: i64 = 100000000;
const SN: i64 = N / 14; // n: small N
const A: i64 = 13591409;
const B: i64 = 545140134;
const C: i64 = 640320;
const CT: i64 = C * C * C;
const CTD24: i64 = CT / 24;

// スレッド数
const NTHREADS: usize = 4;

fn calc_x(k: i64) -> Integer {
    if k == 0 {
        return Integer::from(1);
    }
    Integer::from(k).pow(3) * CTD24
}

fn calc_y(k: i64) -> Integer {
    A + Integer::from(B) * k
}

fn calc_z(k: i64) -> Integer {
    if k == SN - 1 {
        return Integer::from(0);
    }
    (-1) * Integer::from((6 * k + 1) * (2 * k + 1)) * (6 * k + 5)
}

fn calc(left: i64, right: i64) -> (Integer, Integer, Integer) {
    if right - left == 1 {
        return (calc_x(left), calc_y(left), calc_z(left));
    }

    let mid = (left + right) >> 1;

    let (lx, ly, lz) = calc(left, mid);
    let (rx, ry, rz) = calc(mid, right);

    (lx * &rx, &rx * ly + ry * &lz, &lz * rz)
}

fn main() {
    let mut handles = Vec::new();

    let len = (SN + NTHREADS as i64 - 1) / NTHREADS as i64;
    let mut c = SN % len;
    for _x in 1..NTHREADS {
        handles.push(thread::spawn(move || calc(c, c + len)));
        c += len;
    }

    let (fx, fy, fz) = calc(0, SN % len);

    let (x, y, _z) = handles.into_iter().fold((fx, fy, fz), |(x, y, z), handle| {
        let (tx, ty, tz) = handle.join().unwrap();
        (x * &tx, &tx * y + ty * &z, &z * tz)
    });

    // with_valのprocはあくまでも有効桁のビット長(N / log_10^2), 10進数の桁数とは違う
    // 1e8 / log10 ^2の演算結果を四捨五入した値
    let prec: u32 = 332192810;

    // precは10進数1e8桁の場合u32の制限に引っかからない(u32のMAXが4294967295)
    let ans = Float::with_val(prec, CT).sqrt() * x / 12 / y;
    println!("{}", ans);
    // println!("{num:.prec$}", prec = N as usize - 1, num = ans);
}

スレッディングの話

まずBinary Splitting Methodで計算する範囲をスレッド数で分割します。
分割した演算をそれぞれhandlesにpushしていき、最後にjoinすることで演算結果を取得しそれを合成することで最終的なx, y, zの値を求めます。

それ以降はシングルスレッド版と対して内容は変わりありません。

因みにiterではなくinto_iterを用いたのはjoinを用いるのに参照ではなく所有権を移行する必要があったからです。

あとHaskellチックに書けるfold関数すこ。エレガント

演算結果

timeコマンドを用いてお互いの実行速度を比較しました。

環境は

CPU: 2.3 GHz 8-Core Intel Core i9

メモリ: 32 GB 2667 MHz DDR4

です。贅沢の限りを尽くしています。。。

user system cpu total
シングルスレッド版 127.83s 3.74s 99% 2:12.01
マルチスレッド版(2 コア) 126.00s 3.65s 145% 1:28.99
マルチスレッド版(4 コア) 135.70s 3.47s 172% 1:20.49
マルチスレッド版(6 コア) 147.82s 4.11s 180% 1:24.04
マルチスレッド版(12 コア) 198.86s 4.21s 179% 1:53.22

4コアから利用コア数を上げると逆に時間かかっているの、スレッディングから取ってきた値をかけ合わせているところで近い桁同士で掛け算できていないから説があります。
実装の工夫のしがいがあるな。

ということで早速簡単にできそうな掛け算の最適化を行いました〜。

実装は基のコードのmain関数内の

    let (x, y, _z) = handles.into_iter().fold((fx, fy, fz), |(x, y, z), handle| {
        let (tx, ty, tz) = handle.join().unwrap();
        (x * &tx, &tx * y + ty * &z, &z * tz)
    });

    let mut thread_vec: Vec<(Integer, Integer, Integer)> = handles
        .into_iter()
        .map(|handle| handle.join().unwrap())
        .collect();

    let (x, y, _z) = calc_thread_vec(&mut thread_vec);

に変えました。関数calc_thread_vecはこんな感じ。

fn calc_thread_vec(vec: &mut Vec<(Integer, Integer, Integer)>) -> (Integer, Integer, Integer) {
    if vec.len() == 1 {
        return vec.pop().unwrap();
    }
    let mid = vec.len() >> 1;
    let mut rest_vec = vec.split_off(mid);

    let (lx, ly, lz) = calc_thread_vec(vec);
    let (rx, ry, rz) = calc_thread_vec(&mut rest_vec);

    (lx * &rx, &rx * ly + ry * &lz, &lz * rz)
}

統治作業を再帰的にやることで掛け算する桁を近く合わせました。

実行結果、少し早くなった、8コアを有効活用できた気がします。

user system cpu total
マルチスレッド版(4 コア) 127.05s 2.76s 176% 1:13.47
マルチスレッド(8 コア) 137.46s 3.08s 195% 1:11.97
マルチスレッド版(16 コア) 185.82s 3.55s 240% 1:18.68

まだ高速化できる余地は残ってそう、というか多倍長演算の根本的仕組み(FFTとか)が怪しいんだよな。
頑張っていつか理解できるようになりたいですね。

続く。。。

Discussion