🎲

確率単体への射影操作について(行列版の実装)

2022/07/16に公開

まえがき

こちらの記事の続きです。

https://zenn.dev/takilog/articles/639ef513465fd0

原典です。

https://arxiv.org/abs/1309.1541

まとめて N 本のベクトルを射影する

以前 D 次元ベクトルを射影する操作について書きました。arXivの原典を読んでいくと、Matlab版のコード(まとめて射影する版)がありましたので、そちらをJuliaで再実装してみたいと思います。Matlab版のコードはこちらです。

function X = SimplexProj(Y)

[N, D] = size(Y);
X = sort(Y, 2, 'descend');
Xtmp = (cumsum(X, 2) - 1) * diag(sparse(1 ./ (1:D)));
X = max(bsxfun(@minus, Y, Xtmp(sub2ind([N, D], (1:N)', sum(X > Xtmp, 2)))), 0);

以下に1つの D 次元ベクトルに対するアルゴリズムを再掲します。

上のMatlab版コードでは、それぞれ次のような実装がされていることが分かります。

  • 入力は N 本の D 次元ベクトルです。各行を確率単体へ射影します。
  • Algorithm 1 で入力をソートしていますが、行列 Y の列ごとにソートをしています。
  • 各行について、ソートした値の累積和 \sum_{i=1}^{j} u_i を計算しつつ、\frac{1}{j} を対角行列にして行列計算で求めます(Algorithm 1 とは符合が反転しています)。
  • 各行のデータに対して \rho に相当するインデクスを求め、sub2indで行列のインデクスへ変更して値を取ってきて、各行に対してbroadcastして引き算します。

Julia版での実装

昔(0.5の時代とか)はMatlabコードをそのままパクれば動いたのですが、最近は関数が消されてきれいになってしまっているので、真面目に書く必要があります。メモリ効率や真面目な最適化を無視するならば、以下ぐらいのコードが出来上がります。

using LinearAlgebra

function euclidean_projection_matrix(Y)
    N, D = size(Y);
    X = sort(Y, dims=2, rev=true)
    # diagmを使うのが無駄な気がするので、broadcastとtransposeする方がいいのかも
    Xtmp = (cumsum(X, dims=2) .- 1) * diagm(1 ./ (1:D))

    # インデクス処理が難しいので、各行について使うべき rho を求める
    rhovec = sum(X .> Xtmp, dims=2)
    Vrho = [Xtmp[j, rhovec[j]] for j in 1:N]

    # broadcastして計算する
    max.(Y .- Vrho, 0)
end

以前の記事で作成した1本の D 次元ベクトル向けの関数 euclidean_projection と比較してみます。次のような頭の悪い実装を用意します。

function euclidean_projection_stupid(Y)
    X = zeros(size(Y))
    for j in 1:size(Y, 1)
        # 前回作った関数を各行に適用する
        X[j, :] = euclidean_projection(Y[j, :])
    end
    X
end

あとはBenchmarkToolsで比較してみましょう。

N = 1000
D = 10

# 行列版 vs 繰り返し版
Y = rand(N, D)
@benchmark euclidean_projection_matrix(Y)
@benchmark euclidean_projection_stupid(Y)

データ数 N=1000、次元 D=10 の場合の行列版の出力です。

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  217.996 μs …   2.894 ms  ┊ GC (min … max): 0.00%85.23%
 Time  (median):     246.576 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   274.965 μs ± 173.486 μs  ┊ GC (mean ± σ):  6.27% ±  8.77%

  ██                                                             
  ██▇▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▂▂▁▁▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂ ▂
  218 μs           Histogram: frequency by time         1.62 ms <

 Memory estimate: 492.59 KiB, allocs estimate: 36.

データ数 N=1000、次元 D=10 の場合の繰り返し版の出力です。Nが大きくなってくると差が出てくるみたいです(手元ではN=10 ぐらいまでは行列版が負けていた)。とはいえメモリの配置なども真面目に実装していないので、あくまでも参考までの結果になります。一方で次元数 D が上がってくると行列版が負ける場合が多かったです。真面目に実装すべきですね~(雑)。

BenchmarkTools.Trial: 8177 samples with 1 evaluation.
 Range (min … max):  416.836 μs …   7.037 ms  ┊ GC (min … max):  0.00%84.75%
 Time  (median):     470.058 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   598.489 μs ± 606.219 μs  ┊ GC (mean ± σ):  17.96% ± 15.32%

  █▇▅▃▃▂                                                        ▁
  ███████▇▆▆▄▁▄▃▃▁▃▃▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅██▇▇▇▇▇▆▅▆ █
  417 μs        Histogram: log(frequency) by time       3.87 ms <

 Memory estimate: 1.44 MiB, allocs estimate: 10891.

まとめ

今度こそ終わりです。

Discussion