🔔

【効率化】機械学習モデルの学習進捗をSlackで通知し監視する

2025/01/12に公開

導入

モデルの学習を進める際
・どのぐらい進んでいるのか
・いつ終わるのか
・中断していないか
が気になりませんか?

本記事では、モデル学習の進捗状況をSlackに定期通知する機能の実装方法を解説します。
松尾研LLM講座コンペ参加時に、本コードを作成しました。
コンペや開発などでご活用いただければ幸いです。

1. 基本的なSlack通知機能の実装

1.1 事前準備

  1. Slackワークスペースの設定
    • Incoming Webhookを有効化
    • Webhook URLを取得(形式: https://hooks.slack.com/services/~~~

こちらの記事をご参照ください
[Slack:Webhook URL取得してSlackに通知する] (https://zenn.dev/hotaka_noda/articles/4a6f0ccee73a18)

  1. 必要なライブラリのインポート
import json
import requests
from datetime import datetime, timezone, timedelta

1.2 実装

def send_slack_message(message: str):
    #Slackにメッセージを送信する基本機能
    webhook_url = 'https://hooks.slack.com/services/~~~'  # 取得したWebhook URLを設定
    data = json.dumps({'text': message})
    headers = {'content-type': 'application/json'}
    
    try:
        response = requests.post(webhook_url, data=data, headers=headers)
        response.raise_for_status()
        print(f"Notification sent: {message}")
    except Exception as e:
        print(f"Failed to send notification: {str(e)}")

1.3 使用例

# 通知テスト
send_slack_message("テスト通知")

2. モデル学習進捗通知機能の実装

2.1 TrainerCallbackの実装

Hugging Face TransformersのTrainerCallbackを継承して進捗通知機能を実装します。

from transformers import TrainerCallback
from datetime import datetime, timedelta
import math

class ProgressCallback(TrainerCallback):
    def __init__(self, slack_notification_interval=3600):  # 1時間間隔
        self.last_notification_time = datetime.now()
        self.notification_interval = slack_notification_interval
        self.training_start_time = None

    def on_train_begin(self, args, state, control, **kwargs):
        #学習開始時の処理
        self.training_start_time = datetime.now()
        total_steps = state.max_steps if state.max_steps > 0 else state.num_train_epochs * state.num_train_epochs
        message = f"学習を開始しました\n総ステップ数: {total_steps}"
        send_slack_message(message)

    def on_step_end(self, args, state, control, **kwargs):
        #各ステップ終了時の処理
        current_time = datetime.now()
        
        # 前回の通知から指定時間経過している場合のみ通知
        if (current_time - self.last_notification_time).total_seconds() >= self.notification_interval:
            # 進捗率の計算
            total_steps = state.max_steps if state.max_steps > 0 else math.ceil(args.num_train_epochs * state.num_train_epochs)
            progress = (state.global_step / total_steps) * 100
            
            # 残り時間の推定
            elapsed_time = current_time - self.training_start_time
            steps_per_second = state.global_step / elapsed_time.total_seconds()
            remaining_steps = total_steps - state.global_step
            estimated_remaining_seconds = remaining_steps / steps_per_second
            estimated_finish_time = current_time + timedelta(seconds=estimated_remaining_seconds)
            
            # メッセージ作成
            message = (
                f"学習進捗状況:\n"
                f"- 進捗: {progress:.2f}% ({state.global_step}/{total_steps}ステップ)\n"
                f"- 経過時間: {str(elapsed_time).split('.')[0]}\n"
                f"- 推定残り時間: {str(timedelta(seconds=int(estimated_remaining_seconds)))}\n"
                f"- 推定完了時刻: {estimated_finish_time.strftime('%Y-%m-%d %H:%M:%S')}"
            )
            
            send_slack_message(message)
            self.last_notification_time = current_time

    def on_train_end(self, args, state, control, **kwargs):
        #学習終了時の処理
        total_time = datetime.now() - self.training_start_time
        message = f"学習が完了しました\n学習時間: {str(total_time).split('.')[0]}"
        send_slack_message(message)

2.2 Trainerへの組み込み方

# Progress Callbackのインスタンス化
progress_callback = ProgressCallback(slack_notification_interval=3600)  # 1時間間隔で通知

# Trainerの設定
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    callbacks=[progress_callback],  # コールバックを追加
    # その他
)

本実装により、以下のタイミングで自動的にSlack通知を受け取ることができます

  • 学習開始時
  • 定期的な進捗報告(デフォルト1時間間隔)
  • 学習終了時
    この機能により、長時間のモデル学習をより効率的に監視することが可能になります。
    自身の実行環境では通知エラーは起きませんでしたが、心配な方はエラーハンドリングを追加されるのが良いと思います。
    通知間隔や表示内容はProgressCallbackクラスのパラメータを調整することで、カスタマイズ可能です。

3. 通知メッセージ例

学習開始時

学習を開始しました
総ステップ数: 5000

進捗報告時

学習進捗状況:
- 進捗: 35.80% (1790/5000ステップ)
- 経過時間: 2:15:30
- 推定残り時間: 4:03:45
- 推定完了時刻: 2024-01-11 18:45:22

学習完了時

学習が完了しました
学習時間: 6:19:15

4. 最後に

エラー発生とその対応については、コメントの方に記載していただけると、本記事を閲覧された方の参考になり、大変ありがたいです。皆さんが笑顔になれるよう、ご協力頂けますと幸いです。

Discussion