transformersでのモデルの学習状況をSlackに通知する
どうも、SpiralAIの@ksterxです。
みなさんは、モデルの学習を行う時、なんのライブラリを使用して学習していますでしょうか?自分自身は、なんだかんだpytorch lightningを使っている時期が長かったのかなと思いますが、最近は言語系を触っていることもあり、🤗transformersを使用する機会が多いです。(megatronやllama-factoryに浮気したい気もしていますが、、、)
🤗Hugging Face Hubにある多くのモデルが、🤗transformersを使用してロード、推論が簡単に行えます。また、🤗transformersは比較的新しい論文のアルゴリズムであってもTrainerとして提供されることが多いため、簡単に多様なアルゴリズムを試せることが多いです。
実際に🤗transformersを使用してモデルの学習をする場合、学習状況を確認することがあると思います。この辺は、TrainingArguments
やそのラッパーSFTConfig
等のreport_to
にお好みの出力先(wandbやtensorboard、あるいは何も設定しないと、ユーザーがインストールしているインテグレーション全部に!)を指定すると、デフォルト(train/eval lossやリソース使用状況等)のロギングをしてくれます。
Source: Hugging Face Documentation
from trl import SFTConfig, SFTTrainer
train_args = SFTConfig(
report_to="wandb",
# ...
)
...
trainer = SFTTrainer(
args=train_args,
# ...
)
ただ、我々のような心配性and面倒くさがりな人間は、学習が進んでいるのかが心配になるわけです。わざわざ、wandbのダッシュボードを見る必要もない、それでも、どのぐらい進んでいるかプッシュ通知ぐらいで知れれば良い―――slackに通知すれば良くね。
本題
前置きが長くなりました。今回実現するのは、エポックが終了するタイミングとトレーニング自体が終了するタイミングに通知を飛ばすシステムです。
できるもの
↑↑↑↑こんな感じ
必要なもの
- SlackのToken (参考ページ)
- 環境変数
SLACK_BOT_TOKEN
として設定
- 環境変数
- Pythonライブラリ
transformers
slack_sdk
実装
🤗transformersのモデルトレーニング時の挙動を変える方法の一つが、TrainerCallback
を継承した独自のコールバックを実装することです。独自のコールバッククラスに、ステップする前/後、エポックが終わった時、保存を行う時のメソッドをオーバーライドすることで実現されます。(公式ドキュメント)
オーバーライドしたメソッドには、自分で渡した学習条件TrainingArguments
や学習の現在の状態TrainerState
等のオブジェクトを渡すことができます。今回で言えば現在のグローバルステップや保存先ディレクトリにアクセスしてSlackのチャンネルにポストします。
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
class SlackCallback(TrainerCallback):
default_text = "\n\nProject: {proj}\nRepo: {repo}\nLog: {log_dir}\n\n"
def __init__(self, channel, client=None):
super().__init__()
self.client = (
WebClient(token=os.environ["SLACK_BOT_TOKEN"]) if client is None else client
)
self.channel = channel
def on_epoch_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if state.is_world_process_zero:
self.client.chat_postMessage(
channel=self.channel,
text=f"✅Epoch {state.epoch:.1f} (global step {state.global_step}) finished!!✅"
+ self.default_text.format(
proj=args.run_name,
repo=args.hub_model_id,
log_dir=args.output_dir,
),
)
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if state.is_world_process_zero:
self.client.chat_postMessage(
channel=self.channel,
text="🚀Training started🚀"
+ self.default_text.format(
proj=args.run_name,
repo=args.hub_model_id,
log_dir=args.output_dir,
),
)
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if state.is_world_process_zero:
self.client.chat_postMessage(
channel=self.channel,
text="🎉Training finished🎉"
+ self.default_text.format(
proj=args.run_name,
repo=args.hub_model_id,
log_dir=args.output_dir,
),
)
使い方
SlackCallback
クラスをTrainerに渡すことで動きます。
slack_callback = SlackCallback("write権限のあるチャンネル名")
trainer = SFTTrainer(
...,
callbacks=[slack_callback]
)
trainer.train()
これでプログラムを動かすと…
🎉無事Slackの任意のチャンネルに通知することができる様になりました🎉
まとめ
今回は、🤗transformersを使用してモデルの学習進捗をSlackに通知する方法をご紹介しました。これで、学習の進捗状況をリアルタイムで把握するのが楽になりますね。わざわざダッシュボードを確認しに行かなくても、通知が来るので安心して学習を進められます。
Slack通知はチームメンバーとも共有しやすく、プロジェクト全体の透明性を高める効果も期待できます。もちろん、今回のコードをさらにカスタマイズして、特定のイベントやログメッセージを追加することも可能です。皆さんのプロジェクトに合わせて、最適な通知システムを構築してみてください。
LLMエンジニアを募集しています!!
SpiralAIでは、生成AI×エンタメをテーマに様々なプロジェクトが立ち上がっています!もし、ご興味があれば@ksterxや採用ページまでご連絡ください〜
Discussion