🎉

無料版Google ColabでMoE付きTransformerの事前学習をしてみる

に公開

はじめに

先日投稿した下記記事の続きになります。
https://zenn.dev/asap/articles/d3bada2f005330

上記記事ではVRAM12GBのRTX3060を利用して、400Mパラメータ程度のMoE付きTransformerを構築して、ランダム初期値のパラメータから事前学習を行い、テキストを出力させるところまで実施しました。

とはいえ、自宅にGPUがない方もいらっしゃるかと思いますので、今回は無料版のGoogle Colabでどこまでできるかを試してみた記事になります

無料版Google Colabで事前学習する上で

無料版のGoogle Colabで事前学習する上で課題となるのは下記の部分です。

  • 無料版のGoogle Colabでは、GPU利用時、1日あたり3-4時間程度しか利用できない。
  • 無料版のGoogle Driveを利用することが前提だと思うので、重みを保存するストレージが15GBしかない
  • 無料版のGoogle Colabでは、学習時にかなり古いT4 GPUを利用する必要がある

上記3つの制限が非常に大きいです。

例えば、保存領域が最大15GBしか利用できないため、学習済みモデルの重みのサイズを気をつける必要があります、最低限15GB以下にする必要があります。

また、重みのチェックポイントというのは、一定間隔で常に保存し続ける必要があります。
なぜなら、Google Colabの無料版は3時間程度でセッションが切れてしまうため、学習の途中経過を定期的に保存し、再度Google Colabの利用制限が解除されたら途中から学習を実施できるようにしなければなりません。
(3時間程度で学習は完了しません)

自宅PCでの処理であれば、チェックポイントが貯まってきたら古いものを自動消去することで、ストレージを圧迫しないようにすることは可能です。
しかし、Google Colabでは、削除データは自動的にゴミ箱に貯まってしまい、ゴミ箱から手動で消さなければストレージを圧迫し続けます。

また、このストレージの中には、キャッシュしたデータセットもおく必要があります。
(学習のたびにデータセットのダウンロードや、トークン化処理を実施していたら時間がもったいない)

前回の記事で使っていたデータセットのキャッシュサイズは約30GB、モデルのチェックポイントのサイズが約5GBであるため、このままでは無料版のGoogle Colabでは学習させることはできません。
したがって、学習データ量を小さくすること、モデルサイズを小さくすることは必須になります。

また、T4 GPUは非常に古いGPUのため、bfloat16が利用できません。それも課題になってきます。

環境

無料版Google Colab
GPU:T4 GPU
Google Driveストレージ:まっさらな状態を推奨(15GBの空き)

リポジトリ

学習や推論のコードは、以下のリポジトリをご覧ください。
https://github.com/personabb/LightLM_public_repo

なお、本リポジトリやコードの詳細などは前回の記事をご覧ください。

今回の設定での学習済み重みは以下をご覧ください
https://huggingface.co/asap-bb/mylightlm_small_sample

事前準備

リポジトリをクローン

下記コマンドでクローンしてください

git clone https://github.com/personabb/LightLM_public_repo.git

Google Driveにアップロード

マイドライブ直下に「LightLM」フォルダを作成し、リポジトリの中身のファイルを全てLightLMフォルダにアップロードしてください。

HF tokenを取得して、Google Colabに登録(任意)

学習したモデルをHuggingFaceにアップロードしたい場合は、HF tokenを取得してください。
Huggingfaceページの右上アイコンをクリックして、「Access Tokens」から作成することができます。

取得したtokenは、下記の部分に設定してください。
これにより、同じアカウントを利用している限り、ノートブックが変わったとしても利用可能です。

事前学習

Google ColabにてLightLM_Colab_train.ipynbを開いて、すべてのセルを実行してください。
実行すると、250stepごとに、モデルのチェックポイントの保存(model_testing-small)と、検証データによる評価が保存(log/eval-small.txt)されます。

また、このチェックポイントには学習中の全てのデータ(モデル重みだけでなく、OptimizerやScheduler、lr、dataset idxなど)が保存されます。従って途中からの学習の再開も可能です。

以下に学習設定や、モデル設定を記載します。
修正しても問題ないですが、無料版のGoogle Colab, Google Drive上で学習する都合上、大きく増加させることはできないかなと思います。

学習設定

前回の記事からの差分を記載します。

LightLM_Colab_train.ipynb
train_config = TrainerConfig(
    ・・・
    use_dtype="float16" if device == 'cuda' else "float32", #T4はbfloatに未対応

    ・・・

    checkpoints_frequency=250,
    path_to_checkpoints="/content/drive/MyDrive/LightLM/model_testing-small",
    max_checkpoints_to_keep=4, # 0の場合は全て保持、-1の場合は最新1つのチェックポイントを保持 colabの場合はゴミ箱システムのせいでどれを設定しても結局重みは圧迫する

    tokenized_dataset_path = "HuggingFaceFW/fineweb-edu",
    #sub_target_files = "", #all data
    #sub_target_files = "data/CC-MAIN-2025-26/*.parquet",
    #sub_target_files = "data/CC-MAIN-2025-26/000_00049.parquet",
    sub_target_files = [
        #"data/CC-MAIN-2025-26/000_00047.parquet",
        "data/CC-MAIN-2025-26/000_00048.parquet",
        "data/CC-MAIN-2025-26/000_00049.parquet"
    ],
    eval_log_file="/content/drive/MyDrive/LightLM/log/eval-small.txt",

    ・・・
)

Google Colabの無料枠で利用できるGPUであるT4 GPUは、bfloat16を利用できないので、float16を利用しています。

また、学習用のデータセットに関しても、CC-MAIN-2025-26フォルダすらすべてダウンロードしたキャッシュをすることは、無料版のGoogle Driveのストレージでは無理なので、2ファイル分(000_00048.parquet000_00049.parquet)だけで学習することにします。
この2ファイルだけだと、215,939,584token(約2億token)の学習データになります。

モデル設定

LightLM_Colab_train.ipynb
config = ModelConfig(
    vocab_size=tokenizer.vocab_size,

    num_dims=512,     
    num_heads=16,
    num_kv_heads=4,    # GQA による効率化
    num_layers=12,     
    ffn_hidden_dims=512 * 4,
    # 無料版google Driveの少量すぎる保存容量と、貧弱な計算資源を考慮し、GPT-2リスペクトでさらにモデルサイズを小さく
    rmsnorm_eps=1e-6,
    rope_theta=1e5,

    context_len=512,  

    use_cache=False,
    use_flash=True,    # 利用可能な場合
    use_moe=True,    

    moe_num_experts=3, 
    moe_active_experts=1,
    moe_eps=1e-6,
    moe_aux_loss_coef=0.01,
    moe_shared_experts=1,
    use_lossfreebalance=False,
)

前回の記事で利用したモデルよりも、Transformerブロックの数や、MoEのExpert数を減らしています。
モデルパラメータとしては、184.06M(Active:108.55M)のTransformerになります。

Huggingfaceへのアップロード(必要であれば)

以下の4セル目を適切に修正したのちに、HF_Colab.ipynbのすべてのセルを実行してください。

HF_Colab.ipynb
# デフォルトのチェックポイントパス(train.pyから)
default_checkpoint = lightlm_path + "/model_testing-small/model.checkpoint.epoch0_step23500_global23500.pt"
model_dir = lightlm_path + "/hf_model-small"
repo_name = "your_username/your_repo_name"
private = False

default_checkpointは学習した上で、評価データの損失が最も低いチェックポイントを指定してください。
model.checkpoint.epoch0_step23500_global23500.pt」の部分は、学習の進捗によって変わります」
model_dirはHF形式に変換したデータを保存するディレクトリです。変更不要です。
repo_nameは自身のユーザネームと保存したいリポジトリ名を指定してください。
privateはHFのリポジトリが公開か非公開かを選択します。ストレージ圧迫するのもどうかと思うので公開で良いかと思います。

本ノートブックを実行後、Huggingfaceにアップロードがなされているかと思います。

推論実施

checkpointファイルを利用した推論

学習途中のチェックポイント(例えばmodel.checkpoint.epoch0_step16000_global16000.ptなど)を利用して推論をする場合はLightLM_Colab_infer.ipynbを実行します。

実行前に4セル名のモデルパラメータを学習時と同じものに設定し、5セル目のcheckpoint_pathを利用したいチェックポイントを指定します。
また、6セル目でプロンプトと生成パラメータ(temperatureなど)を設定し実行してください。

以下のパラメータをコード内で設定できます

  • text
    • 入力されるプロンプト
    • baseモデルになるので、プロンプトの続きから出力されます。
  • max_tokens
    • 出力できる最大トークン数
  • temperature
  • top_k
  • top_p
  • repetition_penalty
    • 繰り返しのかかるペナルティ
  • use_cache
    • KVキャッシュの利用有無

Huggingfaceモデルを利用した推論

Huggingfaceへアップロードしたモデルを推論に利用する場合は、HF_inference_Colab.ipynbを実行します。

実行前に、2セル目でリポジトリ名の指定や、プロンプトの設定、3セル目のgenerateメソッドでtemperatureなどのパラメータを設定してください。
設定可能なパラメータはLightLM_Colab_infer.ipynbと同様です。

出力

実際に今回の設定で学習を行いました

学習経過内容

各ステップごとの検証データにおけるlossの推移

今回は、学習データ量を減らしたため、1epochあたり3279stepのみでした。
自宅のPCであれば、丸一日学習できれば4epochの学習が完了します。

Global Step: 250, Epoch: 0, Step: 250, val_loss: 5.8105, norm: 0.6994, lr: 1.1594202899e-04, time: 5.91s, tok/s: 11067.0 | dataset idx: 421649/421757
Global Step: 500, Epoch: 0, Step: 500, val_loss: 4.9471, norm: 0.7153, lr: 2.0652173913e-04, time: 5.63s, tok/s: 11610.4 | dataset idx: 421541/421757
Global Step: 750, Epoch: 0, Step: 750, val_loss: 4.4959, norm: 0.5326, lr: 2.9710144928e-04, time: 5.57s, tok/s: 11751.1 | dataset idx: 421433/421757
Global Step: 1000, Epoch: 0, Step: 1000, val_loss: 4.1430, norm: 0.5438, lr: 3.8768115942e-04, time: 5.63s, tok/s: 11617.0 | dataset idx: 421325/421757
Global Step: 1250, Epoch: 0, Step: 1250, val_loss: 3.9000, norm: 0.4086, lr: 4.7826086957e-04, time: 5.61s, tok/s: 11659.7 | dataset idx: 421217/421757
Global Step: 1500, Epoch: 0, Step: 1500, val_loss: 3.6893, norm: 0.3595, lr: 4.9971243571e-04, time: 5.76s, tok/s: 11359.8 | dataset idx: 421109/421757
Global Step: 1750, Epoch: 0, Step: 1750, val_loss: 3.5402, norm: 0.3253, lr: 4.9845925999e-04, time: 5.65s, tok/s: 11570.8 | dataset idx: 421001/421757
Global Step: 2000, Epoch: 0, Step: 2000, val_loss: 3.4457, norm: 0.3139, lr: 4.9621733556e-04, time: 5.57s, tok/s: 11738.2 | dataset idx: 420893/421757
Global Step: 2250, Epoch: 0, Step: 2250, val_loss: 3.3438, norm: 0.3039, lr: 4.9299658233e-04, time: 5.58s, tok/s: 11713.8 | dataset idx: 420785/421757
Global Step: 2500, Epoch: 0, Step: 2500, val_loss: 3.2853, norm: 0.2956, lr: 4.8881125131e-04, time: 5.59s, tok/s: 11705.9 | dataset idx: 420677/421757
Global Step: 2750, Epoch: 0, Step: 2750, val_loss: 3.2458, norm: 0.2930, lr: 4.8367986147e-04, time: 5.65s, tok/s: 11580.7 | dataset idx: 420569/421757
Global Step: 3000, Epoch: 0, Step: 3000, val_loss: 3.1882, norm: 0.2930, lr: 4.7762511788e-04, time: 5.62s, tok/s: 11648.0 | dataset idx: 420461/421757
Global Step: 3250, Epoch: 0, Step: 3250, val_loss: 3.1629, norm: 0.3240, lr: 4.7067381120e-04, time: 5.61s, tok/s: 11654.2 | dataset idx: 420353/421757
Global Step: 3278, Epoch: 0, Step: 3278, val_loss: 3.1273, norm: 0.2951, lr: 4.6984074466e-04, time: 5.72s, tok/s: 11425.0 | dataset idx: 420245/421757
Global Step: 3500, Epoch: 1, Step: 221, val_loss: 3.1290, norm: 0.3036, lr: 4.6285669913e-04, time: 5.61s, tok/s: 11651.3 | dataset idx: 420137/421757
Global Step: 3750, Epoch: 1, Step: 471, val_loss: 3.0477, norm: 0.2877, lr: 4.5420837035e-04, time: 5.57s, tok/s: 11743.0 | dataset idx: 420029/421757
Global Step: 4000, Epoch: 1, Step: 721, val_loss: 3.0432, norm: 0.2976, lr: 4.4476709145e-04, time: 5.58s, tok/s: 11719.1 | dataset idx: 419921/421757
Global Step: 4250, Epoch: 1, Step: 971, val_loss: 3.0405, norm: 0.2892, lr: 4.3457463762e-04, time: 5.64s, tok/s: 11588.3 | dataset idx: 419813/421757
Global Step: 4500, Epoch: 1, Step: 1221, val_loss: 3.0247, norm: 0.2925, lr: 4.2367610780e-04, time: 5.59s, tok/s: 11693.3 | dataset idx: 419705/421757
Global Step: 4750, Epoch: 1, Step: 1471, val_loss: 2.9962, norm: 0.3036, lr: 4.1211972513e-04, time: 5.59s, tok/s: 11699.4 | dataset idx: 421705/421757
Global Step: 5000, Epoch: 1, Step: 1721, val_loss: 2.9384, norm: 0.2959, lr: 3.9995662357e-04, time: 5.58s, tok/s: 11729.6 | dataset idx: 421597/421757
Global Step: 5250, Epoch: 1, Step: 1971, val_loss: 2.9432, norm: 0.3000, lr: 3.8724062167e-04, time: 5.56s, tok/s: 11760.1 | dataset idx: 421489/421757
Global Step: 5500, Epoch: 1, Step: 2221, val_loss: 2.8901, norm: 0.2942, lr: 3.7402798440e-04, time: 5.61s, tok/s: 11654.3 | dataset idx: 421381/421757
Global Step: 5750, Epoch: 1, Step: 2471, val_loss: 2.8764, norm: 0.3046, lr: 3.6037717423e-04, time: 5.61s, tok/s: 11669.1 | dataset idx: 421273/421757
Global Step: 6000, Epoch: 1, Step: 2721, val_loss: 2.8534, norm: 0.3114, lr: 3.4634859242e-04, time: 5.56s, tok/s: 11755.0 | dataset idx: 421165/421757
Global Step: 6250, Epoch: 1, Step: 2971, val_loss: 2.8333, norm: 0.3161, lr: 3.3200431176e-04, time: 5.55s, tok/s: 11789.0 | dataset idx: 421057/421757
Global Step: 6500, Epoch: 1, Step: 3221, val_loss: 2.7851, norm: 0.3120, lr: 3.1740780195e-04, time: 5.57s, tok/s: 11732.9 | dataset idx: 420949/421757
Global Step: 6557, Epoch: 1, Step: 3278, val_loss: 2.8190, norm: 0.3156, lr: 3.1405118376e-04, time: 5.71s, tok/s: 11449.7 | dataset idx: 420841/421757
Global Step: 6750, Epoch: 2, Step: 192, val_loss: 2.7864, norm: 0.3142, lr: 3.0262364872e-04, time: 5.60s, tok/s: 11674.0 | dataset idx: 420733/421757
Global Step: 7000, Epoch: 2, Step: 442, val_loss: 2.7736, norm: 0.3155, lr: 2.8771726808e-04, time: 5.59s, tok/s: 11708.7 | dataset idx: 420625/421757
Global Step: 7250, Epoch: 2, Step: 692, val_loss: 2.7894, norm: 0.3266, lr: 2.7275461685e-04, time: 5.61s, tok/s: 11658.2 | dataset idx: 420517/421757
Global Step: 7500, Epoch: 2, Step: 942, val_loss: 2.7963, norm: 0.3200, lr: 2.5780190086e-04, time: 5.64s, tok/s: 11599.9 | dataset idx: 420409/421757
Global Step: 7750, Epoch: 2, Step: 1192, val_loss: 2.7613, norm: 0.3261, lr: 2.4292528196e-04, time: 5.54s, tok/s: 11809.7 | dataset idx: 420301/421757
Global Step: 8000, Epoch: 2, Step: 1442, val_loss: 2.7409, norm: 0.3304, lr: 2.2819058528e-04, time: 5.55s, tok/s: 11793.6 | dataset idx: 420193/421757
Global Step: 8250, Epoch: 2, Step: 1692, val_loss: 2.7475, norm: 0.3265, lr: 2.1366300801e-04, time: 5.60s, tok/s: 11676.8 | dataset idx: 420085/421757
Global Step: 8500, Epoch: 2, Step: 1942, val_loss: 2.7326, norm: 0.3301, lr: 1.9940683087e-04, time: 5.61s, tok/s: 11656.2 | dataset idx: 419977/421757
Global Step: 8750, Epoch: 2, Step: 2192, val_loss: 2.7202, norm: 0.3333, lr: 1.8548513371e-04, time: 5.56s, tok/s: 11754.6 | dataset idx: 419869/421757
Global Step: 9000, Epoch: 2, Step: 2442, val_loss: 2.6826, norm: 0.3494, lr: 1.7195951639e-04, time: 5.56s, tok/s: 11757.7 | dataset idx: 419761/421757
Global Step: 9250, Epoch: 2, Step: 2692, val_loss: 2.6671, norm: 0.3477, lr: 1.5888982624e-04, time: 5.58s, tok/s: 11718.8 | dataset idx: 419653/421757
Global Step: 9500, Epoch: 2, Step: 2942, val_loss: 2.6512, norm: 0.3331, lr: 1.4633389321e-04, time: 5.59s, tok/s: 11704.0 | dataset idx: 421653/421757
Global Step: 9750, Epoch: 2, Step: 3192, val_loss: 2.6311, norm: 0.3370, lr: 1.3434727402e-04, time: 5.55s, tok/s: 11784.4 | dataset idx: 421545/421757
Global Step: 9836, Epoch: 2, Step: 3278, val_loss: 2.6606, norm: 0.3489, lr: 1.3036514311e-04, time: 5.61s, tok/s: 11651.6 | dataset idx: 421437/421757
Global Step: 10000, Epoch: 3, Step: 163, val_loss: 2.6370, norm: 0.3509, lr: 1.2298300631e-04, time: 5.58s, tok/s: 11728.8 | dataset idx: 421329/421757
Global Step: 10250, Epoch: 3, Step: 413, val_loss: 2.6324, norm: 0.3576, lr: 1.1229137400e-04, time: 5.54s, tok/s: 11805.5 | dataset idx: 421221/421757
Global Step: 10500, Epoch: 3, Step: 663, val_loss: 2.6414, norm: 0.3553, lr: 1.0231968476e-04, time: 5.65s, tok/s: 11584.2 | dataset idx: 421113/421757
Global Step: 10750, Epoch: 3, Step: 913, val_loss: 2.6010, norm: 0.3533, lr: 9.3112060706e-05, time: 5.54s, tok/s: 11796.9 | dataset idx: 421005/421757
Global Step: 11000, Epoch: 3, Step: 1163, val_loss: 2.5890, norm: 0.3563, lr: 8.4709243161e-05, time: 5.53s, tok/s: 11836.0 | dataset idx: 420897/421757
Global Step: 11250, Epoch: 3, Step: 1413, val_loss: 2.5630, norm: 0.3537, lr: 7.7148412395e-05, time: 5.55s, tok/s: 11784.6 | dataset idx: 420789/421757
Global Step: 11500, Epoch: 3, Step: 1663, val_loss: 2.5723, norm: 0.3672, lr: 7.0463023103e-05, time: 5.59s, tok/s: 11704.5 | dataset idx: 420681/421757
Global Step: 11750, Epoch: 3, Step: 1913, val_loss: 2.5652, norm: 0.3678, lr: 6.4682656383e-05, time: 5.57s, tok/s: 11751.5 | dataset idx: 420573/421757
Global Step: 12000, Epoch: 3, Step: 2163, val_loss: 2.5753, norm: 0.3747, lr: 5.9832888844e-05, time: 5.53s, tok/s: 11832.4 | dataset idx: 420465/421757
Global Step: 12250, Epoch: 3, Step: 2413, val_loss: 2.5752, norm: 0.3672, lr: 5.5935179439e-05, time: 5.65s, tok/s: 11583.0 | dataset idx: 420357/421757
Global Step: 12500, Epoch: 3, Step: 2663, val_loss: 2.5639, norm: 0.3732, lr: 5.3006774510e-05, time: 5.69s, tok/s: 11495.3 | dataset idx: 420249/421757
Global Step: 12750, Epoch: 3, Step: 2913, val_loss: 2.5446, norm: 0.3675, lr: 5.1060631482e-05, time: 5.55s, tok/s: 11790.5 | dataset idx: 420141/421757
Global Step: 13000, Epoch: 3, Step: 3163, val_loss: 2.5736, norm: 0.3794, lr: 5.0105361530e-05, time: 5.58s, tok/s: 11719.2 | dataset idx: 420033/421757
Global Step: 13115, Epoch: 3, Step: 3278, val_loss: 2.5512, norm: 0.3927, lr: 5.0000000000e-05, time: 5.55s, tok/s: 11791.4 | dataset idx: 419925/421757

出力設定

./HF_inference_Colab.ipynb
        output_ids = model.generate(
            input_ids,
            max_tokens=100,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.2,
            use_cache=True,
        )

prompt

text = "I am Mike. I live in"

出力結果

Google ColabでHF_inference_Colab.ipynbを実行した際の出力結果になります。

==================================================
📄 Generated Text:
==================================================
I'm Mike. I live in the United States, and he's a big fan of programming. But have you ever wondered what those guys actually do?
Well, we know that computers are probably one of my favourite subjects at college. And if you’re like me, you've got a problem with your computer! If anyone can, let us give it a try. You'll be able to make your life easier by following these easy steps:
1) Go on a trip
If you're just starting out, here
==================================================

翻訳

私はマイクです。アメリカに住んでいて、プログラミングが大好きです。でも、あの連中が実際に何をしているのか、考えたことありますか?
さて、コンピュータは大学で私が最も好きな科目の一つです。もしあなたが私と同じなら、きっとコンピュータに問題を抱えていることでしょう!誰かできるなら、私たちに試させてください。以下の簡単な手順に従えば、あなたの生活を楽にできるはずです:
1) 旅行に出かける
もしあなたが始めたばかりなら、ここ

まとめ

意外と、無料版のGoogle Colabでも小さなTransformerの事前学習が可能なんだなというのがわかって少し驚きました。
小さめのLanguage Modelであれば個人でもどんどん作れそうなので、何か特化させて処理させるAIを作るのも面白そうですね。(自然言語処理ではなくとも)

前回の記事もぜひご覧ください!
https://zenn.dev/asap/articles/d3bada2f005330

Discussion