Closed1

pytorchのKLDivLossとCrossEntropyLossの違いについて

bilzardbilzard

KLダイバージェンスとCross Entropyの間には以下の関係がある。

\mathcal{L}_\text{KL}(y, \hat{y}) = \mathcal{L}_\text{CE}(y, \hat{y}) - \mathcal{H}(y)

数学的には両者は定数項(エントロピーの項)を除いて一致するが、nn.CrossEntropyLoss[1]とnn.KLDivLoss[2]では微妙な違いがある。

CrossEntropyLoss

  1. 入力値として正規化されていない対数確率を想定する
  2. 内部で確率的な正規化処理を実施する(Softmaxの全確率で割る部分)
  3. reduce="mean"は集計結果を「全ての次元で」平均する

KLDivLoss

  1. 入力値として正規化されていない数確率を想定する
  2. 内部で確率的な正規化処理を実施しない
  3. reduce="batchmean"は集計結果を加えた後で「バッチの次元」で割る(注1)

注1:デフォルトのreduce="mean"は現状ではCrossEntropyLossと同じ振る舞いをするが、将来的にreduce="batchmean"と同じ振る舞いに統一されるらしい。

両者の処理的な違い

両者の計算過程を簡単なグラフで表すと以下のようになる。KLDivLossの計算ではエントロピーの項を無視している。(ブラケットで囲われた名前は入力と出力を表す。それ以外の名前は処理を表す。)

  • CrossEntropyLoss: [input] -> exp -> norm -> log -> mul(gt) -> [output]
  • KLDivLoss: [input] -> mul(gt) -> [output]

すなわち、KLDivLoss内部的な正規化処理(exp -> norm -> log)を省略して、入力の対数確率を直接target変数とかけてpointwise lossを計算している。したがって、CrossEntropyLossは「入力の対数確率が正規化されている」という制約を自動的に付加しているのに対し、KLDivLossにはそのような制約が加わっていないという違いがある。

実際、入力の対数確率が正規化されていない前提であれば両者は(集約方法やエントロピーによる違いを無視しても)異なる値となる。

def ce_loss(input, target, eps=1e-4):
    batch_size = input.shape[0]
    pred = input.exp()
    pred /= pred.sum(dim=-1, keepdim=True)
    return -(target * (pred + eps).log()).sum() / batch_size


def kl_div_loss(input, target, eps=1e-4):
    batch_size = input.shape[0]
    loss = target * ((target + eps).log() - input)
    return loss.sum() / batch_size

Reference

このスクラップは2023/11/16にクローズされました