mergekit でモデルマージを試してみる
こんにちは、初めましての方は初めまして。株式会社 Fusic の瓦です。暑い。まだ六月なのに本当に暑いです。でも六月にクーラーを付けるのは負けな気がする… そう思って扇風機で我慢している今日この頃です。
暑いといえば、最近の LLM 系周辺ではモデルを組み合わせて性能を向上させるモデルマージという方法の話題が熱いです。モデルマージはその名の通り、「モデルをマージして(=くっつけて)新しいモデルを作る」方法を指します。例えば複数のモデルのパラメータの平均を新しいモデルのパラメータとして精度を向上させたり[1]、ベースモデルを対話できるように学習させた後、そのパラメータの差分を別のモデルに加えることで学習せずに対話能力が付与したり[2]といった方法が提案されています。
この記事では、そんなモデルマージを簡単に試せるライブラリ mergekit
[3] を使ってモデルのマージを試してみようと思います。mergekit
は様々なモデルをサポートしており、かつ色々なモデルマージの手法(MoE や、話題となっている sakana.ai の進化的アルゴリズムを用いた手法[4]など)が実装されています。またモデルマージに使用するモデルや手法を YAML で管理して mergekit-yaml
コマンドで呼び出すだけでマージが手軽に行えたりもするかなり便利なライブラリなので、この記事でも使用してみたいと思います。
モデルマージを試してみる
llm-jp-1.3b のベースモデルでの実験
簡単に試してみるために、llm-jp-1.3b
[5] のモデルを対象にモデルのマージを行ってみます。llm-jp-1.3b
は NII のプロジェクトで開発された大規模言語モデルであり、パラメータ数が 1.3B ということでローカルでも手軽に扱えるのが特徴だと思っており、今回もローカルでパッと試したいのでこのモデルを選択しました。とりあえず mergekit が動かせるかを確かめるため、一旦ダウンロードしただけのモデルを二つ用意してマージしてみることにします。
まずは llm-jp-1.3b モデルを二つダウンロードします。
git clone https://huggingface.co/llm-jp/llm-jp-1.3b-v1.0 llm-jp-1.3b-1
git clone https://huggingface.co/llm-jp/llm-jp-1.3b-v1.0 llm-jp-1.3b-2
次に、mergekit の examples/linear.yml
を参考にして、上でダウンロードした二つのモデルをマージする設定を書いた YAML ファイルを用意します。
models:
- model: llm-jp-1.3b-1
parameters:
weight: 0.7
- model: llm-jp-1.3b-2
parameters:
weight: 0.3
merge_method: linear
dtype: float16
ここではマージ方法として linear
を選択し、マージ対象のモデルとして llm-jp-1.3b-1
と llm-jp-1.3b-2
を指定しました。weight
はマージするときの重みです(この設定では llm-jp-1.3b-1
* 0.7 + llm-jp-1.3b-2
* 0.3 として計算したパラメータを使って新しいモデルが出来る)他のマージ方法では引数が異なるので、適宜 examples
の中のファイルや README を参照するとよいです。
以上でマージする準備は出来たので(めっちゃ簡単!)README に従って mergekit-yaml
コマンドを使ってマージしてみます。
mergekit-yaml examples/llm-jp-linear.yaml ./llm-jp-1.3b-merged --cuda --lazy-unpickle --allow-crimes
無事にマージでき、llm-jp-1.3b-merged
にモデルが出力されているかと思います。これでマージは終わりです。とっても簡単ですね!
学習させたモデルをマージさせてみる
上では学習していないモデルをとりあえずマージしてみました。しかし、本当にマージ出来ているのか(パラメータが変な値になっていたりしないのか)気になります。そこで、次は同じデータで学習させたモデルをマージして、精度が変化しないことを確認します。(※ 以下の実験は、すべて 1:1 の割合でモデルをマージしています)
今回確かめるデータとして、JNLI[6] を対象としました。理由としては LLM-jp の評価用リポジトリ[7]で簡単に用意できるためです。このデータセットを使って、LLM-jp の学習用リポジトリ[8]を使ってファインチューニングを行いました。学習したモデルを使い、これもテストデータで評価用リポジトリに従って評価してみます。JNLI はクラス分類問題なので、精度(一致率)を見ると
次にタスクとして似ているデータでの精度の変化を確かめるために、JNLI データと JaNLI データ[9]で学習させたモデルをマージして、JSICK[10] データに対して推論し評価してみます。以下に結果の表を載せます。
学習データ | 精度 |
---|---|
JNLI | 0.1668 |
JaNLI | 0.0189 |
JSICK | 0.6359 |
JNLI + JaNLI | 0.3852 |
JNLI で学習させたモデルでは 16.68%、JaNLI で学習させたモデルでは 1.89% しか合っていませんでしたが、マージすることで 38.52% まで精度が向上しました。JSICK で学習させたモデルには遠く及びませんが、マージすることでかなり精度が向上できていることが確認できました!
まとめ
この記事ではマージを簡単に試せる mergekit を使ってみました。YAML ファイルを用意するだけで簡単にモデルマージが出来るので、学習するリソースがない場合でも新しいモデルを作ることが出来そうです。 この記事では使えるかを試してみたかったのでマージ方法として linear
しか試していませんが、マージ方法は他に色々ある[11]ので、実際にモデルを作る際は色々試してみて精度の変化を確認し、精度の良いモデルを選ぶのがよいと思います。タスクによってはマージすることによって精度が上がったり下がったりすることもあるんじゃないかと思っていて、今後マージするタスクとテストデータのタスクの違いと精度の関係も探っていきたいと思います。
後に宣伝になりますが、機械学習でビジネスの成長を加速するために、Fusic の機械学習チームがお手伝いたします。機械学習のPoCから運用まで、すべての場面でサポートした実績があります。もし、困っている方がいましたら、ぜひ Fusic にご相談ください。お問い合わせからでも気軽にご連絡いただけます。また Twitter の DM からでも大歓迎です!
Discussion