🚴

【AIモデル長時間学習を可能に】Google Colabの24時間制限を突破する

2025/01/12に公開

導入

性能の高いGPUを使いたくて、Google Colabに課金したけど、
連続実行時間の制限で困った経験はありませんか?

頑張ってColab Pro+まで課金しても、最大24時間までしか連続実行が出来ません。
松尾研LLM講座コンペにおいて、24時間を越える学習を実施する際に困ったため、本コードを作成しました。

Google Colab上でのモデル学習において、Google Driveへのチェックポイントの保存と読み込みによる学習再開機能の実装方法を解説します。
GoogleColab上で、Hugging Faceを利用したモデル開発の際に使用していました。
コンペや開発などでご活用いただければ幸いです。

1. 事前準備

1.1 必要なライブラリのインポート

from transformers import TrainerCallback
import os
import json
from datetime import datetime

1.2 Google Driveのマウント

from google.colab import drive
drive.mount('./gdrive')

2. チェックポイント管理機能の実装

2.1 CheckpointCallbackクラスの実装

class CheckpointCallback(TrainerCallback):
    def __init__(self, output_dir="/content/gdrive/MyDrive/model_checkpoints"):
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)

    def on_save(self, args, state, control, **kwargs):
        #チェックポイント保存時の処理
        try:
            if state.is_world_process_zero:
                checkpoint_dir = os.path.join(
                    self.output_dir,
                    f"checkpoint-{state.global_step}"
                )
                os.makedirs(checkpoint_dir, exist_ok=True)

                if kwargs.get('trainer'):
                    kwargs['trainer'].save_model(checkpoint_dir)
                    state.save_to_json(os.path.join(checkpoint_dir, "trainer_state.json"))
                    print(f"チェックポイントを保存しました: {checkpoint_dir}")
                else:
                    print("trainerが取得できず、保存に失敗しました")

        except Exception as e:
            print(f"チェックポイント保存中にエラーが発生: {str(e)}")

    def get_latest_checkpoint(self):
        #最新のチェックポイントパスを取得
        if not os.path.exists(self.output_dir):
            return None

        checkpoints = [d for d in os.listdir(self.output_dir)
                      if d.startswith('checkpoint-')]
        if not checkpoints:
            return None

        return os.path.join(
            self.output_dir,
            max(checkpoints, key=lambda x: int(x.split('-')[1]))
        )

2.2 Trainerへの組み込み

checkpoint_callback = CheckpointCallback(
    output_dir="/content/gdrive/MyDrive/model_checkpoints"
)

# TrainingArgumentsの設定
training_args = TrainingArguments(
    output_dir="/content/gdrive/MyDrive/model_checkpoints",
    save_steps=500,  # 500ステップごとにチェックポイントを保存
    save_total_limit=10,  # 保存する最新チェックポイントの数
    # その他
)

# Trainerの設定
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    callbacks=[checkpoint_callback],
)

2.3 学習実行と再開処理

try:
    # 最新のチェックポイントを確認
    latest_checkpoint = checkpoint_callback.get_latest_checkpoint()

    if latest_checkpoint:
        print(f"チェックポイントから学習を再開します: {latest_checkpoint}")
        trainer_stats = trainer.train(resume_from_checkpoint=latest_checkpoint)
    else:
        print("新規に学習を開始します")
        trainer_stats = trainer.train()

except Exception as e:
    print(f"エラーが発生しました: {str(e)}")
    raise

チェックポイントの保存間隔はsave_stepsパラメータ、保存するチェックポイント数はsave_total_limitパラメータで制御しています。
古いチェックポイントは自動的に削除されます。Google Driveの容量を考慮して設定してください。

3. 使用例

# チェックポイントの保存先設定
CHECKPOINT_DIR = "/content/gdrive/MyDrive/model_checkpoints"

# コールバックの設定
checkpoint_callback = CheckpointCallback(output_dir=CHECKPOINT_DIR)

# 学習の実行
try:
    latest_checkpoint = checkpoint_callback.get_latest_checkpoint()
    
    if latest_checkpoint:
        print(f"Resuming from checkpoint: {latest_checkpoint}")
        trainer.train(resume_from_checkpoint=latest_checkpoint)
    else:
        trainer.train()
        
except RuntimeError as e:
    print(f"Training interrupted: {str(e)}")

本実装により

  • チェックポイントを定期的にGoogle Driveに保存
  • チェックポイントをGoogle Driveから読み込んで再開
    ができるようになります。

4. 最後に

エラー発生とその対応については、コメントの方に記載していただけると、本記事を閲覧された方の参考になり、大変ありがたいです。皆さんが笑顔になれるよう、ご協力頂けますと幸いです。

Discussion