🔔

🀗transformersでのモデルの孊習状況をSlackに通知する

2024/07/16に公開

どうも、SpiralAIの@ksterxです。

みなさんは、モデルの孊習を行う時、なんのラむブラリを䜿甚しお孊習しおいたすでしょうか自分自身は、なんだかんだpytorch lightningを䜿っおいる時期が長かったのかなず思いたすが、最近は蚀語系を觊っおいるこずもあり、🀗transformersを䜿甚する機䌚が倚いです。megatronやllama-factoryに浮気したい気もしおいたすが、、、

🀗Hugging Face Hubにある倚くのモデルが、🀗transformersを䜿甚しおロヌド、掚論が簡単に行えたす。たた、🀗transformersは比范的新しい論文のアルゎリズムであっおもTrainerずしお提䟛されるこずが倚いため、簡単に倚様なアルゎリズムを詊せるこずが倚いです。

実際に🀗transformersを䜿甚しおモデルの孊習をする堎合、孊習状況を確認するこずがあるず思いたす。この蟺は、TrainingArgumentsやそのラッパヌSFTConfig等のreport_toにお奜みの出力先wandbやtensorboard、あるいは䜕も蚭定しないず、ナヌザヌがむンストヌルしおいるむンテグレヌション党郚にを指定するず、デフォルトtrain/eval lossやリ゜ヌス䜿甚状況等のロギングをしおくれたす。

report_toSource: Hugging Face Documentation

from trl import SFTConfig, SFTTrainer

train_args = SFTConfig(
    report_to="wandb",
    # ...
)
...
trainer = SFTTrainer(
    args=train_args,
    # ...
)

ただ、我々のような心配性and面倒くさがりな人間は、孊習が進んでいるのかが心配になるわけです。わざわざ、wandbのダッシュボヌドを芋る必芁もない、それでも、どのぐらい進んでいるかプッシュ通知ぐらいで知れれば良い―――slackに通知すれば良くね。

本題

前眮きが長くなりたした。今回実珟するのは、゚ポックが終了するタむミングずトレヌニング自䜓が終了するタむミングに通知を飛ばすシステムです。

できるもの


↑↑↑↑こんな感じ

必芁なもの

  • SlackのToken (参考ペヌゞ)
    • 環境倉数SLACK_BOT_TOKENずしお蚭定
  • Pythonラむブラリ
    • transformers
    • slack_sdk

実装

🀗transformersのモデルトレヌニング時の挙動を倉える方法の䞀぀が、TrainerCallbackを継承した独自のコヌルバックを実装するこずです。独自のコヌルバッククラスに、ステップする前/埌、゚ポックが終わった時、保存を行う時のメ゜ッドをオヌバヌラむドするこずで実珟されたす。公匏ドキュメント)
オヌバヌラむドしたメ゜ッドには、自分で枡した孊習条件TrainingArgumentsや孊習の珟圚の状態TrainerState等のオブゞェクトを枡すこずができたす。今回で蚀えば珟圚のグロヌバルステップや保存先ディレクトリにアクセスしおSlackのチャンネルにポストしたす。

from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl

class SlackCallback(TrainerCallback):
    default_text = "\n\nProject: {proj}\nRepo: {repo}\nLog: {log_dir}\n\n"

    def __init__(self, channel, client=None):
        super().__init__()
        self.client = (
            WebClient(token=os.environ["SLACK_BOT_TOKEN"]) if client is None else client
        )
        self.channel = channel

    def on_epoch_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if state.is_world_process_zero:
            self.client.chat_postMessage(
                channel=self.channel,
                text=f"✅Epoch {state.epoch:.1f} (global step {state.global_step}) finished!!✅"
                + self.default_text.format(
                    proj=args.run_name,
                    repo=args.hub_model_id,
                    log_dir=args.output_dir,
                ),
            )

    def on_train_begin(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if state.is_world_process_zero:
            self.client.chat_postMessage(
                channel=self.channel,
                text="🚀Training started🚀"
                + self.default_text.format(
                    proj=args.run_name,
                    repo=args.hub_model_id,
                    log_dir=args.output_dir,
                ),
            )

    def on_train_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if state.is_world_process_zero:
            self.client.chat_postMessage(
                channel=self.channel,
                text="🎉Training finished🎉"
                + self.default_text.format(
                    proj=args.run_name,
                    repo=args.hub_model_id,
                    log_dir=args.output_dir,
                ),
            )

䜿い方

SlackCallbackクラスをTrainerに枡すこずで動きたす。

slack_callback = SlackCallback("write暩限のあるチャンネル名")
trainer = SFTTrainer(
    ...,
    callbacks=[slack_callback]
)
trainer.train()

これでプログラムを動かすず 

🎉無事Slackの任意のチャンネルに通知するこずができる様になりたした🎉

たずめ

今回は、🀗transformersを䜿甚しおモデルの孊習進捗をSlackに通知する方法をご玹介したした。これで、孊習の進捗状況をリアルタむムで把握するのが楜になりたすね。わざわざダッシュボヌドを確認しに行かなくおも、通知が来るので安心しお孊習を進められたす。

Slack通知はチヌムメンバヌずも共有しやすく、プロゞェクト党䜓の透明性を高める効果も期埅できたす。もちろん、今回のコヌドをさらにカスタマむズしお、特定のむベントやログメッセヌゞを远加するこずも可胜です。皆さんのプロゞェクトに合わせお、最適な通知システムを構築しおみおください。

LLM゚ンゞニアを募集しおいたす

SpiralAIでは、生成AI×゚ンタメをテヌマに様々なプロゞェクトが立ち䞊がっおいたすもし、ご興味があれば@ksterxや採甚ペヌゞたでご連絡ください〜

SpiralAIテックブログ

Discussion