Instruction tuningをやってみて分かった5つこと

2024/02/07に公開

今回の実験の目的ときっかけ

日本語のLLMを勉強していたところ、ふと「Baseモデルを自分が学習するのは難しいけどInstruction tunedモデルは学習出来そうなのでその性能を再現してみる価値があるのではないか?」と考えました。普段よく使う StableLM-instruct-gamma-7bの性能の再現を目指して、以下のことを確認しようと決定しました。

  • Base modelからのInstruction tuningを行い、公開されているInstructモデルの性能を再現する。
  • その過程でどの要素が良い性能を作るために重要なのか確認する。

実験のセッティング

  • 使用したtransformersのバージョンは以下の通りです。

    • transformers==4.35.0
    • trl==0.7.6
  • notus のコミットハッシュ5cd49123b2fd8622a656afb441361e27db791898を使用しました。

  • japanese-stablelm-instruct-gamma-7bの性能を再現したいのでデータをそのまま使用しました。

    • kunishou/databricks-dolly-15k-ja
    • fujiki/japanese_hh-rlhf-49k
    • fujiki/llm-japanese-dataset_wikinews
  • Chat templateもjapanese-stablelm-instruct-gamma-7bのものをそのまま使用しました。

    chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '\n### 指示: \n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '\n### 応答: \n'  + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '\n### 応答: \n' }}\n{% endif %}\n{% endfor %}"
    
  • 以下はフォーマットされたデータの例です。

    以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
    
    ### 指示:
    バナナの栄養素を教えてください。
    
    ### 応答:
    確かに、バナナにはあなたの健康にとって重要な多くの栄養素が含まれています。ビタミンB6、繊維、カリウム、およびビタミンA、C、およびEの配列もあり、マグネシウムと銅も含まれており、マンガン、カルシウム、マグネシウムの優れた供給源です。
    
    以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
    
    ### 指示:
    SATスコアは、良い大学に入ろうとするときに最も重要なことですか?
    
    ### 応答:
    大学は、テストスコアに加えて、多くの要因を検討しています。最も重要なことは、生徒が高校でさまざまな科目を取ることです。もちろん、テストのスコアは重要ですが、大学は他の資質も持っている候補者を探しています。
    
    以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
    
    ### 指示:
    Euxoa brevipennisは、どの蛾科に属するのでしょうか?
    
    ### 入力:
    Euxoa brevipennisは、1888年にSmithによって初めて記述されたNoctuidae科の蛾である。カナダでは、ブリティッシュコロンビア州、アルバータ州、サスカチュワン州に生息する。アメリカではユタ州、コロラド州、カリフォルニア州から記録されている。
    
    ### 応答:
    Euxoa brevipennis mothは、ノクトウガ科に属するガです。
    

実験結果のサマリー

実施した実験の数が多いためまずは結果を要約し詳細は下で説明します。以下の表は各実験で学習したモデルをJP Language Model Evaluation Harness(effdbeaf742e74ea1787871e99272c12146ba346)で評価した結果です。

  • Instruction tuningの場合、最適なlearning_rate がかなり小さくなることがあります。(T1 vs. T10)
    • 今回の実験で一番良い性能を出したモデルの学習のlearning_rate1e-7でした。
  • 性能の良いオープンモデルのハイパーパラメータをそのまま使用しても良い結果が得られるという保証はありません。(T1 vs. T4 vs. T10)
    • 英語モデルと日本語モデルの違いかもしれません。
    • データの量の違いかもしれません。
  • Instructionはマスキングしてロスから除外する必要ないです。(T10 vs. T11)
    • 除外する場合、むしろ性能が落ちました。
  • Efficient Sequence Packing without Cross-contaminationは学習速度も上げつつ性能向上の効果もありました。(T11 vs. T18)
    • trlConstantLengthDatasetpackingオプションをTrueにすると使用できます。
    • packing=Trueで学習した後、packing=Falseで追加学習すると性能がむしろ低下しました。 (T15 vs. T16, T17)
  • Square root ruleの効果がありました。 (T11 vs. T14 vs. T15)
    • batch_sizelearning_rateが一緒に動くので、batch_sizeをGPUに最大に入る大きさに設定してlearning_rateを調整しながら良いハイパーパラメータを探すといいと思います。
  • 実験結果はここにアップロードしました。

実験のヒストリー

  • 最初の実験(Trial 1)はalignment-handbookconfigを使用しました。

    • num_train_epochsを3に増やしました。コンフィグのままエポックを1に設定して学習してみましたが、性能が上がりませんでした。
    • その後、lr_min_ratioを変えたり、weight_decayを与えたり、instruction部分をマスキングしてロスから除外するなどの実験を行ってみましたが性能の改善はありませんでした。
      • Trial 4, Trial 5, Trial 9はその実験の中で比較的結果が良かったいくつかの例です。
  • Trial 4はOpenchatを参考にしてbatch_sizeを変更した実験です。

    # https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py#L24-L55
    def parse_args():
        parser = argparse.ArgumentParser()
        
    	  ...
    
        # Hyperparameters
        parser.add_argument("--batch_max_len",      type=int, default=81920)
        parser.add_argument("--epochs",             type=int,   default=5)
    
        ...
    
        # Parse known args
        args, unknown = parser.parse_known_args()
        return args
    
    • batch_max_lenper_device_train_token_numbersって言えます。Openchatからbatch_max_lenを設定するときに2048の倍数にしてくださいって言われたので、2048はmax_seq_lengthと同じことだと思いました。
    • トークン81920個を文章数に置き換えると40個になります。
      • 後で知ったのですがreadmeでは77824をおすすめしています。
    • そこでper_device_train_batch_sizeを8に、gradient_accumulation_stepsを5に設定してbatch_sizeを40にしました。
    • 性能は若干向上しましたがまだ再現したとは全くいえない性能でした。
  • Trial 9からはinstruction部分はマスキングでロスから除外し、weight_decay 0.1を追加しました。

    • それでも性能は再現はできませんでした。
  • Trial 10は per_device_train_batch_sizeを1に、learning_rate1e-7に設定[1]して実験を行いました。

    • GPUを8枚使用したためtotal_batch_sizeは教えてもらった通り8になりました。
    • 初めて再現に成功したと言える性能が出ました。
    • batch_sizelearning_rateの学習の時の重要性は非常に大きかったです。
      • batch_sizelearning_rateがこんなに小さくないといけないとは考えもつきませんでした。
      • データの量が小さいと一緒に小さくなるのかもしれないです → 後でデータの量が多くなったら実験をやってみたいです。
  • Trial 11ではinstruction部分をマスキングせずに学習に含めるようにしました。Trial 10ではinstruction部分をマスキングしてロスから除外しました。

    • ほかのハイパーパラメータは同じでした。
    • 性能がさらに上がりました。
      • 「多くの性能の良いオープンモデルの学習の時マスキングしない理由があるんだな」と思いました。
  • Trial 12ではTrial 10 で試したバッチサイズが通常のオープンモデルより小さかったので4倍にして学習してみました。

    • Fujikiさんのハイパーパラメータだけ性能が出るのか気になって実験をしました。
    • 性能はほんの少し下がりました。でもほぼ同じと言える程度でした。
  • Trial 13ではbatch_sizeとlinearにlearning_rateを大きくしてみました。

    • Trial 12でlearning_rateを調整せずにbatch_sizeだけを大きくしたことが性能の低下の原因だったのかなと思いました。
    • 性能がほぼ同じって言えないほど下がりました。
  • Trial 14ではbatch_sizeのSquare rootに比例するようにlearning_rate を大きくしてみました。

    • batch_sizeを4倍に設定したのでsquare root ruleに従ってlearning_rateを2倍に設定しました。
    • 効果がありました。同じ性能を得ることができました!
    • 性能が若干向上し、余ったGPUメモリを利用したため学習にかかる時間も大幅に短くなりました。
  • Trial 15では他のスケールでも学習がうまくいくのか実験してみました。

    • 今回はbatch_sizeを16倍に設定し、同様にsquare root ruleに従ってlearning_rateを4倍に設定して実験しました。
    • 性能はTrial 14より少し落ちましたが、ほぼ同じって言える結果が出ました。
      • Square root ruleがうまく機能することが分かりました。
      • batch_sizelearning_rateが関数的な関係に基づいて一緒に動いていることが分かりました。 つまり一方を決まれば、もう一方は自動的に決められることになります。
      • batch_sizeをGPUに最大に入る大きさに設定してlearning_rateだけ適切なサイズを探せばいいと思います。
  • Trial 18ではtrlConstantLengthDataset のオプションの一つのpackingが与える影響について実験してみました。

    • 今までの実験はtrlConstantLengthDatasetのオプションの中でpacking=Trueに設定していましたが、今回はpacking=Falseで実験してみました。
    • 意外にpacking=True で学習したモデルの性能が少し高かったです。
    • Trial 16とTrial 17ではTrial 15で学習したモデルで初期化し、packing=Trueでもう一度ファインチューニングしてみたんですが、性能が下がりました。
  • 最後に一番良い性能を見せたTrial 14の学習した時に使ったconfigを共有します。

    # configs/sft/full/a100_40g_gamma_14.yaml
    
    # Ported from https://github.com/huggingface/alignment-handbook/blob/main/recipes/zephyr-7b-beta/dpo/config_full.yaml
    # with slight modifications to add `wandb` logging, missing `warmup_ratio: 0.1`, Notus-related stuff, and also to make
    # if work in 8 x A100 40GB by adding `torch_dtype: bfloat16` and `use_flash_attention_2: true`
    
    # Model arguments
    model_name_or_path: stabilityai/japanese-stablelm-base-gamma-7b
    torch_dtype: auto
    use_flash_attention_2: true
    
    # Data training arguments
    dataset_mixer:
      {
        kunishou/databricks-dolly-15k-ja: 1.0,
        fujiki/japanese_hh-rlhf-49k: 1.0,
        fujiki/llm-japanese-dataset_wikinews: 1.0
      }
    dataset_splits:
      - train
    preprocessing_num_workers: 12
    
    # SFTTrainer arguments
    bf16: true
    do_eval: true
    evaluation_strategy: epoch
    eval_steps: 1
    gradient_accumulation_steps: 1
    gradient_checkpointing: true
    learning_rate: 2.0e-07
    log_level: info
    logging_steps: 10
    lr_scheduler_type: cosine
    max_seq_length: 2048
    num_train_epochs: 3
    optim: adamw_torch
    adam_beta1: 0.9
    adam_beta2: 0.999
    adam_epsilon: 1.0e-08
    weight_decay: 0.1
    output_dir: results/gamma-sft-full-trial14
    per_device_train_batch_size: 4
    per_device_eval_batch_size: 8
    push_to_hub: false
    save_strategy: epoch
    save_total_limit: 1
    seed: 42
    warmup_ratio: 0.1
    report_to:
      - tensorboard
    

Acknowledgement

Disclaimer

  • 私は日本語のネイティブではありません。 日本語を勉強中です。 間違いがあれば直していただけると助かります。

Discussion