🤗

transformersでのモデルの学習状況をSlackに通知する

2024/07/16に公開

どうも、SpiralAIの@ksterxです。

みなさんは、モデルの学習を行う時、なんのライブラリを使用して学習していますでしょうか?自分自身は、なんだかんだpytorch lightningを使っている時期が長かったのかなと思いますが、最近は言語系を触っていることもあり、🤗transformersを使用する機会が多いです。(megatronやllama-factoryに浮気したい気もしていますが、、、)

🤗Hugging Face Hubにある多くのモデルが、🤗transformersを使用してロード、推論が簡単に行えます。また、🤗transformersは比較的新しい論文のアルゴリズムであってもTrainerとして提供されることが多いため、簡単に多様なアルゴリズムを試せることが多いです。

実際に🤗transformersを使用してモデルの学習をする場合、学習状況を確認することがあると思います。この辺は、TrainingArgumentsやそのラッパーSFTConfig等のreport_toにお好みの出力先(wandbやtensorboard、あるいは何も設定しないと、ユーザーがインストールしているインテグレーション全部に!)を指定すると、デフォルト(train/eval lossやリソース使用状況等)のロギングをしてくれます。

report_toSource: Hugging Face Documentation

from trl import SFTConfig, SFTTrainer

train_args = SFTConfig(
    report_to="wandb",
    # ...
)
...
trainer = SFTTrainer(
    args=train_args,
    # ...
)

ただ、我々のような心配性and面倒くさがりな人間は、学習が進んでいるのかが心配になるわけです。わざわざ、wandbのダッシュボードを見る必要もない、それでも、どのぐらい進んでいるかプッシュ通知ぐらいで知れれば良い―――slackに通知すれば良くね。

本題

前置きが長くなりました。今回実現するのは、エポックが終了するタイミングとトレーニング自体が終了するタイミングに通知を飛ばすシステムです。

できるもの


↑↑↑↑こんな感じ

必要なもの

  • SlackのToken (参考ページ)
    • 環境変数SLACK_BOT_TOKENとして設定
  • Pythonライブラリ
    • transformers
    • slack_sdk

実装

🤗transformersのモデルトレーニング時の挙動を変える方法の一つが、TrainerCallbackを継承した独自のコールバックを実装することです。独自のコールバッククラスに、ステップする前/後、エポックが終わった時、保存を行う時のメソッドをオーバーライドすることで実現されます。(公式ドキュメント)
オーバーライドしたメソッドには、自分で渡した学習条件TrainingArgumentsや学習の現在の状態TrainerState等のオブジェクトを渡すことができます。今回で言えば現在のグローバルステップや保存先ディレクトリにアクセスしてSlackのチャンネルにポストします。

from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl

class SlackCallback(TrainerCallback):
    default_text = "\n\nProject: {proj}\nRepo: {repo}\nLog: {log_dir}\n\n"

    def __init__(self, channel, client=None):
        super().__init__()
        self.client = (
            WebClient(token=os.environ["SLACK_BOT_TOKEN"]) if client is None else client
        )
        self.channel = channel

    def on_epoch_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if state.is_world_process_zero:
            self.client.chat_postMessage(
                channel=self.channel,
                text=f"✅Epoch {state.epoch:.1f} (global step {state.global_step}) finished!!✅"
                + self.default_text.format(
                    proj=args.run_name,
                    repo=args.hub_model_id,
                    log_dir=args.output_dir,
                ),
            )

    def on_train_begin(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if state.is_world_process_zero:
            self.client.chat_postMessage(
                channel=self.channel,
                text="🚀Training started🚀"
                + self.default_text.format(
                    proj=args.run_name,
                    repo=args.hub_model_id,
                    log_dir=args.output_dir,
                ),
            )

    def on_train_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if state.is_world_process_zero:
            self.client.chat_postMessage(
                channel=self.channel,
                text="🎉Training finished🎉"
                + self.default_text.format(
                    proj=args.run_name,
                    repo=args.hub_model_id,
                    log_dir=args.output_dir,
                ),
            )

使い方

SlackCallbackクラスをTrainerに渡すことで動きます。

slack_callback = SlackCallback("write権限のあるチャンネル名")
trainer = SFTTrainer(
    ...,
    callbacks=[slack_callback]
)
trainer.train()

これでプログラムを動かすと…

🎉無事Slackの任意のチャンネルに通知することができる様になりました🎉

まとめ

今回は、🤗transformersを使用してモデルの学習進捗をSlackに通知する方法をご紹介しました。これで、学習の進捗状況をリアルタイムで把握するのが楽になりますね。わざわざダッシュボードを確認しに行かなくても、通知が来るので安心して学習を進められます。

Slack通知はチームメンバーとも共有しやすく、プロジェクト全体の透明性を高める効果も期待できます。もちろん、今回のコードをさらにカスタマイズして、特定のイベントやログメッセージを追加することも可能です。皆さんのプロジェクトに合わせて、最適な通知システムを構築してみてください。

LLMエンジニアを募集しています!!

SpiralAIでは、生成AI×エンタメをテーマに様々なプロジェクトが立ち上がっています!もし、ご興味があれば@ksterx採用ページまでご連絡ください〜

GitHubで編集を提案
SpiralAIテックブログ

Discussion