Closed4
transformersのTrainer利用時にTensorBoardのHPARAMSでハイパラ管理したい
つまづき
huggingface transformersのTrainer使用後にlogging_dir出力先をTensorBoardで見てみるとHPARAMSが表示されていない
調査
- TensorBoardCallbackでtb_writer.add_hparamsがmetric_dictの値が入った状態で呼ばれていないのが原因か
- 特に今後対応の予定もなさそう
解決
- TensorBoardCallbackを書き換え、Trainerからこちらを呼び出すように修正
class MyTensorBoardCallback(TensorBoardCallback):
def on_evaluate(self, args, state, control, **kwargs):
if "metrics" in kwargs:
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict = kwargs["metrics"])
main.py
tensorboard_callback = MyTensorBoardCallback()
trainer = Trainer(
...
callbacks=[tensorboard_callback],
)
- trainer.evaluateを修正し、on_evaluateが呼ばれる条件を制御できるように
my_trainer.py
def evaluate(
...
log_metric: bool = False
) -> Dict[str, float]
...
if log_metric:
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics = metrics)
main.py
if training_args.do_eval:
// 学習後のbestチェックポイントモデルvalidation評価時のみログを送信
results = trainer.evaluate(log_metric=True)
もっとすっきりした解決策を募集中です🙇♂️
このスクラップは2024/01/01にクローズされました