🌏

KDDCup参戦記④ モデルマージ編

2024/08/15に公開

はじめに

データアナリティクスラボ株式会社の宮澤です。
KDDCup参戦記シリーズの第4弾はモデルマージの取り組みについてご紹介します。

なぜモデルマージに取り組んだのか

このコンペにおいて私たちのチームでは、主にファインチューニング(PEFT)に取り組んでいました。データの品質向上、データ量の増減、LoRAのハイパラの変更など、様々な試みをしましたが大きく精度を向上させることができず、徐々に上位陣との差を広げられていっていました。

そこで、コンペも終盤に差し掛かった頃でしたが、思い切ってファインチューニング(PEFT)以外のアプローチを取るしかないのではと思い、モデルマージにトライしてみようと思いました。

他の選択肢としては①より大きなモデルを訓練する、②フルパラメータチューニング、③継続事前学習)、④アライメント(RLHF, DPO)、⑤モデルマージ などを考えましたが、①, ②, ③は手元の計算リソースで実現不可能であり、④は今回の選択問題の精度を上げるのには適していないと考え、除外しました。残った⑤モデルマージは、7B級のモデルであればNVIDIA L4もしくはA100程度のGPUが1枚あれば実装できそうだということがわかりました。

また、この時点までの検証で「数値推論」の問題の精度が低いということが考えられていたため、算数能力を持ったモデルと一般的な言語処理能力が高いモデルをマージさせることで、スコアが上がるのではないかと考え、勉強の意味も込めてチャレンジしてみました。

モデルマージとは

LLMにおけるモデルマージの手法には大きく2つあります。
一つが「重みレベルのマージ」です。イメージとしては以下のような形で、同じアーキテクチャのモデルを複数用意し、各層のパラメータ同士を平均するなどして一つの重みにすることを指します。

(Sakana AI 進化的アルゴリズムによる基盤モデルの構築 より引用)

二つ目が「レイヤーレベルのマージ」です。イメージとしては以下のような形で、モデルの層を重ね合わせていって別のマージモデルを作ることを指します。

(Sakana AI 進化的アルゴリズムによる基盤モデルの構築 より引用)

進化的モデルマージとは

進化的モデルマージとは、Sakana AIから提案されたマージ手法です。この手法は、モデルをマージする際のハイパーパラメータの設定が経験に基づく職人技となっていた課題に対して提案されたものであり、ハイパーパラメータの探索を進化アルゴリズムを用いて行い最適化する手法です。

モデルマージおよび進化的モデルマージについての詳細は、弊社ホームページ内のJOURNALにて解説しておりますので、詳細を知りたい方はこちらをご覧ください。(実装方法も記載しています。)

https://dalab.jp/archives/journal/llm-merge-evolve/

実装

ここからはコンペのために行った実装コードを紹介していきます。
実装はGCPのVertexAIインスタンスで、NVIDIA L4 × 1枚の環境です。

ライブラリ

実装にはmergekitというライブラリを使用しました。このライブラリではモデルマージの手法がいくつかサポートされており、その中の一つの mergekit evolveで進化的モデルマージを使うことができます。

git clone https://github.com/arcee-ai/mergekit.git

設定ファイルの準備

まずはじめにいくつかの設定ファイルの準備が必要です。
設定ファイルの配置は以下の通りです。各ファイルの中身については下で説明します。

mergekit
  - evol_merge_config.yaml
  - xxxxxxx
  - xxxxxxx
  - workspace
    - eval_tasks
        - multi_choice.yaml
        - process_multi.py
        - retrieval.yaml
        - process_ret.py

1. マージの設定

マージをするにあたってのハイパーパラメータなどを設定します。
evol_merge_config.yamlというファイルを作り、以下を書き込みます。

genome:
    models:
      - meta-math/MetaMath-Mistral-7B
      - Nexusflow/Starling-LM-7B-beta
      - BAAI/Infinity-Instruct-3M-0613-Mistral-7B
    merge_method: dare_ties
    base_model: mistralai/Mistral-7B-v0.1
    layer_granularity: 4 # sane default
    allow_negative_weights: true # useful with task_arithmetic
tasks:
  - name: multiple_choice_kdd
    weight: 0.875
  - name: retrieval_kdd
    weight: 0.125

ベースはにはMistral-7Bを用いて、meta-math/MetaMath-Mistral-7B(算数データでSFTしたモデル)、Nexusflow/Starling-LM-7B-beta(RLAIF(Reinforcement Learning from AI Feedback (RLAIF))で学習したモデル)、BAAI/Infinity-Instruct-3M-0613-Mistral-7B(大量のInstructionデータでSFTしたモデル)をマージ対象としました。各モデルの能力を活かしつつ最適な回答をしてくれるモデルになることを期待しました。

マージの手法としてはDARETIESを組み合わせたdare_tiesを選択しました。

layer_granularityは分割する層数を決めるもので、Mistral-7Bは32層であるため、 32層 / 4 = 8ブロックに分割されます。マージメソッドがdare_tiesの場合、マージのパラメータはdensityweightというものがあるため、最適化する対象のパラメータは、3モデル * 8ブロック * 2パラメータ = 48パラメータとなります。

タスクは下記で説明しますが、今回のコンペ(Track 2)では約9割が4択から1つを選ぶmultiple_choiceという問題で、約1割が10~15の選択肢から複数選択するretrievalと呼ばれる問題であったため、tasksに2つの最適化タスクを重みづけして設定しています。

2. 最適化タスクデータの準備

megekitではlm-evaluation-harnessがバックグラウンド処理されるようになっていますが、その処理ではHuggingFaceにあるデータを取得して最適化に用いることができるようになっています。

そのため、予めHuggingFaceにmulti-choiceとretrievalのデータセットをアップロードしておきました。入力をinput_fieldに、正解ラベルをoutput_fieldに入れた単純なデータセットです。

3. 最適化タスクの設定

上記ディレクトリ構造のmergekit/workspace/eval_tasksの中にあるファイルを作っていきます。ここでは2つの最適化タスクに対して、設定ファイル(.yaml)と処理定義ファイル(.py)を作ります。

mutiple_choice

設定ファイル(mutiple_choice.yaml)は以下のように設定しました。
taskの名称はevol_merge_config.yamlに記載したものと一致させる必要があります。
doc_to_textは入力プロンプトの処理の設定であり、process_resultsは評価スコアの設定です。これらはprocess_multi.pyの中で定義しています。

task: multiple_choice_kdd
dataset_path: user_name/train_multi
output_type: generate_until
training_split: train
validation_split: train
test_split: train
doc_to_text: !function process_multi.doc_to_text
doc_to_target: ""
process_results: !function process_multi.score
metric_list:
  - metric: acc
    aggregation: mean
    higher_is_better: true
generation_kwargs:
  do_sample: false
  temperature: 0
  max_gen_toks: 1
metadata:
  version: 1.0

処理ファイル(process_multi.py)は以下のように設定しました。
ゼロショットでは「回答選択肢を答える」ということが難しいと考え、Few Shotで設定しておきました。

# インポート
import re
import random

# プロンプトを整形する関数
def doc_to_text(doc) -> str:
    input_field = doc["input_field"]

    instruction_prompt = "[INST] You are an assistant with excellent knowledge of shopping. Please answer multiple-choice questions about the shopping and products. Answer only one of the options 0, 1, 2, or 3.\n\n"
    few_shot = """Example :

Which of the following product categories best complement the product type electric toothbrush?
0. toothpaste
1. headphones
2. smartphone
3. book
Answer: 0

The product 'Quaker Oatmeal, Maple & Brown Sugar, 1.5 Ounce (Pack of 10)' appears on an e-commerce website. What is the total weight of the oatmeal?
0. 12 ounce
1. 15 ounce
2. 20 ounce
3. 10 ounce
Answer: 1

Given the product "Samsung Galaxy S21", as a result, PersonX feels
0. frustrated
1. disappointed
2. bored
3. excited
Answer: 3

Which of the following product categories best complement the product type digital camera?
0. frying pan
1. memory card
2. office chair
3. hair dryer
Answer: 1

The product 'Organic Fuji Apples, 16 lbs (Pack of 4)' appears on an e-commerce website. What is the weight of each pack of apples?
0. 1 lbs
1. 2 lbs
2. 4 lbs
3. 8 lbs
Answer: 2

"""

    input_field_1 = input_field.replace("\nAnswer: ",  " [/INST]\nAnswer: ")
    input_field_2 = input_field_1.replace("\nOutput: ",  " [/INST]\nAnswer: ")
    input_field_3 = input_field_2.replace("\nOutput (answer in three comma-separated numbers): ", " [/INST]\nAnswer: ")
    prompt_all = f"{instruction_prompt}{few_shot}{input_field_3}"
    return prompt_all

# 評価をする関数
def score(doc, results):
    pred = results[0] # 予測結果
    output_field = doc["output_field"] # 正解

    # 予測結果
    print("予測:", pred)
    
    if pred == output_field: # 正解の場合
        score = 1
    else: # 不正解の場合
        score = 0

    return {"acc": score}

retrieval

設定は概ね同じなので説明は割愛します。
以下にコードのみ掲載しておきます。

retireval.yaml
task: retrieval_kdd
dataset_path: user_name/train_ret
output_type: generate_until
training_split: train
validation_split: train
test_split: train
doc_to_text: !function process_ret.doc_to_text
doc_to_target: ""
process_results: !function process_ret.score
metric_list:
  - metric: acc
    aggregation: mean
    higher_is_better: true
generation_kwargs:
  do_sample: false
  temperature: 0
  max_gen_toks: 7
metadata:
  version: 1.0
process_ret.py
# インポート
import re

# プロンプトを整形する関数
def doc_to_text(doc) -> str:
    input_field = doc["input_field"]

    instruction_prompt = "[INST] You are an assistant with excellent knowledge of shopping. Please answer multiple-choice questions about the shopping and products. Answer three comma-separated numbers.\n\n"
    few_shot = """Example :

A user on an online shopping website has just purchased a product 'Women's Waterproof Hiking Boots - Brown, Size 7'. The following numbered list contains 15 products. Please select 3 products from the list that the user may also purchase.
Product List:
1. North Face Women's Venture 2 Jacket
2. Apple iPhone 14 Pro
3. KitchenAid Stand Mixer
4. Sony Bravia 65 Inch TV
5. Gucci Leather Handbag
6. Dyson V11 Vacuum Cleaner
7. Samsung Galaxy Tab S8
8. Bose QuietComfort 35 Headphones
9. Canon EOS R5 Camera
10. Vitamix Professional Series Blender
11. Smartwool Women's PhD Outdoor Light Crew Socks
12. Rolex Submariner Watch
13. Nespresso Vertuo Coffee Maker
14. Peloton Bike
15. Osprey Women's Tempest 20 Hiking Backpack
Answer: 1,11,15

You are given a user review given to a(n) laptop product. You are also given a numbered list of ten aspects.
Please choose three aspects from the list that are covered by the review.
You should ONLY output three numbers, separated by comma. Do not generate explanations or other texts.
Review:
This laptop is very fast and has a long battery life. The screen resolution is also excellent.
Aspect List:
1. battery life
2. cooking speed
3. screen resolution
4. color
5. weight
6. operating system
7. camera quality
8. durability
9. sound quality
10. processing speed
Answer: 1,3,10

You are a helpful shop assistant. A user would like to buy the product 'Lavazza Super Crema Whole Bean Coffee Blend, Medium Espresso Roast, 2.2 Pound (Pack of 1)'. Please select the products that the user may also buy from the following numbered list.
Product List:
1. Organic Stevia in the Raw, 800 Count
2. Twinings English Breakfast Tea, 100 Count
3. KitchenAid Stand Mixer
4. Ghirardelli Chocolate Baking Chips
5. Breville Smart Grinder Pro
6. Hario V60 Paper Coffee Filters
7. Green Mountain Dark Magic Coffee, Keurig K-Cup Pods, 12 Count
8. French Press Coffee Maker
9. Maxwell House Original Roast Ground Coffee, 30.6 oz
10. Starbucks Classic Hot Cocoa Mix
11. Cuisinart Coffee Maker
12. Bodum Chambord Milk Frother
13. De'Longhi Espresso Machine
14. Jura Cool Control Milk Cooler
15. Folgers Classic Roast Instant Coffee, 8 oz
You should output 3 numbers that correspond to the selected products. There should be a comma separating every two numbers. Only respond with the results. Do not say any word or explanations.
Answer: 5,8,13

"""

    input_field_1 = input_field.replace("\nAnswer: ",  " [/INST]\nAnswer: ")
    input_field_2 = input_field_1.replace("\nOutput: ",  " [/INST]\nAnswer: ")
    input_field_3 = input_field_2.replace("\nOutput (answer in three comma-separated numbers): ", " [/INST]\nAnswer: ")
    prompt_all = f"{instruction_prompt}{few_shot}{input_field_3}"
    return prompt_all

# 評価をする関数
def score(doc, results):
    pred = results[0] # 予測結果
    output_field = doc["output_field"] # 正解

    # 予測結果
    print("予測:", pred)

    try:
        # 数字以外の文字をスペースに変換し、数字のリストに変換して最初の3つだけを取る
        correct_numbers = set(int(num) for num in ''.join(c if c.isdigit() else ' ' for c in output_field).split()[:3])
        predicted_numbers = set(int(num) for num in ''.join(c if c.isdigit() else ' ' for c in pred).split()[:3])

        # ヒット数と全体の数を計算
        hits = len(correct_numbers & predicted_numbers)
        total = len(correct_numbers)
        
        # ヒット率を計算
        hit_rate = hits / total if total > 0 else 0

    except ValueError:  # 数字に変換できない文字が入力された場合
        hit_rate = 0
    
    return {"acc": hit_rate}

ここまで準備できたら実行に進んでいきます。

実装

環境構築はクリーンな環境にpip install -e .[evolve,vllm]をすれば問題なく完了します。私はGCPのVertexAIを使っていたため、元から入っているライブラリとの依存関係エラーで少しハマりました。大きなところとしては、setuptoolsのバージョンが合わず、69.5.1にダウングレードすることで解消されました。

マージの実行は簡単で、以下のように引数を設定してコマンド実行すればマージを開始することができます。

!mergekit-evolve ./evol_merge_config.yaml \
    --storage-path ./workspace/evol_merge_storage \
    --task-search-path ./workspace/eval_tasks \
    --in-memory \
    --merge-cuda \
    --wandb \
    --wandb-project mergekit-evolve \
    --wandb-entity your_wandb_entity \
    --max-fevals 1000

結果

学習推移

今回はマージの対象モデルが3つのパターン(緑)と2つのパターン(橙)で実行しました。
wandbでの学習推移は以下の通りでした。

進化アルゴリズム(CMA-ES)が使われるため、初めに生成された子孫の数が一つ一つのステップ数となっています。例えば橙のほうでは子孫が15であったため、15パターンでの評価が終わった時点でのベストスコアがプロットされ、そこで新たな15の子孫に世代交代されて評価が始まる、という流れになっています。

序盤は順調にベストスコアが上がっていましたが、200ステップ目あたりからベストスコアが横ばいになり上がらなくなりました。

Evolutionary Optimization of Model Merging Recipesには、Parameter Spaceでのマージの実験ではOptunaで1,000ステップ学習したと述べられていたため、1,000ステップを考えていましたが、コンペ終盤であったこともあり、350ステップあたりで打ち切ることとしました。

ちなみに、1ステップあたり15分~17分ほどかかっていたため、350ステップに到達するまでに3.5日ほど経過していました。

※CMA-ESのイメージは以下のような形であり、多変量正規分布の平均と分散を変化させながら最適解に近づくように探索するアルゴリズムです。

Wikipedia CMA-ESより引用)

サブミットスコア

wandbのベストスコアの縦軸を見るとわかる通り、スコアは大きくは上がっておらず、マージをしてみたものの、サブミットスコアはさほど期待できませんでした。
実際にサブミットしてみると、スコアは以下の通りでした。(Starling-LM-7B-betaはマージ対象のモデルの一つ。)

Model multiple_choice score retireval score
Starling-LM-7B-beta 0.599457 0.453634
Green 0.609494 0.418546
Orange 0.597068 0.434837

考察

今回の結果を踏まえた考察です。

1. パラメータ数を大きくしすぎた もしくは 最適化ステップ数が不足していた可能性

上記の通り、層の分割サイズlayer_granularityを4として3モデルをマージする場合、最適化するパラメータは48になります。予備実験で数パターン試した際に、このパラメータ数が多いほどCMA-ESの子孫の数は多くなる傾向が見えていました。今回は15ステップあたりで世代交代がされていたため、350ステップでも世代交代の数は23回ほどになります。48のパラメータを最適化するにあたって、この回数が十分であったかどうかは疑問が残る点であり、まだまだ足りなかった可能性があるのではないかと考えられました。同じ処理時間で世代交代回数を増やすためにlayer_granularityをもう少し大きくしておく(分割サイズが大きくなるのでパラメータが減る)ほうがよかったかもしれないと考えました。

しかし、この最適化に必要なステップ数は自分が調査不足であったため、今後の宿題にしたいと思います。

参考として、Sakana AIによる実験はこちらの記事では、「最終的なモデルは、数百世代にわたる進化の中で最も優れた性能を発揮する(学習セットで高いスコアを獲得した)モデルです。」と書かれていました。

2. ある1つのタスクだけの精度向上の目的にモデルマージという手段が適していなかった可能性

今回はコンペのTrack2のスコアを高めるため、英語で与えられる4択問題の性能を高めるためにモデルマージという手段を試してみましたが、結果としてマージ前のモデルの精度とほとんど変わりませんでした。

ここでEvolutionary Optimization of Model Merging Recipesの実験などを振り返ると、日本語モデル + 英語数学モデル → 日本語数学タスクの性能向上のように、ドメインとタスクの掛け合わせによる性能向上がされていました。

Chat Vector: A Simple Approach to Equip LLMs with Instruction Following and Model Alignment in New Languagesのような手法でも、継続事前による言語ドメイン知識を付けたモデルとChat Vectorのマージによってドメイン×タスクの性能向上が報告されています。

一方で今回は「Mistralベースで汎用的な性能が高いモデル と 算数データセットでMistralをファインチューニングしたモデル をマージして、数値推論を含む4択問題の性能を上げよう」という試みでした。最適化タスクが1トークン(回答の選択肢)を出力するという単純なタスクのみであったことや、言語ドメインも英語のみであったことから、マージすることによる能力掛け合わせという利点が十分に得られず、マージ元のモデルをファインチューニングする以上の性能は出せなかったのではないかと感じました。またMetaMath-Mistral-7BはGSM8Kは多段的に数値推論するモデルとしてチューニングされていたようだったので、このようにファインチューニングされたときの形式に合わせてマージの最適化タスクも設定ししないと、ファインチューニングモデルの数値推論の能力をうまく引き出せないという可能性も考えられました。

また、補足ですが、今回実現したかったことはどちらかと言えば、3モデルでトークンの生成確率を出力してアンサンブルするという方法が近いように感じていました。ただし、こちらはコンペの環境に3つのモデルが載りきらないため早々に断念していました。

おわりに

本記事ではKDD Cupでのモデルマージの取り組みについて紹介しました。結果としてスコアが大きく上がるモデルを作ることはできませんでしたが、マージの理論や実装について理解することができました。モデルマージは学習や作業コストが少なく高性能なモデルを作ることができると期待される手法ですが、その研究はまだ発展途上であるように感じます。新しいマージの手法は研究が活発になっているため、この技術については引き続き追っていこうと思います。

お読みいただきありがとうございました。

関連記事

DAL Tech Blog

Discussion