⛏️

[Package] 最適輸送問題のパッケージを試す | POT, OTT-jax, OptimalTransport.jl

2021/09/16に公開

最適輸送 (Optimal Transport) は近年機械学習などの分野で注目されているトピックです.

https://speakerdeck.com/eumesy/how-to-leverage-optimal-transport

https://sites.google.com/view/uda-0x-seminar/home/0x03

私も前世に,Optimal Transportに関するPDFファイル (というか本?でしょうか) を読んだ記録を書いていたことがあります (以下のQiita検索を見てください).

https://qiita.com/search?q=takilog+computational+optimal+transport

これまでPythonの代表的なパッケージとしてPOT (Python Optimal Transport),JuliaのパッケージとしてOptimalTransport.jl などがありましたが,最近ネットサーフィンしていると Google Research謹製のJAXパッケージ Optimal Transport Tools (OTT) と呼ばれるものが登場していたことに気付いたので比較をしてみました.またPOTとOTT-jaxの比較については公式ドキュメントが公開しているものを実行して挙動を確認しました.

前提

そもそもJAXパッケージというものについて考えてみます.残念ながら私はそこまで詳しくないのですが

https://zenn.dev/koshian2/articles/af6758a5f3efc2

このような記事を参考にすると,GPUやTPUなどをサポートしつた高速なNumpyと思えば良さそうです.OTの計算で,SinkhornやEntropy-regularized Sinkhornでは行列計算をたくさん使うことになりますので,高速な実装ができそうな気配がありますね.まずは問題としては次のようなOTのインスタンス (\mathbf{a}, \mathbf{b}, \mathbf{C}) を解きます (これは公式のドキュメントのものです).以下はJuliaのコードです.

using Distances

n = 10 # data size
dims = 3 # random data dimension
a = rand(n)
b = rand(n)
x = rand(n, dim) # data は 1次元目
y = rand(n, dim) # ↑

# 1次元目をデータとしたpair-wise distance
# 距離はSqEuclidean()
C = pairwise(Distances.SqEuclidean(), x, y, dims=1) 

こうして周辺分布\mathbf{a}, \mathbf{b}と輸送コスト\mathbf{C}を作成したら,あとはPython/Juliaのコードを呼び出すのみです.

Python

実装については公式のドキュメントにあるnotebookを参考にしてください.

CPU

まずはCPU環境で実験しました.反復回数は1000,誤差の終了条件は公式通り 1e-2 になっています.

実行したローカル環境はRyzen 7 Pro 4750G + 32GB main memoryのローカルPCです.データの次元は公式では大きいですが,ローカルの環境だったので n\in[32, 64, 128, 256, 512, 1024] としました.エントロピー正則化の係数は公式の通り \epsilon\in[0.1, 0.01]としました.結果は以下のグラフの通りです.

行列計算の重さに従って計算時間が増えていく様子が見れます.

GPU

次にGPU環境で実験しました.反復回数は1000,誤差の終了条件は公式通り 1e-2 になっています.

実行した環境はGoogle ColaboratoryのFreeで使えるノートブック環境で,ランタイムをGPUにしてから同じノートブックで実験を行いました.

結果は図の通りです.POTについてはCPUの実行とほとんど同じ挙動を示します.一方で,OTT-jaxについてはn=1024の環境までほとんど計算時間が一定になりました.背景で利用されているJAXの性質 (日本語が変ですが) によって,高速な演算が行われている様子が見て取れます.公式のノードブックではn=4096の範囲までなかなかの速度で計算が可能であることが示唆されています.

Julia

ところで最近界隈でたまに名前を聞くJuliaという言語がありますが,こちらにもOptimal Transport用のライブラリが用意されています.似たような例題をJuliaで実行することにし,挙動を見てみましょう.こちらについてもCUDA.jlなどのサポートがありますが,今回はJuliaのGPU環境についてまでは検証できなかったので,CPU版とのみ比較することにします.

Pythonでは*%timeit%*を利用して時間計測しているのですが,Julia版では怠けて適当に5回試した平均だけを取得することにしました.実験部分の実装はこのようなものです.

using Statistics
using Distances
using OptimalTransport

dim = 3
for n in [32, 64, 128, 256, 512, 1024]
    # 人工データ
    x = rand(n, dim)
    y = rand(n, dim)
    a = rand(n)
    b = rand(n)
    a ./= sum(a)
    b ./= sum(b)
    C = pairwise(Distances.SqEuclidean(), x, y, dims=1)

    # ここはクソ測定コードです
    sol = Float64[]
    for eps in [0.01, 0.1]
        for _ in 1:5
            t1 = time()
            P = sinkhorn(a, b, C, eps; maxiter=1000, atol=1e-2)
            tsolve = time() - t1
            push!(sol, tsolve)
        end

        # store mean(sol) to plot
    end
end

アルゴリズムの実装ですが,通常のSinkhornアルゴリズム (SinkhornGibbs),log-domainで実装したもの (SinkhornStabilized) などが実装されている様子だったので,ここではその2つを比較しました.Sinkhornアルゴリズムの終了条件ですが,POTやOTT-jaxで threshold=1e-2 など指定されていたものは atol という引数に相当するように読んだので,ここでは atol に指定しました.

結果です.SSがSinkhornStabilized,SOが通常のSinkhornを使っています.

PythonとJuliaの乱数データを揃えていないので見た感触だけですが,軽く見た感じだとJuliaの実装もなかなかそこそこ良い感じがしますね.今度はCUDA.jlを使える環境で加速したOptimalTransport.jlも比較してみたいですね.

Discussion