🐜

LLMのレイヤーを削除して小型化するDepth-only pruningをやってみる

2024/12/04に公開

1人ローカルLLMアドベントカレンダーの4日目です。

ローカルLLMやマルチモーダルモデルの学習やデータセット周りについて書いていく予定なので、興味がある方は明日以降も読んでいただけると嬉しいです!
https://qiita.com/advent-calendar/2024/local-llm

要約

  • Pruning直後は支離滅裂な出力になる
  • 性能の回復にはそれなりの学習が必要そう

目的

NVIDIAが公開しているnvidia/Llama-3.1-Minitron-4B-Depth-Base
はLlama-3.1-8Bからレイヤーを半分削除しても性能を維持できていたというのに興味があったので、やってみようと思います。

https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/

実施内容

個人で検証できるレベルということで、llm-jp/llm-jp-3-1.8bを使うことにしました。

Depth Up-Scaling

Pruningにはmergekitを使用しました。

https://github.com/arcee-ai/mergekit

NVIDIAの事例に習って、最終層の1つ手前から半分を削除してみます。

slices:
- sources:
  - model: 'llm-jp/llm-jp-3-1.8'
    layer_range: [0, 11]
- sources:
  - model: 'llm-jp/llm-jp-3-1.8'
    layer_range: [23, 24]

merge_method: passthrough
dtype: bfloat16

これにより、1.8Bのモデルが1.1Bになりました。

ちなみに、本家と同じように削除するレイヤーを一つずつずらして日本語wikipediaのデータでのlossを計算してみると以下のようになりました。

削除する層 loss
0-11 13.15
1-12 11.38
2-13 11.90
3-14 8.94
4-15 7.64
5-16 8.65
6-17 7.72
7-18 9.17
8-19 7.71
9-20 8.16
10-21 9.73
11-22 10.68
12-23 10.26

結果

スケーリング後の学習

まずは、llm-jp/llm-jp-3-1.8bを1.1BにPruningしたモデルを2.4Bトークンでフルパラ学習してみました。

データは以下のものを使用しています:

学習後のモデルはこちらになります
https://huggingface.co/Kendamarron/Depth-Pruned-llm-jp-3-1.1b-steps4624

出力

モデル 出力(太字が入力)
Pruning直後 こんにちは、 SEOBAR� ---?..&#�---[]\ueb2c![...)!\?"!!!"\nOPÃ Verna,Ã relevant&#� ["a Vernrx[.........!-- :)..."?)"). ["!!]!!!!!!!![--@...\n"…]� [[\~!!_]`[]**????^X Vern�SHXXnXMX=D'
2.4Bトークン学習後 こんにちは、 私の住んでいる県の県税庁が発行する「県民税と市町村の税金」のページを閲覧中でした。\n「県民税と市町村の税金」のページでは、「市町村で行っている税金に関する内容について紹介しています」と書いてあります。これは、課税対象となる市区町村ではどのようになっているのかという点について、住民投票に委ねられているものであることから、そのための国民投票対策として行われているものです。\nこの記事を見ると、住民投票は都道府県ごとに決められた地域によって行われる

最初は支離滅裂だった出力が、2.4Bトークンの学習で日本語を話せるようになっています。

ベンチマーク

Pruningの効果を測るため、llm-jp-evalで4タスクの性能を測定してみました。

モデル JCom JMMLU JSQuAD NIILC
llm-jp/llm-jp-3-1.8b 0.16 0.3 0.5515 0.4359
Pruning Model(1.1B) 0.15 0.34 0.1284 0.0774
llm-jp/llm-jp-3-3.7b 0.15 0.21 0.734 0.4955

選択式のタスクは差がないように見えますが、実際の出力を見るとすべての問題で同じ記号を選択していました。

また、自由記述式のタスクであるJSQuADとNIILCについてもfew shotの回答をそのまま出力するなど、問題に回答しているとは言えないような状態でした。

Up-Scalingと違って、ある程度まとまったトークン数を学習させないといけないのかもしれません...

まとめ

今回はモデルのDepth-only pruningと、その後の学習による性能回復について検証しました。

1兆規模のトークンを学習させる必要がないだけで、数十~数百Bトークンは必要になりそうですね

Discussion