🤗

HuggingfaceのTrainer.train()でログを残す方法

2023/09/09に公開

結論とソースコード

こんな感じでTrainerのcallbacksにtransformers.TrainerCallbackを継承したクラスを渡せばOK。

def init_logger(log_file='train.log'):
    from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

logger = init_logger(log_file=f'train.log')

import transformers
class LogCallback(transformers.TrainerCallback):
    def on_evaluate(self, args, state, control, **kwargs):
        logger.info(state.log_history[-1])

trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    train_dataset=ds_train_tokenized,
    eval_dataset=ds_valid_tokenized,
    callbacks=[
        LogCallback()
    ],
)

trainer.train()

出力は以下のように、evalのlossなどが出力される

{
'eval_loss': 1.6095116138458252,
'eval_runtime': 39.4781,
'eval_samples_per_second': 5.066,
'eval_steps_per_second': 1.267,
'epoch': 0.08,
'step': 20
}

参考

https://huggingface.co/docs/transformers/v4.33.0/en/main_classes/callback#transformers.TrainerCallback.example

https://discuss.huggingface.co/t/logs-of-training-and-validation-loss/1974/5

メモ

state.log_historyの中にはlearning_rateなども入っている(奇数が学習関係のパラメーター、偶数が評価データの評価値)ので、そちらも出力できるみたいです

trainingargsにlogsってあるけど出力されなかった。。

Discussion