Closed1
pytorchのKLDivLossとCrossEntropyLossの違いについて

KLダイバージェンスとCross Entropyの間には以下の関係がある。
数学的には両者は定数項(エントロピーの項)を除いて一致するが、nn.CrossEntropyLoss
[1]とnn.KLDivLoss
[2]では微妙な違いがある。
CrossEntropyLoss
- 入力値として正規化されていない対数確率を想定する
- 内部で確率的な正規化処理を実施する(Softmaxの全確率で割る部分)
- reduce="mean"は集計結果を「全ての次元で」平均する
KLDivLoss
- 入力値として正規化されていない数確率を想定する
- 内部で確率的な正規化処理を実施しない
- reduce="batchmean"は集計結果を加えた後で「バッチの次元」で割る(注1)
注1:デフォルトのreduce="mean"
は現状ではCrossEntropyLossと同じ振る舞いをするが、将来的にreduce="batchmean"
と同じ振る舞いに統一されるらしい。
両者の処理的な違い
両者の計算過程を簡単なグラフで表すと以下のようになる。KLDivLoss
の計算ではエントロピーの項を無視している。(ブラケットで囲われた名前は入力と出力を表す。それ以外の名前は処理を表す。)
-
CrossEntropyLoss
: [input
] ->exp
->norm
->log
->mul(gt)
-> [output
] -
KLDivLoss
: [input
] ->mul(gt)
-> [output
]
すなわち、KLDivLoss内部的な正規化処理(exp
-> norm
-> log
)を省略して、入力の対数確率を直接target変数とかけてpointwise lossを計算している。したがって、CrossEntropyLossは「入力の対数確率が正規化されている」という制約を自動的に付加しているのに対し、KLDivLossにはそのような制約が加わっていないという違いがある。
実際、入力の対数確率が正規化されていない前提であれば両者は(集約方法やエントロピーによる違いを無視しても)異なる値となる。
def ce_loss(input, target, eps=1e-4):
batch_size = input.shape[0]
pred = input.exp()
pred /= pred.sum(dim=-1, keepdim=True)
return -(target * (pred + eps).log()).sum() / batch_size
def kl_div_loss(input, target, eps=1e-4):
batch_size = input.shape[0]
loss = target * ((target + eps).log() - input)
return loss.sum() / batch_size
Reference
このスクラップは2023/11/16にクローズされました