🐁

バッチ学習における計算で次元を削減して良い理由

2023/09/28に公開
3

今回はバッチ学習における計算で次元を削減して良い理由について説明します。

何これ?

タイトルの通りです。
バッチ学習における計算で次元を削減して良い理由を考えます。

具体的には、自然言語処理モデルで次のようなコードを見かけ、直感的に納得できなかったので視覚的に考えました。

xs = input # 入力(3次元)
N, T, V = xs.shape # 入力の形状の取得
xs = xs.reshape(N * T, V) # 次元削減(3次元→2次元)

ys = softmax(xs) # 2次元データに対してソフトマックス関数

入力は3次元なのに、2次元に圧縮して計算しても問題ないのか?という部分が気になりました。
※補足:Nはバッチサイズを表しています。

説明

結論から言うと問題なく計算できます。

  1. 次元削減について、私は次のようなイメージを持っていました。
# 次元削減
dim2 = np.array([[1,1,1],[2,2,2]])
# [[1 1 1]
#  [2 2 2]]
dim1 = dim2.reshape(-1) # 2次元→1次元
print(dim1)
# 出力:[1 1 1 2 2 2]

これは正しい結果で、2次元→1次元にするとこうなります。特に問題なく(列方向で)計算できますが、個人的になんとなく行方向に混ざる(横方向に別々のデータが並ぶ)イメージがついていました。

  1. 本題ですが、3次元→2次元の場合は次のようになります。
import numpy as np

# 3次元入力 形状は (2, 3, 4)
x = np.array([
    [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
    [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
])
print("Original shape x:", x)
# Original shape x: 
# [[[ 1  2  3  4]
#   [ 5  6  7  8]
#   [ 9 10 11 12]]

#  [[13 14 15 16]
#   [17 18 19 20]
#   [21 22 23 24]]]

# 2次元→3次元
# 形状が (2 * 3, 4) = (6, 4)になる
x_reshaped = x.reshape(-1, 4)

print("Reshaped x:")
print(x_reshaped)
# Reshaped x:
# [[ 1  2  3  4]
#  [ 5  6  7  8]
#  [ 9 10 11 12]
#  [13 14 15 16]
#  [17 18 19 20]
#  [21 22 23 24]]

視覚的によくわかるように、データは縦方向に結合されるので、行方向で問題なく計算できます。これなら、直感的にも要素ごとにsoftmax関数を適用しても問題ないように感じますね。

まとめ

今回はバッチ学習における全結合で次元を削減して良い理由を説明しました。
視覚化すると分かりやすいですね。

2次元→1次元でも、3次元→2次元でも、x[0],x[1]…のように要素ごとに計算する場合は、問題なく計算できそうです。
ちなみに、簡単に元の形状に戻すこともできます。^{※1}

今回はここまでになります。最後まで読んでいただきありがとうございました!


補遺

今回はsoftmaxを取り上げましたが、同様の理由で全結合なども次元削減しての計算が可能です。
コードもほとんど同じになるので、全結合に対しても成り立つ理由がわかると思います。

^{※1}returnで元の形状に戻しています。

x = input # 入力(3次元)
N, T, D = x.shape # 入力の形状の取得
rx = x.reshape(N*T, -1) # 次元削減(3次元→2次元)

out = np.dot(rx, W) + b # 2次元データに対して全結合
return out.reshape(N, T, -1) # 元の形状に戻して結果を出力

以上です。

Discussion

yKesamaruyKesamaru

素晴らしい記事をありがとうございます😊

このような次元削減は初めて見ました😲
計算量が少し減る…というのは分かるのですが、実際嬉しいことは何でしょうか?🤔

教えて頂けると幸いです🙇

YutoYuto

コメントありがとうございます。

今回の次元削減の一番のメリットは仰る通り、計算量が減ることです。目的の大部分がここにあると思います。

副次的なメリットとしては、既存のライブラリ(損失関数など)への入力コードが簡潔になります。3次元の入力より2次元の入力に対応しているものが多いと思います。
個人的には、コードが頭の中で理解しやすくなるのも少し嬉しいですね笑

yKesamaruyKesamaru

なるほど。😲

既存のライブラリ(損失関数など)への入力コードが簡潔になります。3次元の入力より2次元の入力に対応しているものが多いと思います

勉強になりました!ありがとうございます!😆⭐️