🎲
確率単体への射影操作について(行列版の実装)
まえがき
こちらの記事の続きです。
原典です。
N 本のベクトルを射影する
まとめて 以前
function X = SimplexProj(Y)
[N, D] = size(Y);
X = sort(Y, 2, 'descend');
Xtmp = (cumsum(X, 2) - 1) * diag(sparse(1 ./ (1:D)));
X = max(bsxfun(@minus, Y, Xtmp(sub2ind([N, D], (1:N)', sum(X > Xtmp, 2)))), 0);
以下に1つの
上のMatlab版コードでは、それぞれ次のような実装がされていることが分かります。
- 入力は
本のN 次元ベクトルです。各行を確率単体へ射影します。D - Algorithm 1 で入力をソートしていますが、行列
の列ごとにソートをしています。Y - 各行について、ソートした値の累積和
を計算しつつ、\sum_{i=1}^{j} u_i を対角行列にして行列計算で求めます(Algorithm 1 とは符合が反転しています)。\frac{1}{j} - 各行のデータに対して
に相当するインデクスを求め、sub2indで行列のインデクスへ変更して値を取ってきて、各行に対してbroadcastして引き算します。\rho
Julia版での実装
昔(0.5の時代とか)はMatlabコードをそのままパクれば動いたのですが、最近は関数が消されてきれいになってしまっているので、真面目に書く必要があります。メモリ効率や真面目な最適化を無視するならば、以下ぐらいのコードが出来上がります。
using LinearAlgebra
function euclidean_projection_matrix(Y)
N, D = size(Y);
X = sort(Y, dims=2, rev=true)
# diagmを使うのが無駄な気がするので、broadcastとtransposeする方がいいのかも
Xtmp = (cumsum(X, dims=2) .- 1) * diagm(1 ./ (1:D))
# インデクス処理が難しいので、各行について使うべき rho を求める
rhovec = sum(X .> Xtmp, dims=2)
Vrho = [Xtmp[j, rhovec[j]] for j in 1:N]
# broadcastして計算する
max.(Y .- Vrho, 0)
end
以前の記事で作成した1本の euclidean_projection
と比較してみます。次のような頭の悪い実装を用意します。
function euclidean_projection_stupid(Y)
X = zeros(size(Y))
for j in 1:size(Y, 1)
# 前回作った関数を各行に適用する
X[j, :] = euclidean_projection(Y[j, :])
end
X
end
あとはBenchmarkToolsで比較してみましょう。
N = 1000
D = 10
# 行列版 vs 繰り返し版
Y = rand(N, D)
@benchmark euclidean_projection_matrix(Y)
@benchmark euclidean_projection_stupid(Y)
データ数
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 217.996 μs … 2.894 ms ┊ GC (min … max): 0.00% … 85.23%
Time (median): 246.576 μs ┊ GC (median): 0.00%
Time (mean ± σ): 274.965 μs ± 173.486 μs ┊ GC (mean ± σ): 6.27% ± 8.77%
██
██▇▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▂▂▁▁▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂ ▂
218 μs Histogram: frequency by time 1.62 ms <
Memory estimate: 492.59 KiB, allocs estimate: 36.
データ数
BenchmarkTools.Trial: 8177 samples with 1 evaluation.
Range (min … max): 416.836 μs … 7.037 ms ┊ GC (min … max): 0.00% … 84.75%
Time (median): 470.058 μs ┊ GC (median): 0.00%
Time (mean ± σ): 598.489 μs ± 606.219 μs ┊ GC (mean ± σ): 17.96% ± 15.32%
█▇▅▃▃▂ ▁
███████▇▆▆▄▁▄▃▃▁▃▃▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅██▇▇▇▇▇▆▅▆ █
417 μs Histogram: log(frequency) by time 3.87 ms <
Memory estimate: 1.44 MiB, allocs estimate: 10891.
まとめ
今度こそ終わりです。
Discussion