💩

【交差エントロピー】y[np.arange(batch_size), t] の意味

2022/07/26に公開約900字

交差エントロピーを学習しています。
『ゼロから作るDeep Learning』の中で詰まったところがあるので、まとめておきます。

前提

こちら の内容を理解するにあたっての話です。

疑問

これの意味がわからん、、、

# 出力データ
y = np.array([
    [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0], 
    [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], 
    [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
])

# 教師データ(答え)
t = np.array([
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], 
    [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
])

t = np.argmax(t, axis=1)
# -> [2 8 2]

# 出力から各データの正解ラベルに関する要素を抽出
y[np.arange(batch_size), t] ・・・①
# -> [0.6 0.1 0.1]

なぜこのコードで 正解ラベルに関する要素を抽出できるのか、、?
確かにそうなっているんだけれども、、。

答え

np.arange の使用法がポイントでした。
①はこんな感じです。

# yの行・列を配列型に引数として入れてることで、まとめて指定している
y[np.arange(batch_size), t]
-> y[np.array([0, 1, 2]), np.array([2, 8, 2])
-> y[yの行, yの列]
-> y[0][2], y[1][8], y[2][2]を取ってきてください!
-> よって正解ラベルに関する要素を抽出ができた

Discussion

ログインするとコメントできます