printfの再実装をしたたかにやってみる(2)
まえおき
このページはシリーズ記事になります。
今回からコード例を掲載しますが、これらは説明のために書かれたものであり動作の保証はありません。また、完成したコードは公開しません。ぜひご自身で書いてください。
今回は多倍長整数演算です
printfをわりとしっかり再実装してみる企画の第二弾。
前回に続いて浮動小数点数の文字列化についてやっていきます。
double型を文字列化すると1000文字を超える桁数の演算が必要になる、という話をしました。
つまり標準的なC言語の型では処理できません。
そこで登場するのが多倍長整数というデータ構造です。
intやlong、long long、unsigned longなどの標準的な(ネイティブな)整数型は固定長であるのに対し、多倍長整数は必要に応じて任意の長さをあてがうことが出来ます。たとえば演算により桁数が不足すれば確保している領域を大きくしてやるような仕組みを作ることで、非常に大きな数値を扱うことが可能になります。
ひとくちに多倍長整数と言ってもその実装には様々な形がありますが、もっともシンプルなものは文字列型を使うものでしょう。今回の目的はprintfの機能として文字列化して出力することですから、データを文字列として保持することは理にかなっています。その上、デバッグ作業においても有利です。
シンプルな多倍長整数
たとえば以下のような型を作ります。
#define NUMSTRING_BUFFER_SIZE 1080
typedef struct {
char str[NUMSTRING_BUFFER_SIZE];
size_t len;
size_t dot_pos;
} numstring_t;
必要な桁数を賄えるだけの領域を確保しなければなりません。
今回はmallocを使わないことが前提であり(第一弾参照)、また、最大の長さが決まっていることからスタック領域を使っています。つまり配列です。サイズは前回の検証に若干の余裕をもたせました。もちろんヒープ領域から確保してもいいです。
さらに、数値文字列の長さを表すlenと、小数点位置を格納するdot_posを加えています。
数値の格納は最上位の桁から行うこととします。最終的に文字列として出力する際に都合が良いからです。dos_posは整数第1位(1の位)のインデックスとします。
たとえば8.5という数値ならば
str[0] = '8';
str[1] = '5';
len = 2;
dot_pos = 0;
という内容になります。
必要最低限の多倍長整数演算
ところで、浮動小数点数の文字列化処理で必要となる演算はなんでしょうか。
第一弾で触れた1.625の例に戻って考えることにします。
1.625のビット表現は以下のようになります。
1.625: 0b0011111111111010000000000000000000000000000000000000000000000000
これは、以下のように分解されます。
符号部(1bit)|指数部(11bit)|仮数部(52bit)
符号部: 0
指数部: 01111111111
仮数部: 1010000000000000000000000000000000000000000000000000
符号 符号部のビットが1なら負、0なら正です。
指数部 バイアスを差し引いた値が指数となります。
仮数部 正規化数の場合は最上位に桁を加えて1とし、この桁を2^0とする2進数の小数点数となります。
というわけで1.625を10進数出力する際には、
- 符号部のビットが0なので正
- 指数部は0x3FFFなのでバイアス0x3FFFを差し引いて0
- 正規化数になるため仮数部の最上位に1の桁を追加
- 仮数部を10進数化
- 指数と仮数部を乗算
- 小数点位置にドットを出力しながら数値を文字として出力
という流れになります。
1.625は正規化数ですので、仮数部には最上位に暗黙の1の桁が新たに付加されます。
11010000000000000000000000000000000000000000000000000
これにより53ビットになりました。
前述のように最上位が2^0のビットとし、以下、ビットを下位に1つ進むごとに指数を-1していきます。
この過程で2の除算と加算が最低限必要になります。
さらに指数を乗算するので乗算も必要になります。手抜きするために、ここでは仮数部の10進数化を終えてから指数部を乗じるより、先に指数を多倍長整数にしてその指数をベースとして仮数部を演算していく方法を選んでみます。
疑似コードは以下のようになります。
LOOP(52bit -> 0bit):
if (bit)
baseの初期値を求めるために、指数が正の場合は2の乗算、負の場合は2の除算を使います。任意の数による乗算・除算よりもシンプルです。
どんなに手抜きをしても加算、2の除算、2の乗算の3つが必要ということになります。
2の除算
加算よりも簡単な2の除算から行きましょう。小学校で習う筆算を愚直にコード化しているのみで工夫はありません。上の桁(インデックス0)から順に2で割って、余りがあればborrowに入れます。
get_digit()
は第2引数インデックスと第3引数オフセットによって指定した桁の数値を取り出す関数です。
void div2numstring(numstring_t *num) {
// 2による除算 num /= 2
int borrow;
int n;
size_t i;
borrow = 0;
i = 0;
while (i < num->len) {
n = borrow * 10 + get_digit(num, i, 0);
num->str[i] = n / 2 + '0';
borrow = n % 2;
i++;
}
if (borrow) {
num->str[i++] = borrow * 10 / 2 + '0';
}
num->str[i] = '\0';
num->len = i;
}
2の乗算
2の乗算も同じように書くことができます。乗算では最下位の桁から上の桁に向かって順番に掛けていくことが大きな違いでしょうか。また、桁が上位に増えることがあるため、その際には1桁(1文字)挿入する必要があります。
void mul2numstring(numstring_t *num) {
// 2による乗算 num *= 2
int carry;
int n;
size_t i;
carry = 0;
i = num->len;
while (i-- > 0) {
n = get_digit(num, i, 0) * 2 + carry;
num->str[i] = n % 10 + '0';
carry = n / 10;
}
if (carry) {
memmove(num->str + 1, num->str, num->len);
num->len++;
num->dot_pos++;
num->str[0] = carry + '0';
num->str[num->len] = '\0';
}
}
加算
加算する際には桁あわせをする必要がありますので少しだけ複雑になります。
get_offset()は桁合わせのためのオフセット量を取得します。2つの引数のdot_posを比較し、後者の方が大きい場合にはその差を返します。また、MAXは2つの引数のうち大きい値を返すマクロです。
void add_numstring(numstring_t *dst, numstring_t *a, numstring_t *b) {
// 加算 dst = a + b
const size_t a_offset = get_offset(a, b);
const size_t b_offset = get_offset(b, a);
int carry;
size_t i;
size_t newlen = MAX(a->len + a_offset, b->len + b_offset);
i = newlen;
carry = 0;
while (i-- > 0) {
int sum = get_digit(a, i, a_offset) + get_digit(b, i, b_offset) + carry;
dst->str[i] = (sum % 10) + '0';
carry = sum / 10;
}
if (carry) {
memmove(dst->str + 1, dst->str, newlen);
dst->str[0] = carry + '0';
dst->len = newlen + 1;
dst->dot_pos = MAX(a->dot_pos + a_offset, b->dot_pos + b_offset) + 1;
} else {
dst->len = newlen;
dst->dot_pos = MAX(a->dot_pos + a_offset, b->dot_pos + b_offset);
}
dst->str[dst->len] = '\0';
}
これで最低限必要となる多倍長整数の演算関数はすべて揃いました。
指数の計算
指数部から指数を算出するには当然ながら2の累乗を計算することになります。上記の簡素な関数から作ります。パフォーマンスは悪いですが、出発点としてはありでしょう。
void make_pow2numstring(numstring_t *base, int exp) {
convert_ull2numstring(base, 1);
if (exp > 0) {
for (int i = 0; i < exp; i++) {
mul2numstring(base);
}
} else if (exp < 0) {
for (int i = 0; i > exp; i--) {
div2numstring(base);
}
}
}
仮数部の加算
指数をbaseとして仮数部の53ビットを評価し、ビットが立っていればbaseを加算、baseを2で割るというループです。
for (int i = 52; i >= 0; i--) {
if ((bin.mantissa >> i) & 1) {
add_numstring(&tmp, &result, &base);
numstringcpy(&result, &tmp);
}
div2numstring(&base);
}
全体のコード
主要な部品が揃ったところで全体のコードを貼ります。
解説用に書き直した「doubleの正の正規化数をとりあえず表示する朴訥なコード例」でしかありません。
フォーマットのパースもしないのでprintfではないですし処理速度が遅いです。
#include <stdio.h>
#include <stdbool.h>
#include <stdint.h>
#include <unistd.h>
#include <float.h>
#include <string.h>
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define NUMSTRING_BUFFER_SIZE 1080
#define IEEE754_DOUBLE_SIGN_BITS 1
#define IEEE754_DOUBLE_SIGN_MASK 1
#define IEEE754_DOUBLE_SIGN_SHIFT 63
#define IEEE754_DOUBLE_EXPONENT_BIAS 0x3FF //1023
#define IEEE754_DOUBLE_EXPONENT_BITS 11
#define IEEE754_DOUBLE_EXPONENT_MASK ((1ULL << IEEE754_DOUBLE_EXPONENT_BITS) - 1)
#define IEEE754_DOUBLE_EXPONENT_SHIFT 52
#define IEEE754_DOUBLE_MANTISSA_BITS 52
#define IEEE754_DOUBLE_MANTISSA_MASK ((1ULL << IEEE754_DOUBLE_MANTISSA_BITS) - 1)
typedef struct {
char str[NUMSTRING_BUFFER_SIZE];
size_t len;
size_t dot_pos;
} numstring_t;
typedef struct {
bool sign;
int exp;
uint64_t mantissa;
} binary64_t;
typedef union {
double d;
uint64_t u;
} converter_t;
bool is_normal = true; // 今のところは正規化数のみとする
size_t num_of_digits(unsigned long long n) {
int count;
if (n == 0) return 1;
count = 0;
while (n) {
count++;
n /= 10;
}
return count;
}
void convert_ull2numstring(numstring_t *dst, unsigned long long ul) {
// 整数型を代入
size_t len;
memset(dst->str, 0, NUMSTRING_BUFFER_SIZE);
dst->str[0] = '0';
len = num_of_digits(ul);
dst->len = len;
while (len-- > 0 && ul > 0) {
dst->str[len] = (ul % 10) + '0';
ul /= 10;
}
dst->str[dst->len] = '\0';
dst->dot_pos = dst->len - 1;
}
size_t get_offset(numstring_t *a, numstring_t *b) {
// aとbの小数点位置を比較し、bの方が整数部の桁が多ければその桁数を返す
if (a->dot_pos < b->dot_pos)
return b->dot_pos - a->dot_pos;
return 0;
}
int get_digit(numstring_t *a, size_t index, size_t offset) {
if (index >= offset && (index - offset) < a->len) {
return a->str[index - offset] - '0';
}
return 0;
}
void add_numstring(numstring_t *dst, numstring_t *a, numstring_t *b) {
// 加算 dst = a + b
const size_t a_offset = get_offset(a, b);
const size_t b_offset = get_offset(b, a);
int carry;
size_t i;
size_t newlen = MAX(a->len + a_offset, b->len + b_offset);
i = newlen;
carry = 0;
while (i-- > 0) {
int sum = get_digit(a, i, a_offset) + get_digit(b, i, b_offset) + carry;
dst->str[i] = (sum % 10) + '0';
carry = sum / 10;
}
if (carry) {
memmove(dst->str + 1, dst->str, newlen);
dst->str[0] = carry + '0';
dst->len = newlen + 1;
dst->dot_pos = MAX(a->dot_pos + a_offset, b->dot_pos + b_offset) + 1;
} else {
dst->len = newlen;
dst->dot_pos = MAX(a->dot_pos + a_offset, b->dot_pos + b_offset);
}
dst->str[dst->len] = '\0';
}
void div2numstring(numstring_t *num) {
// 2による除算 num /= 2
int borrow;
int n;
size_t i;
borrow = 0;
i = 0;
while (i < num->len) {
n = borrow * 10 + get_digit(num, i, 0);
num->str[i] = n / 2 + '0';
borrow = n % 2;
i++;
}
if (borrow) {
num->str[i++] = borrow * 10 / 2 + '0';
}
num->str[i] = '\0';
num->len = i;
}
void mul2numstring(numstring_t *num) {
// 2による乗算 num *= 2
int carry;
int n;
size_t i;
carry = 0;
i = num->len;
while (i-- > 0) {
n = get_digit(num, i, 0) * 2 + carry;
num->str[i] = n % 10 + '0';
carry = n / 10;
}
if (carry) {
memmove(num->str + 1, num->str, num->len);
num->len++;
num->dot_pos++;
num->str[0] = carry + '0';
num->str[num->len] = '\0';
}
}
void numstringcpy(numstring_t *dst, numstring_t *src) {
strncpy(dst->str, src->str, NUMSTRING_BUFFER_SIZE);
dst->len = src->len;
dst->dot_pos = src->dot_pos;
}
void make_pow2numstring(numstring_t *base, int exp) {
convert_ull2numstring(base, 1);
if (exp > 0) {
for (int i = 0; i < exp; i++) {
mul2numstring(base);
}
} else if (exp < 0) {
for (int i = 0; i > exp; i--) {
div2numstring(base);
}
}
}
void convert_double2numstring(numstring_t *dst, double f) {
binary64_t bin;
numstring_t base, result, tmp;
long long exp;
converter_t bits;
bits.d = f;
bin.sign = (bits.u >> IEEE754_DOUBLE_SIGN_SHIFT) & IEEE754_DOUBLE_SIGN_MASK;
bin.exp = (bits.u >> IEEE754_DOUBLE_EXPONENT_SHIFT) & IEEE754_DOUBLE_EXPONENT_MASK;
bin.mantissa = bits.u & IEEE754_DOUBLE_MANTISSA_MASK;
exp = (bin.exp - IEEE754_DOUBLE_EXPONENT_BIAS);
if (is_normal) {
bin.mantissa |= (1ULL << 52);
}
make_pow2numstring(&base, exp);
convert_ull2numstring(&result , 0);
for (int i = 52; i >= 0; i--) {
if ((bin.mantissa >> i) & 1) {
add_numstring(&tmp, &result, &base);
numstringcpy(&result, &tmp);
}
div2numstring(&base);
}
numstringcpy(dst, &result);
}
int mk_putnstr(const char *s, size_t len) {
return write(STDOUT_FILENO, s, len);
}
int mk_putchar(const int c) {
return write(STDOUT_FILENO, &c, 1);
}
void putnumstring(numstring_t *num) {
mk_putnstr(num->str, num->dot_pos + 1);
mk_putchar('.');
mk_putnstr(num->str + num->dot_pos + 1, num->len - num->dot_pos);
mk_putchar('\n');
}
void putdouble(double f) {
numstring_t num;
convert_double2numstring(&num, f);
putnumstring(&num);
}
printfとの結果を見比べてみましょう。
以下のコードを足してビルド&&実行してみます。
void compare_put(double f) {
putdouble(f);
printf("%.100f\n", f);
}
int main() {
compare_put(1.625);
compare_put(0.1);
compare_put(123456789.012345);
compare_put(123456789);
compare_put(1234567890123456789012345.12345);
compare_put(0.000000000000000125);
compare_put(DBL_MIN);
}
1.625
1.6250000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
0.1000000000000000055511151231257827021181583404541015625
0.1000000000000000055511151231257827021181583404541015625000000000000000000000000000000000000000000000
123456789.01234500110149383544921875
123456789.0123450011014938354492187500000000000000000000000000000000000000000000000000000000000000000000000000
123456789.
123456789.0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
1234567890123456824475648.
1234567890123456824475648.0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
0.00000000000000012500000000000000971317498458263490478839820014937689318657021431135945022106170654296875
0.0000000000000001250000000000000097131749845826349047883982001493768931865702143113594502210617065430
0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002225073858507201383090232717332404064219215980462331830553327416887204434813918195854283159012511020564067339731035811005152434161553460108856012385377718821130777993532002330479610147442583636071921565046942503734208375250806650616658158948720491179968591639648500635908770118304874799780887753749949451580451605050915399856582470818645113537935804992115981085766051992433352114352390148795699609591288891602992641511063466313393663477586513029371762047325631781485664350872122828637642044846811407613911477062801689853244110024161447421618567166150540154285084716752901903161322778896729707373123334086988983175067838846926092773977972858659654941091369095406136467568702398678315290680984617210924625396728515625
0.0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
うまく動いているようです。
printfのように精度設定がなく、そのかわりに桁が存在するところまではすべて表示されます。末尾の0はつきません。また、小数部がないのに小数点が出ていますが、これは少しコードを追加することで消すことも可能です。しかしここではあまり意味がない(実際にはprintfのフォーマットの指定により別途さまざまなオプション的処理が必要になります)ので無視します。
compare_put(0.000000000000000125)
の出力を見ると、printfと末尾が違うことがわかります。これは、printfでは精度の指定(ここでは100桁)によって、小数点以下第101位で「丸め」が発生しているためです。一方、作成した関数では丸めることなく全桁が表示されています。こうした部分もprintfとして実装する際には考慮する必要があります。
0.00000000000000012500000000000000971317498458263490478839820014937689318657021431135945022106170654296875
0.0000000000000001250000000000000097131749845826349047883982001493768931865702143113594502210617065430
ラストのDBL_MINの出力も同じです。DBL_MINは小数点以下第308位で初めて非0が現れますから、精度100では0しか表示されないのです。試しに"%.1022f"に変更すると全桁が表示されます。
繰り返しになりますが、putdouble()
は正の正規化数にのみ対応した関数です。負の数のハンドルや非正規化数、InfinityやNaNへの対応も必要になってきます。これらは指数部と仮数部の条件によって分岐させるだけですのでここでは触れません。
また、処理速度の観点から様々な最適化を必要とします。わたしは2の乗算、除算を拡張し、整数型で収まる範囲に分けて多倍長整数の演算が行えるようにしました。加えてビットごとに2の除算を行うのは無駄なので必要なときのみbaseを再計算するようにもしています。
多倍長整数同士の乗算ならばカラツバ法などを導入することで高速化が可能です。そもそも文字列で持つというデータ構造が遅いということもありますので、別の形式にすることを検討してもいいでしょう。
次回はlong doubleについて簡単に触れておこうと思います。4倍精度になるとさらにロジックの朴訥さが目に見えて処理時間に響いてきます。
Discussion