機械に優しい統計計算
TL;DR
- Welfold's Online Algorithm
- Kahan Algorithm
- 世の中もう少し証明を丁寧にしてほしい(願望)
数値データの解析
数値データの集合を分析する際に、大量のデータを俯瞰するために代表的な値を求めて、そこから分析することが多い。
- 最大値
- 最小値
- 中央値
- 平均値
- 分散(または、標準偏差)
などなど、どれか一つだけでなく、複数の値を用いたり、またデータのパーセンタイルなどで、数値を総合的に分析する。
平均値と分散の公式
平均値
データ1個あたりの量。
平均値は総和と個数の商。
[1]
分散データのばらつきを表す量。平均との残差の平方和の平均であらわされる。
あるあるな疑問として「どうして2乗するのか」、というものがあるが、実際計算すると普通の残差の和は結局
プログラムでの計算の問題点
数学で分散、平均値の公式を習った人は、プログラムでの計算でも同様に実装できるだろう。ただ、いくつか実行時に問題が生じる場合がある。
- 型の範囲外の値になる
- 例えばC言語のint型は、環境にもよるが最大値は大体
(32bit整数型の場合。)2^{31} - 1 = 2147483647
- 例えばC言語のint型は、環境にもよるが最大値は大体
- メモリ、時間を大量に消費する場合がある
- データ群が大規模な場合、一つのプログラムで使用可能なメモリを超えてしまう場合もある。
- (今回主題ではないが)
0.1 + 0.2 = 0.30000001
など、小数の表現方法の問題で、誤差が生じる場合がある
これらはあくまで単純な公式での導出の場合に発生する。
単純な公式を用いた実装の場合、分散の計算で平均が必要な都合上、平均を求めたあとにまた同じ回数だけ計算を行うことに成る。
もちろんこうした解決にはPythonを使うとか型を変えるなどが考えられます。
単純にアルゴリズムで改善しようというのが今回のテーマです。
計算誤差を抑えるアルゴリズム
Kahanのアルゴリズム
計算上カットされる誤差を補正項として保存することで、計算時に補正項を考慮した計算を行うことで計算機の加算によって生じる誤差を抑えるアルゴリズム。
class Kahan {
public:
Kahan () {}
Kahan (double value) {}
void add(double x) {
double y = x - _c;
double t = _s + y;
_c = (t - _s) - y;
_s = t;
}
void operator=(double x) {
_s = x;
_c = 0.0;
}
void operator+=(double x) {
add(x);
}
operator double() const {
return _s;
}
private:
double _s = 0.0;
double _c = 0.0;
};
Welfoldのアルゴリズム
数式での導出の方針はシンプルです。目的は、無駄な計算を減らすこと・型からあふれるようなオーダーの大きい計算をしないこと・値を取り出していくループ内で逐次計算できることです。
こうした方針の場合、特に最後の目的の達成のために考えられるのは、公式の変形で漸化式にすることで、値を加えていく、ということが可能です。
こうしたアルゴリズムは、Welfold's Online Algorithm
として知られています。
もちろんこのアルゴリズムには欠点があり、計算機のデータの型の都合上、計算するたびに誤差が発生しやすくなります。この問題については、上に記したKahanのアルゴリズムで誤差を抑える、などが対策として考えれられます。
平均値
逐次計算における平均値は以下のように求めます。
詳細な証明
を、
式(1)について:
n-1個のデータの平均を求めるには、n-1を分母にもってくる必要があるため、倍分(言い換えると、かけても変わらない1をかけていることにして分子と分母に必要な数を導く)してn-1を使えるようにする。
今までのサンプルでの平均を求めて置き、そこにデータを1個加えたときにどうなるか、がポイント。数式での導出の方針は今までの平均と、新たなデータ、またループ内で扱える変化量を用いて計算できる式をゴールにするといい感じに求められます。
分散の導出
平方和の場合、
詳細な証明
データの個数は、ループ内で取得できるので、ここではn個時点での残差の平方和
をn-1個の残差の平方和とn個目のデータで計算できるように変形します。
やることは単純ですが純粋に項数が多い導出です。
実装 ― C++
class Welford {
public:
void addSample(double x) {
_n++;
double delta = x - _m;
_m += delta / _n;
double delta2 = x - _m;
_M2 += delta * delta2;
}
double variance() const {
// return _M2 / (_n - 1);
return _M2 / _n;
}
double avarage() const {
return _m;
}
private:
Kahan _M2 = 0.0;
Kahan _m = 0.0;
int _n = 0;
};
実装 ― Rust
せっかくなのでRustでも考えてみましょう。浮動小数点はf64
を用います。
Rustだと難しそうなイメージを持つ方もいると思いますが、単純な計算であれば他のプログラミング言語とあまり変わらないと思います。
struct Welford {
n: u32,
mean: f64,
m2: f64,
}
impl Welford {
fn new() -> Self {
Welford {
n: 0,
mean: 0.0,
m2: 0.0,
}
}
fn add(&mut self, x: f64) {
self.n += 1;
let delta = x - self.mean;
self.mean += delta / self.n as f64;
let delta2 = x - self.mean;
self.m2 += delta * delta2;
}
fn variance(&self) -> f64 {
self.m2 / self.n as f64
}
fn average(&self) -> f64 {
self.mean
}
}
Kahanのアルゴリズムを適用させる場合
C++同様に、Kahan用の構造体を用意して、加算はstd::ops::AddAssign
トレイトを新たに実装することで表現できます。
また計算時にはf64
として扱えるようにFrom
トレイトを実装します。
use rand_distr::StandardNormal;
use rand::prelude::*;
use std::ops::AddAssign;
#[derive(Clone)]
struct Kahan {
sum: f64,
compensation: f64,
}
impl Kahan {
fn new() -> Self {
Self { sum: 0.0, compensation: 0.0 }
}
fn add(&mut self, x: f64) {
let y = x - self.compensation;
let t = self.sum + y;
self.compensation = (t - self.sum) - y;
self.sum = t;
}
}
impl AddAssign for Kahan {
fn add_assign(&mut self, rhs: Kahan) {
self.add(rhs.sum);
}
}
impl From<f64> for Kahan {
fn from(x: f64) -> Self {
Self { sum: x, compensation: 0.0 }
}
}
impl From<Kahan> for f64 {
fn from(x: Kahan) -> Self {
x.sum
}
}
struct Welford {
n: u32,
mean: Kahan,
m2: Kahan,
}
impl Welford {
fn new() -> Self {
Welford {
n: 0,
mean: Kahan::new(),
m2: Kahan::new(),
}
}
fn add(&mut self, x: f64) {
self.n += 1;
let delta = x - self.mean.sum;
self.mean += (delta / self.n as f64).into();
let delta2 = x - self.mean.sum;
self.m2 += (delta * delta2 as f64).into();
}
fn variance(&self) -> f64 {
self.m2.sum / self.n as f64
}
fn average(&self) -> f64 {
self.mean.clone().into()
}
}
副次的効果: 代表値のマージが容易になる
Welfordのアルゴリズムの実装は、残差の平方和、平均、個数を保存するため、分散コンピューティングでの結果を再計算する必要なく求めることが可能です。
公式でそのまま求めるより、再設計を考えたことで、精度の安定化やメモリ節約だけでなく、分散システムへの適性も持つようになりました。
// Welfordのメンバ関数に追加
Welford merge(const Welford &rhs) const {
Welford r;
double ma = _m;
double mb = rhs._m;
double N = _n;
double M = rhs._n;
double N_M = N + M;
double a = N / N_M;
double b = M / N_M;
r._mean = a * ma + b * mb;
r._M2 = _M2 + rhs._M2;
r._n = N_M;
return r;
}
残差の平方和は単純に足して個数の合計で割ることで求められ、平均も計算すると内分点の公式みたいな感じになります。 ここでは導出は省略しますが、見たい方は下の参考サイトに導出している方がいらっしゃるので、そちらも見てみてください。
実際の動作
正規分布を生成して、その平均や標準偏差を求める。
int main() {
Welford stats;
std::default_random_engine generator;
std::normal_distribution<double> distribution(0.0, 1.0);
const int N = 10000;
for (int i = 0; i < N; i++) {
double x = distribution(generator);
stats.addSample(x);
}
double mean = stats.average();
double variance = stats.variance();
double stdev = sqrt(variance);
std::cout << "Mean: " << mean << std::endl;
std::cout << "Standard deviation: " << stdev << std::endl;
return 0;
}
実行結果:
Mean: -0.000139979
Standard deviation: 0.993975
ほぼ近い値に計算できている。
だいぶ誤差をおさえらていることが確認できる。
同条件でRust(ただし、Kahanの補正はなし)
use rand_distr::StandardNormal;
use rand::prelude::*;
struct Welford {
n: u32,
mean: f64,
m2: f64,
}
impl Welford {
fn new() -> Self {
Welford {
n: 0,
mean: 0.0,
m2: 0.0,
}
}
fn add(&mut self, x: f64) {
self.n += 1;
let delta = x - self.mean;
self.mean += delta / self.n as f64;
let delta2 = x - self.mean;
self.m2 += delta * delta2;
}
fn variance(&self) -> f64 {
self.m2 / self.n as f64
}
fn average(&self) -> f64 {
self.mean
}
}
fn main() {
let mut rng = thread_rng();
let normal = StandardNormal;
let mut stats = Welford::new();
const N: u32 = 10000;
for _ in 0..N {
let x = normal.sample(&mut rng);
stats.add(x);
}
let mean = stats.average();
let variance = stats.variance();
let stdev = variance.sqrt();
println!("Mean: {}", mean);
println!("Standard deviation: {}", stdev);
}
実行結果:
Mean: 0.025845755396603033
Standard deviation: 0.9996649616854083
//Kahanアルゴリズムの場合
Mean: -0.0033980121855476195
Standard deviation: 0.9947184865426071
違う言語とはいえ精度の差はやはりKahanの補正の有無が大きい気がする。
おまけ: マージを試してみる
// stats2の分布: (0.0, 2.0)
Welford stats3 = stats1.merge(stats2);
double mean3 = stats3.average();
double variance3 = stats3.variance();
double stdev3 = sqrt(variance3);
std::cout << "Mean: " << mean3 << std::endl;
std::cout << "Standard deviation: " << stdev3 << std::endl;
実行結果:
Mean: -0.000139979
Standard deviation: 0.993975
Mean: 0.0398515
Standard deviation: 2.03155
Mean: 0.0198558
Standard deviation: 1.59924
設計の恩恵を感じる・・・・・・。
参考にしたサイト
日本語版Wikiほしいなぁ()
-
高校などは
で表している。s^2 だと標準偏差になる。s は統計や機械学習でよく見られる。平均は\sigma^2 など。 ↩︎\mu
Discussion