🍉

np.newaxis を使ってユークリッド距離を計算する方法

2022/05/28に公開

この記事の概要

np.newaxisを使って賢く行列計算する方法が理解できなすぎて、書きながら分かったら良いなとの思いでこの記事を書いています。この記事では、np.newaxisを使ってユークリッド距離の計算を高速化する方法を丁寧に追っていきます。

結論

\boldsymbol{X}=(\boldsymbol{x}_1,\boldsymbol{x}_2,\cdots,\boldsymbol{x}_{N})^{T}\in\mathbb{R}^{N\times D}\\ \boldsymbol{Y}=(\boldsymbol{y}_1,\boldsymbol{y}_2,\cdots,\boldsymbol{y}_{M})^{T}\in\mathbb{R}^{M\times D}

を入力として、

\boldsymbol{Z}\in\mathbb{R}^{N\times M}\\ z_{nm}=d(\boldsymbol{x}_n,\boldsymbol{y}_m)

を出力することを考える。
ただしdはユークリッド距離

d(x,y)=\sqrt{(x_1-y_1)^2+(x_2-y_2)^2+\cdots(x_D-y_D)^2}

とする。

このとき、\boldsymbol{Z}は以下の通り計算できる。

def euclidean_distance(X,Y):
  X = X[:, np.newaxis, :]
  Y = Y[np.newaxis, :, :]
  Z = np.sqrt(np.sum((X - Y) ** 2, axis=2))
  return Z

何がどうなった

全然分からないので、処理を書き下して追ってみることにします。
簡単のためN=2, M=3, D=2で考えます(一般化は諦めた)
まず、X[:, np.newaxis, :]は以下のようになります。

[[[x11 x12]]
 [[x21 x22]]]

Y[np.newaxis, :, :]は以下になります。

[[[y11 y12]
  [y21 y22]
  [y31 y32]]]

X - Yの計算でブロードキャストが効いて、X

[[[x11 x12][x11 x12][x11 x12]]
 [[x21 x22][x21 x22][x21 x22]]]

に、Y

[[[y11 y12][y21 y22][y31 y32]]
 [[y11 y12][y21 y22][y31 y32]]]

になります。X - Y

[[[x11-y11 x12-y12][x11-y21 x12-y22][x11-y31 x12-y32]]
 [[x21-y11 x22-y12][x21-y21 x22-y22][x21-y31 x22-y32]]]

になるので、** 2してaxis=2に関してsumをとると、

[[(x11-y11)**2+(x12-y12)**2 (x11-y21)**2+(x12-y22)**2 (x11-y31)**2+(x12-y32)**2]
 [(x21-y11)**2+(x22-y12)**2 (x21-y21)**2+(x22-y22)**2 (x21-y31)**2+(x22-y32)**2]]

これのsqrtなので、確かに\boldsymbol{Z}に一致していますね。

まとめ

追ったら計算が合うことは分かった。
全然抽象化して理解できてないので、考え方のコツとかあったら教えてください。

Discussion