🔀

mergekit でモデルマージを試してみる

2024/06/11に公開

こんにちは、初めましての方は初めまして。株式会社 Fusic の瓦です。暑い。まだ六月なのに本当に暑いです。でも六月にクーラーを付けるのは負けな気がする… そう思って扇風機で我慢している今日この頃です。

暑いといえば、最近の LLM 系周辺ではモデルを組み合わせて性能を向上させるモデルマージという方法の話題が熱いです。モデルマージはその名の通り、「モデルをマージして(=くっつけて)新しいモデルを作る」方法を指します。例えば複数のモデルのパラメータの平均を新しいモデルのパラメータとして精度を向上させたり[1]、ベースモデルを対話できるように学習させた後、そのパラメータの差分を別のモデルに加えることで学習せずに対話能力が付与したり[2]といった方法が提案されています。

この記事では、そんなモデルマージを簡単に試せるライブラリ mergekit[3] を使ってモデルのマージを試してみようと思います。mergekit は様々なモデルをサポートしており、かつ色々なモデルマージの手法(MoE や、話題となっている sakana.ai の進化的アルゴリズムを用いた手法[4]など)が実装されています。またモデルマージに使用するモデルや手法を YAML で管理して mergekit-yaml コマンドで呼び出すだけでマージが手軽に行えたりもするかなり便利なライブラリなので、この記事でも使用してみたいと思います。

モデルマージを試してみる

llm-jp-1.3b のベースモデルでの実験

簡単に試してみるために、llm-jp-1.3b[5] のモデルを対象にモデルのマージを行ってみます。llm-jp-1.3b は NII のプロジェクトで開発された大規模言語モデルであり、パラメータ数が 1.3B ということでローカルでも手軽に扱えるのが特徴だと思っており、今回もローカルでパッと試したいのでこのモデルを選択しました。とりあえず mergekit が動かせるかを確かめるため、一旦ダウンロードしただけのモデルを二つ用意してマージしてみることにします。

まずは llm-jp-1.3b モデルを二つダウンロードします。

git clone https://huggingface.co/llm-jp/llm-jp-1.3b-v1.0 llm-jp-1.3b-1
git clone https://huggingface.co/llm-jp/llm-jp-1.3b-v1.0 llm-jp-1.3b-2

次に、mergekit の examples/linear.yml を参考にして、上でダウンロードした二つのモデルをマージする設定を書いた YAML ファイルを用意します。

llm-jp-linear.yaml
models:
  - model: llm-jp-1.3b-1
    parameters:
      weight: 0.7
  - model: llm-jp-1.3b-2
    parameters:
      weight: 0.3
merge_method: linear
dtype: float16

ここではマージ方法として linear を選択し、マージ対象のモデルとして llm-jp-1.3b-1llm-jp-1.3b-2 を指定しました。weight はマージするときの重みです(この設定では llm-jp-1.3b-1 * 0.7 + llm-jp-1.3b-2 * 0.3 として計算したパラメータを使って新しいモデルが出来る)他のマージ方法では引数が異なるので、適宜 examples の中のファイルや README を参照するとよいです。

以上でマージする準備は出来たので(めっちゃ簡単!)README に従って mergekit-yaml コマンドを使ってマージしてみます。

mergekit-yaml examples/llm-jp-linear.yaml ./llm-jp-1.3b-merged --cuda --lazy-unpickle --allow-crimes

無事にマージでき、llm-jp-1.3b-merged にモデルが出力されているかと思います。これでマージは終わりです。とっても簡単ですね!

学習させたモデルをマージさせてみる

上では学習していないモデルをとりあえずマージしてみました。しかし、本当にマージ出来ているのか(パラメータが変な値になっていたりしないのか)気になります。そこで、次は同じデータで学習させたモデルをマージして、精度が変化しないことを確認します。(※ 以下の実験は、すべて 1:1 の割合でモデルをマージしています)

今回確かめるデータとして、JNLI[6] を対象としました。理由としては LLM-jp の評価用リポジトリ[7]で簡単に用意できるためです。このデータセットを使って、LLM-jp の学習用リポジトリ[8]を使ってファインチューニングを行いました。学習したモデルを使い、これもテストデータで評価用リポジトリに従って評価してみます。JNLI はクラス分類問題なので、精度(一致率)を見ると 0.5785 となりました。推論結果自体の分析も面白いですが、今回はちゃんとマージが出来ているかを確かめることが目的なので、これ以上は深入りしません。次に JNLI で学習済みの同じモデルをマージしてみます。同じモデルを重みづけて足し合わせるため、基本的に精度は変化しないはずです。実際にマージしてみると精度は 0.5776 となりました。ちょっと落ちてる… とはいえほとんど変化していないので、マージが出来ていそうです。

次にタスクとして似ているデータでの精度の変化を確かめるために、JNLI データと JaNLI データ[9]で学習させたモデルをマージして、JSICK[10] データに対して推論し評価してみます。以下に結果の表を載せます。

学習データ 精度
JNLI 0.1668
JaNLI 0.0189
JSICK 0.6359
JNLI + JaNLI 0.3852

JNLI で学習させたモデルでは 16.68%、JaNLI で学習させたモデルでは 1.89% しか合っていませんでしたが、マージすることで 38.52% まで精度が向上しました。JSICK で学習させたモデルには遠く及びませんが、マージすることでかなり精度が向上できていることが確認できました!

まとめ

この記事ではマージを簡単に試せる mergekit を使ってみました。YAML ファイルを用意するだけで簡単にモデルマージが出来るので、学習するリソースがない場合でも新しいモデルを作ることが出来そうです。 この記事では使えるかを試してみたかったのでマージ方法として linear しか試していませんが、マージ方法は他に色々ある[11]ので、実際にモデルを作る際は色々試してみて精度の変化を確認し、精度の良いモデルを選ぶのがよいと思います。タスクによってはマージすることによって精度が上がったり下がったりすることもあるんじゃないかと思っていて、今後マージするタスクとテストデータのタスクの違いと精度の関係も探っていきたいと思います。

後に宣伝になりますが、機械学習でビジネスの成長を加速するために、Fusic の機械学習チームがお手伝いたします。機械学習のPoCから運用まで、すべての場面でサポートした実績があります。もし、困っている方がいましたら、ぜひ Fusic にご相談ください。お問い合わせからでも気軽にご連絡いただけます。また Twitter の DM からでも大歓迎です!

脚注
  1. Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time ↩︎

  2. Chat Vector: A Simple Approach to Equip LLMs with Instruction Following and Model Alignment in New Languages ↩︎

  3. https://github.com/arcee-ai/mergekit ↩︎

  4. https://arxiv.org/abs/2403.13187 ↩︎

  5. https://huggingface.co/llm-jp/llm-jp-1.3b-v1.0/tree/main ↩︎

  6. JGLUE のタスクの内の一つ ↩︎

  7. https://github.com/llm-jp/llm-jp-eval ↩︎

  8. https://github.com/llm-jp/llm-jp-sft ↩︎

  9. https://github.com/verypluming/JaNLI ↩︎

  10. https://github.com/verypluming/JSICK ↩︎

  11. npaka さんの記事がまとまっていて分かりやすいです ↩︎

GitHubで編集を提案
Fusic 技術ブログ

Discussion