点群の最適輸送アルゴリズムで遊んでみた
この記事は、「最適輸送の理論とアルゴリズム」を読んでいて気になった部分を自分で実装して試してみたものを雑にまとめたものです。今回は点群の最適輸送について、実際にPythonで実装をして実験してみた結果を書きます。最適輸送理論は機械学習においてロス関数の設計に役立つらしい(本をまだ全部読んでない...)ので、シンプルな線形計画問題で解ける簡単なアルゴリズムを一度実装してみました。
最適輸送問題とは
まず、そもそも最適輸送問題って何?という話なのですが、
- 2つの分布
,D_1 があるD_2 - 分布を移動させるときのコストが定義されている
という状況で
点群の最適輸送問題の定式化
点群Aを別の点群Bに移動させる問題を考えます。
と書くことができます。また制約条件としては、
の3つがあります。2番目と3番目の制約条件は質量保存の法則的なもので、輸送の過程で点群の質量が失われないという条件です。
Pythonによる実装
上の定式化を見るとわかるように、これは制約付きの線形計画問題なのでScipyなどに用意されているソルバーを使って解くことができます。
if __name__ == "__main__":
start = np.array([[2.2, 2.1], [3.2, 5.3], [4.5, 4.4], [3.1, 3.8]])
end = np.array([[4.8, 1.9], [4.1, 3.3], [2.0, 5.5], [3.4, 2.5]])
a = np.ones(4) / 4
b = np.ones(4) / 4
C = calc_cost(start, end)
P = solve_transport_problem(C, a, b)
print("P: ", P)
print("Minimum Cost: ", np.sum(P * C))
ここでは本と同じ例を使って解いてみます。start
と end
がそれぞれ輸送前と輸送のターゲット先の点群の座標です。分布の点の重みは等しく、総和が1になるようにしています(a
, b
)
def calc_cost(start: np.ndarray, end: np.ndarray) -> np.ndarray:
cost = np.zeros((len(start), len(end)))
for i in range(len(start)):
for j in range(len(end)):
cost[i, j] = np.linalg.norm(start[i] - end[j]) ** 2
return cost
輸送コストの設計ですが、簡単のため点間の二乗距離で定義します。
def solve_transport_problem(
C: np.ndarray,
a: np.ndarray,
b: np.ndarray,
) -> np.ndarray:
num_x = len(a)
num_y = len(b)
c = C.flatten()
A = []
# \sum_{j}P_{i, j} = a_i
for i in range(num_x):
A_i = np.zeros((num_x, num_y))
A_i[i, :] = 1
A.append(A_i.flatten())
# \sum_{i}P_{i, j} = b_j
for j in range(num_y):
A_j = np.zeros((num_x, num_y))
A_j[:, j] = 1
A.append(A_j.flatten())
A = np.array(A)
b = np.concatenate([a, b])
res = linprog(c, A_eq=A, b_eq=b, method="highs")
P = res.x.reshape((num_x, num_y))
return P
scipyの linprog
メソッドを利用して解くと
P: [[-0. 0. 0. 0.25]
[ 0. 0. 0.25 0. ]
[ 0.25 0. 0. 0. ]
[-0. 0.25 0. 0. ]]
Minimum Cost: 2.6675000000000004
が得られます。これは start
の最初の点 [2.2, 2,1]
は end
の最後の要素の点 [3.4, 2.5]
に移動させるのが最適ということを示してます。視覚的に表すならば、
のようになります(赤: start
、青:end
)。
code
def plot_transport(start: np.ndarray, end: np.ndarray, P: np.ndarray):
plt.figure()
plt.scatter(start[:, 0], start[:, 1], c="r", label="start")
plt.scatter(end[:, 0], end[:, 1], c="b", label="end")
for i in range(P.shape[0]):
for j in range(P.shape[1]):
if P[i, j] > 0:
plt.annotate(
"",
xy=end[j],
xytext=start[i],
arrowprops=dict(arrowstyle="->", color="k", lw=1),
)
plt.show()
コスト関数を変えたときの挙動の変化
コスト関数の定義の仕方を色々変えてみたときに最適な点群の輸送がどう変わるか見てみます。
逆二乗
ユークリッド距離が大きくなるほどコストが小さくなるような構造なので、なるべく離れた点に輸送が行われるようになっていることがわかります。
非対称なコスト関数
感想
この記事の内容は本の最初の10 %ぐらいの内容なのですが、もう既に残り90 %を理解しきれるか心配です。ただ、実際に数値例を使って計算してみることで機械学習にどのように使えるのか、なんとなくですが理解できた気がするのかな、と思っています。
Discussion