💩
【交差エントロピー】y[np.arange(batch_size), t] の意味
交差エントロピーを学習しています。
『ゼロから作るDeep Learning』の中で詰まったところがあるので、まとめておきます。
前提
こちら の内容を理解するにあたっての話です。
疑問
これの意味がわからん、、、
# 出力データ
y = np.array([
[0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
])
# 教師データ(答え)
t = np.array([
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
])
t = np.argmax(t, axis=1)
# -> [2 8 2]
# 出力から各データの正解ラベルに関する要素を抽出
y[np.arange(batch_size), t] ・・・①
# -> [0.6 0.1 0.1]
なぜこのコードで 正解ラベルに関する要素を抽出できるのか、、?
確かにそうなっているんだけれども、、。
答え
np.arange
の使用法がポイントでした。
①はこんな感じです。
# yの行・列を配列型に引数として入れてることで、まとめて指定している
y[np.arange(batch_size), t]
-> y[np.array([0, 1, 2]), np.array([2, 8, 2])
-> y[yの行, yの列]
-> y[0][2], y[1][8], y[2][2]を取ってきてください!
-> よって正解ラベルに関する要素を抽出ができた
Discussion