Instruction tuningをやってみて分かった5つこと
今回の実験の目的ときっかけ
日本語のLLMを勉強していたところ、ふと「Baseモデルを自分が学習するのは難しいけどInstruction tunedモデルは学習出来そうなのでその性能を再現してみる価値があるのではないか?」と考えました。普段よく使う StableLM-instruct-gamma-7b
の性能の再現を目指して、以下のことを確認しようと決定しました。
- Base modelからのInstruction tuningを行い、公開されているInstructモデルの性能を再現する。
- その過程でどの要素が良い性能を作るために重要なのか確認する。
実験のセッティング
-
使用したtransformersのバージョンは以下の通りです。
transformers==4.35.0
trl==0.7.6
-
notus のコミットハッシュ
5cd49123b2fd8622a656afb441361e27db791898
を使用しました。 -
japanese-stablelm-instruct-gamma-7b
の性能を再現したいのでデータをそのまま使用しました。kunishou/databricks-dolly-15k-ja
fujiki/japanese_hh-rlhf-49k
fujiki/llm-japanese-dataset_wikinews
-
Chat templateも
japanese-stablelm-instruct-gamma-7b
のものをそのまま使用しました。chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '\n### 指示: \n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '\n### 応答: \n' + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '\n### 応答: \n' }}\n{% endif %}\n{% endfor %}"
-
以下はフォーマットされたデータの例です。
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。 ### 指示: バナナの栄養素を教えてください。 ### 応答: 確かに、バナナにはあなたの健康にとって重要な多くの栄養素が含まれています。ビタミンB6、繊維、カリウム、およびビタミンA、C、およびEの配列もあり、マグネシウムと銅も含まれており、マンガン、カルシウム、マグネシウムの優れた供給源です。
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。 ### 指示: SATスコアは、良い大学に入ろうとするときに最も重要なことですか? ### 応答: 大学は、テストスコアに加えて、多くの要因を検討しています。最も重要なことは、生徒が高校でさまざまな科目を取ることです。もちろん、テストのスコアは重要ですが、大学は他の資質も持っている候補者を探しています。
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。 ### 指示: Euxoa brevipennisは、どの蛾科に属するのでしょうか? ### 入力: Euxoa brevipennisは、1888年にSmithによって初めて記述されたNoctuidae科の蛾である。カナダでは、ブリティッシュコロンビア州、アルバータ州、サスカチュワン州に生息する。アメリカではユタ州、コロラド州、カリフォルニア州から記録されている。 ### 応答: Euxoa brevipennis mothは、ノクトウガ科に属するガです。
実験結果のサマリー
実施した実験の数が多いためまずは結果を要約し詳細は下で説明します。以下の表は各実験で学習したモデルをJP Language Model Evaluation Harness(effdbeaf742e74ea1787871e99272c12146ba346
)で評価した結果です。
- Instruction tuningの場合、最適な
learning_rate
がかなり小さくなることがあります。(T1 vs. T10)- 今回の実験で一番良い性能を出したモデルの学習の
learning_rate
は1e-7
でした。
- 今回の実験で一番良い性能を出したモデルの学習の
- 性能の良いオープンモデルのハイパーパラメータをそのまま使用しても良い結果が得られるという保証はありません。(T1 vs. T4 vs. T10)
- 英語モデルと日本語モデルの違いかもしれません。
- データの量の違いかもしれません。
- Instructionはマスキングしてロスから除外する必要ないです。(T10 vs. T11)
- 除外する場合、むしろ性能が落ちました。
-
Efficient Sequence Packing without Cross-contaminationは学習速度も上げつつ性能向上の効果もありました。(T11 vs. T18)
-
trl
のConstantLengthDataset
のpacking
オプションをTrue
にすると使用できます。 -
packing=True
で学習した後、packing=False
で追加学習すると性能がむしろ低下しました。 (T15 vs. T16, T17)
-
- Square root ruleの効果がありました。 (T11 vs. T14 vs. T15)
-
batch_size
とlearning_rate
が一緒に動くので、batch_size
をGPUに最大に入る大きさに設定してlearning_rate
を調整しながら良いハイパーパラメータを探すといいと思います。
-
- 実験結果はここにアップロードしました。
実験のヒストリー
-
最初の実験(Trial 1)は
alignment-handbook
のconfigを使用しました。-
num_train_epochs
を3に増やしました。コンフィグのままエポックを1に設定して学習してみましたが、性能が上がりませんでした。 - その後、
lr_min_ratio
を変えたり、weight_decay
を与えたり、instruction部分をマスキングしてロスから除外するなどの実験を行ってみましたが性能の改善はありませんでした。- Trial 4, Trial 5, Trial 9はその実験の中で比較的結果が良かったいくつかの例です。
-
-
Trial 4は
Openchat
を参考にしてbatch_size
を変更した実験です。# https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py#L24-L55 def parse_args(): parser = argparse.ArgumentParser() ... # Hyperparameters parser.add_argument("--batch_max_len", type=int, default=81920) parser.add_argument("--epochs", type=int, default=5) ... # Parse known args args, unknown = parser.parse_known_args() return args
-
batch_max_len
はper_device_train_token_numbers
って言えます。Openchat
からbatch_max_len
を設定するときに2048の倍数にしてくださいって言われたので、2048はmax_seq_length
と同じことだと思いました。 - トークン81920個を文章数に置き換えると40個になります。
- 後で知ったのですがreadmeでは77824をおすすめしています。
- そこで
per_device_train_batch_size
を8に、gradient_accumulation_steps
を5に設定してbatch_size
を40にしました。 - 性能は若干向上しましたがまだ再現したとは全くいえない性能でした。
-
-
Trial 9からはinstruction部分はマスキングでロスから除外し、
weight_decay
0.1を追加しました。- それでも性能は再現はできませんでした。
-
Trial 10は
per_device_train_batch_size
を1に、learning_rate
を1e-7
に設定[1]して実験を行いました。- GPUを8枚使用したため
total_batch_size
は教えてもらった通り8になりました。 - 初めて再現に成功したと言える性能が出ました。
-
batch_size
とlearning_rate
の学習の時の重要性は非常に大きかったです。-
batch_size
とlearning_rate
がこんなに小さくないといけないとは考えもつきませんでした。 - データの量が小さいと一緒に小さくなるのかもしれないです → 後でデータの量が多くなったら実験をやってみたいです。
-
- GPUを8枚使用したため
-
Trial 11ではinstruction部分をマスキングせずに学習に含めるようにしました。Trial 10ではinstruction部分をマスキングしてロスから除外しました。
- ほかのハイパーパラメータは同じでした。
- 性能がさらに上がりました。
- 「多くの性能の良いオープンモデルの学習の時マスキングしない理由があるんだな」と思いました。
-
Trial 12ではTrial 10 で試したバッチサイズが通常のオープンモデルより小さかったので4倍にして学習してみました。
- Fujikiさんのハイパーパラメータだけ性能が出るのか気になって実験をしました。
- 性能はほんの少し下がりました。でもほぼ同じと言える程度でした。
-
Trial 13では
batch_size
とlinearにlearning_rate
を大きくしてみました。- Trial 12で
learning_rate
を調整せずにbatch_size
だけを大きくしたことが性能の低下の原因だったのかなと思いました。 - 性能がほぼ同じって言えないほど下がりました。
- Trial 12で
-
Trial 14では
batch_size
のSquare rootに比例するようにlearning_rate
を大きくしてみました。-
batch_size
を4倍に設定したのでsquare root ruleに従ってlearning_rate
を2倍に設定しました。 - 効果がありました。同じ性能を得ることができました!
- 性能が若干向上し、余ったGPUメモリを利用したため学習にかかる時間も大幅に短くなりました。
-
-
Trial 15では他のスケールでも学習がうまくいくのか実験してみました。
- 今回は
batch_size
を16倍に設定し、同様にsquare root ruleに従ってlearning_rate
を4倍に設定して実験しました。 - 性能はTrial 14より少し落ちましたが、ほぼ同じって言える結果が出ました。
- Square root ruleがうまく機能することが分かりました。
-
batch_size
とlearning_rate
が関数的な関係に基づいて一緒に動いていることが分かりました。 つまり一方を決まれば、もう一方は自動的に決められることになります。 -
batch_size
をGPUに最大に入る大きさに設定してlearning_rate
だけ適切なサイズを探せばいいと思います。
- 今回は
-
Trial 18では
trl
のConstantLengthDataset
のオプションの一つのpacking
が与える影響について実験してみました。- 今までの実験は
trl
のConstantLengthDataset
のオプションの中でpacking=True
に設定していましたが、今回はpacking=False
で実験してみました。 - 意外に
packing=True
で学習したモデルの性能が少し高かったです。 - Trial 16とTrial 17ではTrial 15で学習したモデルで初期化し、
packing=True
でもう一度ファインチューニングしてみたんですが、性能が下がりました。
- 今までの実験は
-
最後に一番良い性能を見せたTrial 14の学習した時に使ったconfigを共有します。
# configs/sft/full/a100_40g_gamma_14.yaml # Ported from https://github.com/huggingface/alignment-handbook/blob/main/recipes/zephyr-7b-beta/dpo/config_full.yaml # with slight modifications to add `wandb` logging, missing `warmup_ratio: 0.1`, Notus-related stuff, and also to make # if work in 8 x A100 40GB by adding `torch_dtype: bfloat16` and `use_flash_attention_2: true` # Model arguments model_name_or_path: stabilityai/japanese-stablelm-base-gamma-7b torch_dtype: auto use_flash_attention_2: true # Data training arguments dataset_mixer: { kunishou/databricks-dolly-15k-ja: 1.0, fujiki/japanese_hh-rlhf-49k: 1.0, fujiki/llm-japanese-dataset_wikinews: 1.0 } dataset_splits: - train preprocessing_num_workers: 12 # SFTTrainer arguments bf16: true do_eval: true evaluation_strategy: epoch eval_steps: 1 gradient_accumulation_steps: 1 gradient_checkpointing: true learning_rate: 2.0e-07 log_level: info logging_steps: 10 lr_scheduler_type: cosine max_seq_length: 2048 num_train_epochs: 3 optim: adamw_torch adam_beta1: 0.9 adam_beta2: 0.999 adam_epsilon: 1.0e-08 weight_decay: 0.1 output_dir: results/gamma-sft-full-trial14 per_device_train_batch_size: 4 per_device_eval_batch_size: 8 push_to_hub: false save_strategy: epoch save_total_limit: 1 seed: 42 warmup_ratio: 0.1 report_to: - tensorboard
Acknowledgement
- Trial 10とそのあとからの実験はStability-AI JapanのFujikiさんに助けてもらいました。
- バッチサイズやランニングレートの設定を聞いて回答をしてもらうことができました。
- ありがとうございます 🙏
Disclaimer
- 私は日本語のネイティブではありません。 日本語を勉強中です。 間違いがあれば直していただけると助かります。
Discussion