LLMのDepth Up-Scalingを試す
1人ローカルLLMアドベントカレンダーの2日目です。
ローカルLLMやマルチモーダルモデルの学習やデータセット周りについて書いていく予定なので、興味がある方は明日以降も読んでいただけると嬉しいです!
要約
- Up-Scaling直後は性能が低下していそう
- 数百Mトークン程度の学習でも性能が回復した
- 数Bトークン程度の学習で元のモデルを超えられる?
目的
GENIACでELYZAさんが取り組まれていた、70B→120BへのDepth Up-Scalingを見て自分でもやりたくなってしまったので、小規模なモデルでの検証を行っていきます。
Depth Up-ScalingについてはELYZAさんの記事をご覧ください。
実施内容
個人で検証できるレベルということで、llm-jp/llm-jp-3-1.8bを使うことにしました。
Depth Up-Scaling
Up-Scalingにはmergekitを使用しました。
ELYZAさんは0-10,5-15,...,70-80と5層ずつずらして重ねていたので、それを参考に以下のように設定しました。
slices:
- sources:
- model: 'llm-jp/llm-jp-3-1.8b'
layer_range: [0, 6]
- sources:
- model: 'llm-jp/llm-jp-3-1.8b'
layer_range: [3, 9]
- sources:
- model: 'llm-jp/llm-jp-3-1.8b'
layer_range: [6, 12]
- sources:
- model: 'llm-jp/llm-jp-3-1.8b'
layer_range: [9, 15]
- sources:
- model: 'llm-jp/llm-jp-3-1.8b'
layer_range: [12, 18]
- sources:
- model: 'llm-jp/llm-jp-3-1.8b'
layer_range: [15, 21]
- sources:
- model: 'llm-jp/llm-jp-3-1.8b'
layer_range: [18, 24]
merge_method: passthrough
dtype: bfloat16
これにより、1.8Bのモデルが2.96Bになりました。
結果
スケーリング後の学習
- 0.5BトークンをLoRA(r=128)で学習
- 0.3Bトークンをフルパラメータで学習
の2通りを試しています。
データは以下のものを使用しました:
ベンチマーク
Up-Scalingの効果を測るため、llm-jp-evalで4タスクの性能を測定してみました。
モデル | JCom | JMMLU | JSQuAD | NIILC | mean |
---|---|---|---|---|---|
llm-jp/llm-jp-3-1.8b | 0.16 | 0.3 | 0.5515 | 0.4359 | 0.3619 |
Up-Scalingモデル(3B) | 0.22 | 0.3 | 0.3506 | 0.297 | 0.2919 |
Up-ScalingモデルをLoRA(r=128)で0.5Bトークン学習 | 0.19 | 0.19 | 0.3884 | 0.3884 | 0.2892 |
Up-Scalingモデルをフルパラで0.3Bトークン学習 | 0.11 | 0.33 | 0.551 | 0.3886 | 0.3449 |
llm-jp/llm-jp-3-3.7b | 0.15 | 0.21 | 0.734 | 0.4955 | 0.3974 |
それぞれのベンチマークの結果についてですが、JCommonsenseQAとJMMLUの差は誤差レベルかなという印象です。
一方で、JSQuADとNIILCは自由記述式のタスクということで、それなりに差がついています。
記述式のタスクで考えると、0.3B程度でもUp-Scaling前のモデルの性能に回復していそうです。数B~数十Bトークン学習することで、元となったモデルを超えるようになるかもしれません。
まとめ
今回はモデルのDepth Up-Scalingと、その後の学習による性能回復について検証しました。
元となったモデルを超えるにはそれなりの規模の学習が必要になりそうですが、1から学習する必要がないことから、大規模モデルを効率よく開発できるのは良さそうですね。
Discussion