📈

gensimでWord2VecのLoss値を求める

2024/04/01に公開

概要

表題通りgensimのWord2VecのLoss値を求める方法について備忘録として残します。

バージョン

gensim:4.3.1

実装

先にCallbackAny2Vecを継承したクラスを定義します。
gensimのWord2Vecのloss値は今までのエポックごとに学習したloss値を合算した値を持ちます。
そのために現在のエポックのloss値と1つ前のloss値の差を求める必要があります。

# CallbackAny2Vecを継承したクラス
class LossCallback(CallbackAny2Vec):
    def __init__(self):
        self.epoch = 0
        self.losses = []
        # 今のロス値
        self.cumu_loss = 0.0
        # 1つ前のロス値
        self.pre_cumu_loss = 0.0
        self.now_loss = 0
    
    def on_epoch_end(self, model):
     # 現在のエポックのロス値を取得
        loss = model.get_latest_training_loss()
        self.cumu_loss = float(loss)
     # 現在のロス値
        self.now_loss = self.cumu_loss if self.epoch == 0 else self.cumu_loss - self.pre_cumu_loss
        self.pre_cumu_loss = self.cumu_loss
        self.epoch += 1
        self.losses.append(self.now_loss)

上記で実装したクラスをWord2Vecのモデルを定義するときに読み込みます。

    loss_calc = LossCallback()  # 上記で定義したクラスの定義
    model = Word2Vec(
        •••
        compute_loss=True,
        callbacks=[loss_calc]
    )

参考

https://stackoverflow.com/questions/54888490/gensim-word2vec-print-log-loss

https://github.com/piskvorky/gensim/issues/2735

https://github.com/piskvorky/gensim/issues/2743

Discussion