🍉
np.newaxis を使ってユークリッド距離を計算する方法
この記事の概要
np.newaxis
を使って賢く行列計算する方法が理解できなすぎて、書きながら分かったら良いなとの思いでこの記事を書いています。この記事では、np.newaxis
を使ってユークリッド距離の計算を高速化する方法を丁寧に追っていきます。
結論
を入力として、
を出力することを考える。
ただし
とする。
このとき、
def euclidean_distance(X,Y):
X = X[:, np.newaxis, :]
Y = Y[np.newaxis, :, :]
Z = np.sqrt(np.sum((X - Y) ** 2, axis=2))
return Z
何がどうなった
全然分からないので、処理を書き下して追ってみることにします。
簡単のため
まず、X[:, np.newaxis, :]
は以下のようになります。
[[[x11 x12]]
[[x21 x22]]]
Y[np.newaxis, :, :]
は以下になります。
[[[y11 y12]
[y21 y22]
[y31 y32]]]
X - Y
の計算でブロードキャストが効いて、X
は
[[[x11 x12][x11 x12][x11 x12]]
[[x21 x22][x21 x22][x21 x22]]]
に、Y
は
[[[y11 y12][y21 y22][y31 y32]]
[[y11 y12][y21 y22][y31 y32]]]
になります。X - Y
は
[[[x11-y11 x12-y12][x11-y21 x12-y22][x11-y31 x12-y32]]
[[x21-y11 x22-y12][x21-y21 x22-y22][x21-y31 x22-y32]]]
になるので、** 2
してaxis=2
に関してsumをとると、
[[(x11-y11)**2+(x12-y12)**2 (x11-y21)**2+(x12-y22)**2 (x11-y31)**2+(x12-y32)**2]
[(x21-y11)**2+(x22-y12)**2 (x21-y21)**2+(x22-y22)**2 (x21-y31)**2+(x22-y32)**2]]
これのsqrt
なので、確かに
まとめ
追ったら計算が合うことは分かった。
全然抽象化して理解できてないので、考え方のコツとかあったら教えてください。
Discussion