Eediコンペ上位解法まとめ
はじめに
この記事は2024年12月に終了したEedi - Mining Misconceptions in Mathematicsの上位解法をまとめたものです。私もこのコンペに参加しており結果はメダル圏外となってしまいましたが、解法の至る所にLLMが用いられており非常に学びの多いコンペであったと感じています。この記事では上位解法をまとめることで本コンペの集合知を共有できればと考えています。
コンペ概要
タスク
数学的な質問(Question)に対して1つの正解(Correct Answer)と3つの不正解(Incorrect Answer)で構成される多肢選択問題が与えられ、不正解の背景にある誤解(Misconception) を予測することが本コンペのタスクです。
例えば下記の例でBの13という選択肢を選んだ人は「四則演算の優先順位に関係なく左から右へ計算を実行する」という誤解があると考えられます。
データセットの構成
train/testデータ
Question: 数学の問題文です。
Answer: 学生が選択した回答で、正解および不正解の両方が含まれます。
Misconception: 学生の誤解を表すラベルで、trainデータにのみ存在します。
Subject: 問題の科目カテゴリを示すテキストデータです。
Construct: 問題の構造や内容を示すテキストデータです。
Misconceptionデータ
事前に2587件の誤解の候補が用意されており、各候補にはIDと説明が付与されています。
参加者はこれらの候補の中から各質問と誤回答のペアに最も適切な誤解を予測する必要があります。
詳細はこちら
評価指標
Mean Average Precision@25 (MAP@25) が採用されています。
この指標には以下の特徴があります
- 1つの質問-誤回答ペアに対して、最大25個の予測値を提出可能。
- 正しいMisconceptionが25個の予測の中に含まれているとスコアが加算される。
- 正しいMisconceptionが上位にあるほどスコアが高くなるため、予測の順位が重要。
参考として、予測値のランクとスコアの関係が以下の通りです。上位ほどスコアの伸びが大きく、例えば25位を23位に引き上げても+0.003ですが、3位を1位に引き上げると+0.66と大きく向上します。
なおPrivate LBの1位スコアが0.638であることから1位ソリューションの場合、提出した予測の多くが上位2つに正解が含まれているという精度感になっています。
提出時は以下のように1つのQuestion-Answerのペアに対して最大25個のMisconceptionを紐づける形式でsubmissionは作成されます。
QuestionId_Answer,MisconceptionId
1869_A,1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
1869_B,1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
...
基本解法
ここでは公開notebookでも利用されていた基本的な解法を紹介します。上位解法も基本構造は同じで 以下で紹介する2stageのパイプラインが使用されていました。
stage1: Retrieval
Stage1では 質問-誤回答ペア に関連性の高い Misconception を、すべてのMisconception候補から抽出するタスクを解きます。埋め込みモデルを用いて Query(質問、誤回答などのテキスト情報)と Misconception(誤解)の埋め込み(embedding)を取得し、その類似度を計算することでQueryに関連するMisconceptionを抽出します。
RetrieverのFine-Tuningについてはコンペ初期に以下の素晴らしいnotebookが公開されていました。sentence_transformers
ライブラリを利用して埋め込みモデルを学習する方法が紹介されています。埋め込みモデルの学習方法はいくつか存在しますが本notebookで紹介されたMultipleNegativesRankingLoss
による 対照学習 が上位解法でも採用されていました。
埋め込みモデルにはhugging faceのLBにあるembedding特化のモデルの他、テキスト生成用のLLMを埋め込みモデルとして利用するケースも見られました。テキスト生成用のLLMの方が種類が豊富で大規模なモデルも多く存在することから、Retrieval性能はテキスト生成用のLLMの方が高い 傾向があり、公開notebookや上位解法でもテキスト生成用のLLMが多く利用されていました。Fine-Tuningを行う場合は、メモリ節約のためモデルを量子化した上でLoRAを学習する QLoRA が採用されていました。
stage2: Reranking
stage2のrerankingではstage1で抽出したMisconception候補をもとに関連性の強い順に並び替えるタスクを解きます。reranking方法はいくつかバリエーションがあり詳細は後述しますが公開notebookでは以下のように複数のmisconceptionをpromptに与えて正しいmisconceptionを選択させるlist wise reranking手法が提案されていました。
またLLMの推論については非常に高速な推論が可能であるvllm
が採用されていました。
このようにRetrievalとRerankingを2stageに分ける解法は上位公開notebookで採用されていました。ただし公開notebookではFine-TuningなしのモデルがRerankerとして利用されていた一方で、上位解法の多くがQLoRAによるFine-Tuningを行っていました。
2stageでretrievalとrerankingを行う解法はmap@k指標を採用していた以下のコンペの上位解法でも使用されていることから主流な方法であると言えそうです。2stage制の意義については強力なrerankerを全てのMisconceptionの組み合わせに適用すると計算コストが大きいためretrieverである程度候補を絞ってから強力なrerankerで推論するためという理解をしています。
H&M Personalized Fashion Recommendations
Kaggle - LLM Science Exam
上位解法まとめ
上位解法の重要ポイントを以下の6つに分けてみました。
- LLMによるSyntheticデータ生成
- Retrieval
- Reranking
- Post Process
- LLMテクニック
- CV戦略
1. LLMによるSyntheticデータ生成
本コンペではLLMを用いたSynthetic(合成)データ生成が有効であり、生成されたデータをRetrieverやRerankerの学習に利用することでスコアを改善することができました。LLMによるSyntheticデータ生成が有効であった理由として以下の2点がありそうです。
- trainデータに存在しない未知のMisconceptionがtestデータに多く含まれている
- 性能の高いLLMを利用すれば、十分な精度と品質のデータ生成が可能である
質問-回答データ生成
未知のMisconceptionを学習するためにtrainデータにない未知のmisconceptionに関連する質問と誤回答をLLMを用いて生成しています。生成されるデータ品質を向上するためにチームごとにいくつか工夫がされています。
- Grouped Synthetic Data Generation[1]
- retriever/rerankerの出力から共起行列を作成し類似Misconceptionごとにクラスタを作成する
- 同一クラスタのMisconceptionをfew shot exampleとしてpromptに追加しデータ生成する
- LLMによるフィルタリング[1][2]
- 生成されたSyntheticデータをLLMに評価させスコアが低いデータは除外する
- 大きめのexample付与[4]
- promptにtrainデータからのexampleを多く含めると生成品質が向上したため100件のexampleをpromptに追加してデータを生成する
上位解法のSyntheticデータ生成で使用されたモデル一覧
モデル | 使用チーム |
---|---|
GPT-4o | 1,8,9,10 |
GPT-4o-mini | 2,7 |
Qwen2.5-72B | 4,6 |
Claude 3.5 Sonnet | 1 |
Qwen Math | 2 |
gemini-1.5-pro | 5 |
gemma 27B | 7 |
Qwen2.5 32B | 7 |
Misconception Augmentation
Misconceptionの埋め込みの表現力を上げるためLLMを用いてMisconceptionに対する説明を生成してそれをモデルへの入力に利用するアプローチです。[2]
このアプローチはMisconceptionの表現を豊かにするだけでなく事前計算しておけば提出時に追加の推論コストがかからない点も魅力的で面白いアプローチだなと思いました。
system_prompt_template = 'You are an excellent math teacher about to teach students of year group 1 to 14.
The subject of your lesson includes Number, Algebra, Data and Statistics, Geometry and Measure.
You will be provided a misconception that your students may have.
Please explain the misconception in detail and provide some short cases when the misconception will occur.
No need to provide the correct approach. The explanation should be in format of "Explanation: {explanation}"'
user_prompt_template = 'Misconception: {misconception}'
新規のMisconception生成
コンペ用に与えられた2587件のMisconceptionとは別の新しいMisconceptionをLLMを用いて生成しているチームも存在しました。意図を正しく掴めているかは分かりませんが、Misconception Augmentationと同様、Misconceptionの表現を豊かにするためと想像しています。
-
1位チーム[1]
4000件の新規Misconceptionを生成した。ただしそのまま追加すると既存のMisconceptionの言い換えとなりノイズとなるため、埋め込みモデルを利用して既存のMisconceptionと類似スコアが大きい(0.95~0.995)新規Misconceptionを除外する -
4位チーム[4]
few shot exampleをpromptに含めてQwen2.5-32B-instructでMisconceptionを生成する
prompt例
"""You are an expert in mathematics.
Refer to the examples below to identify and describe the misconception that led to the incorrect answer.
Example1
ConstructName: Recognise and use efficient methods for mental multiplication
SubjectName: Mental Multiplication and Division
Math problem: Tom and Katie are discussing ways to calculate\\( 21\\times 12\\) mentally. Tom does\\( 12\\times 7\\) and then multiplies his answer by\\( 3\\); Katie does\\( 21\\times 6\\) and then doubles her answer. Who would get the correct answer?
Incorrect answer: Only Katie
Misconception: Does not correctly apply the distributive property of multiplication
Example2
ConstructName: Multiply a decimal by an integer
SubjectName: Mental Multiplication and Division
Math problem:\\( 9.4\\times 50=\\)
Incorrect answer:\\( 4700\\)
Misconception: When multiplying a decimal by an integer, ignores decimal point and just multiplies the digits
ConstructName:{ConstructName}
SubjectName:{SubjectName}Math problem:{QuestionText}
Incorrect answer:{AnswerText}
Misconception:
"""
2.Retrieval
Retrievalパートは上位陣も基本解法とほとんど同じ構成でしたがSyntheticデータを活用したり複数モデルのアンサンブルを活用することでスコアを伸ばしていた印象です。またRetrievalパートはmap@25だけでなくrecall指標が重要であることからrecall指標を使ってモデルの評価を実施しているチームもいました。[1][4]
Fine-Tuningには基本解法で触れた対照学習が多く利用されていました。MultipleNegativesRankingLoss
による対照学習ではバッチ内の他の正例を負例として扱うバッチ内負例と、正例と類似度が高いハード負例と呼ばれる負例の2種類が利用されます。負例サイズは重要なパラメータであるためチームごとにチューニングが実施されていました。
- バッチ内負例6件とハード負例42件[4]
- バッチサイズ不明, ハード負例2件[5]
- バッチ内負例32件とハード負例5件[6]
- バッチサイズ不明, ハード負例4件[7]
その他の工夫としてCoTやMisconceptionをLLMに予測させてその出力をpromptに追加する手法も取られていました。(詳細は5.LLMテクニックを参照)
上位解法のretrieverで使用されたモデル一覧
モデル | 使用チーム |
---|---|
QWen/Qwen2.5-14B | 1,3,4,6,8,9,10,12 |
QWen/Qwen2.5-32B | 2,4,12 |
Salesforce-SFR-Embedding-2_R | 6,7,12 |
BAAI/bge-en-icl | 1 |
intfloat/e5-mistral-7b-instruct | 1 |
Qwen/QwQ-32B-preview | 2 |
Linq-AI-Research/Linq-Embed-Mistral | 2 |
dunzhang/stella_en_1.5B_v5 | 5 |
bge-multilingual-gemma2 | 6 |
gte-Qwen2-7B | 6 |
3.Reranking
個人的にはretriverよりもrerankerの改善が重要だったように感じています。
上位解法の多くがLLMをFine-Tuningしたものをrerankerとして利用していました。アプローチ別に紹介します。
Point wise reranking
質問-誤回答のペアに対して1つのMisconceptionに関する処理を実施するアプローチです。
与えられた単一のmisconceptionが正しいかどうかをyes/noで予測させそのlogits値をrerankingに利用しています。[1][7]
1位解法のprompt
1位解法の訓練コード
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
logits_yes = outputs.logits[:, -1, self.yes_loc] # Yes token logit at the last position [bs]
logits_no = outputs.logits[:, -1, self.no_loc] # No token logit at the last position [bs]
logits = logits_yes - logits_no # [bs]
logits = logits.reshape(-1, self.group_size)
labels = labels.to(logits.device).reshape(-1)
ce_loss = self.loss_fn(logits, labels)
List wise reranking
質問-誤回答のペアに対して複数のMisconceptionに関する処理を実施するアプローチです。[1][2][4][5][7]
list wiseはpoint wiseよりもpromptに含まれるmisconception情報が多く、かつ推論回数が少なくて済むメリットがあります。
1位解法のprompt
5位解法のvllm推論コード
max_logprobsとlogprobsのサイズを選択肢の数に合わせている
# Initialize model
model = LLM(
model=gptq_name,
trust_remote_code=True,
gpu_memory_utilization=0.99,
max_logprobs=52,
dtype="bfloat16",
max_model_len=2048,
enforce_eager=True,
)
# Generate predictions
sampling_params = SamplingParams(temperature=0.0, max_tokens=1, logprobs=52)
list wiseでlogitを得るためにはsingle token(token size=1)で出力されるようにする必要があるため、1桁の数字(0-9),小文字アルファベット(a-z),大文字アルファベット(A-Z)などを選択肢のIDとして設定してpromptに含めていました。[5][7]
5位のprompt内のmisconception候補例
A: Thinks sign for parallel lines just means opposite
B: Does not recognise the notation for parallel sides
.
.
.
y: Thinks that co-interior angles can lie on opposite sides of a transversal
z: Believes the gradient of perpendicular lines just have opposite signs<|im_end|>
またlist wiseアプローチの場合、misconception候補をprompt内に列挙する際の位置のバイアス(1つ目のmisconceptionが選ばれやすい等)を受ける傾向もあったようです。そこで以下の工夫がなされていました。
- 学習時にプロンプトに与えるmisconceptionの候補数Nを複数パターン用意する[2]
- 推論時にrerankerに与えるpromptの候補順そのままと逆順の平均を使う[2] [6]
Multi Stage Reranking
rerankingを複数回実行することで性能を改善しているチームもいました。
-
上位のrerankingほど大きいモデルを利用する[1]
- スコア向上のために重要な上位ランクの予測に対して適切にモデルリソースを割くためにretrieval順位に応じて段階的にLLMのモデルサイズを大きくしてrerankingを実行する
- スコア向上のために重要な上位ランクの予測に対して適切にモデルリソースを割くためにretrieval順位に応じて段階的にLLMのモデルサイズを大きくしてrerankingを実行する
-
sliding window形式で複数回推論[2]
- sliding window形式で複数回rerankingを実行することで精度を高める
- sliding window形式で複数回rerankingを実行することで精度を高める
上位解法のrerankerで使用されたモデル一覧
モデル | 使用チーム |
---|---|
Qwen/Qwen2.5-32B | 1,3,4,5,7,8,9,10,12 |
Qwen/Qwen2.5-72B | 1,2 |
Qwen/Qwen2.5-14B | 1,12 |
lama3.3-70B | 2 |
4.Post Process
未知のmisconceptionに対するrerankingスコア補正
testデータには未知のMisconceptionが多いと予想されていましたが、trainデータにない未知のMisconceptionは予測スコアが小さくなる傾向にあります。そこで未知のMisconceptionに対する予測スコアを補正する後処理をしてスコアを伸ばしているチームもいました。この後処理もコンペの性質上強力なテクニックだったようで、特に3位のチームはSyntheticデータを利用しておらず、この後処理でスコアが大きく改善したと言及があることから補正によってSyntheticデータによる未知Misconception対策を代替してるとも考えられそうです。
- testデータへtop1のmisconceptionにおいて未知のmisconceptionの割合が75%となるように未知のMisconceptionに対する予測スコアを線形探索で定数倍した[3]
- trainデータに存在するMisconceptionの予測スコアを0.4倍に下げる[4]
- 未知のMisconceptionの順位を上げる[7]
5.LLMテクニック
- 知識蒸留
- 知識蒸留とはモデルサイズが大きい教師モデルの知識をよりモデルサイズが小さい生徒モデルへ引き継がせる手法
- 質問と回答を教師LLMに入力してMisconceptionを予測させこれをretrieverやrerankerへのpromptに追加する[4][5][9]
- 質問と回答を教師LLMに入力してその回答が導かれる推論過程(CoT)を推論させこれをretrieverやrerankerのpromptに追加する[1][2]
- 1位はclaude3.5 SonnetのCoTを模倣するようにCoT生成用のQwenモデルをFine-Tuningしていた
- vllmで
enable_prefix_caching=True
- prefix_cachingを有効にすることで推論時間を短縮
- https://docs.vllm.ai/en/stable/automatic_prefix_caching/apc.html
- LoRAコンポーネントのアンサンブル[4]
- 複数のLoRAパラメータを平均することで推論回数は1回に抑えつつ複数モデルによるアンサンブル効果が期待できる
- vllmを用いたembedding処理の高速化[4]
- テキスト生成用のLLMを埋め込みモデルとして利用する場合そのままvllmで推論することはできないためvllm側の実装を修正して埋め込み計算をvllmで高速化した
-
intel/auto-roundによる量子化[2]
- AutoGPTQやAutoAWQと比較して使いやすくこ精度の低下も最小限でvllmと互換性がある
- Misconception embeddingの事前計算[7]
- アンサンブル時の計算時間を短縮するためにmisconception埋め込みを事前に計算しておく
- その埋め込みと各モデルのquery埋め込みの類似度をとることでmisconceptionの埋め込み計算を1回で済ませる
6.CV戦略
trainとtestで重複する質問が存在しないことからQuestionIdでGroupKFoldをする方法がdiscussionやnotebookでは利用されていましたが、この方法だと未知のmisconceptionの割合がvalidとtestで異なるためCVとLBの乖離が大きく相関も取りずらい状況でした。これを踏まえいくつかのチームでは以下のようなcv戦略を取っていたようです。
- ConstructIdによるGroupKFold[1]
- QuestionIDでGroupKFold。ただし未知のmisconceptionに対するスコアを個別で見る[4]
- SubjectIdに基づくGroupKFold[5]
- cv計算時に未知のmisconceptionの重みを調整[6]
おわりに
このコンペでは、すべての解法パートにLLMが活用されており、その有用性を強く実感することができました。特に、LLMを用いたSyntheticデータ生成は、今後のコンペでも有効な手法として注目すべきであり、ぜひ1つの選択肢として取り入れていきたいと感じました。
最後までお読みいただきありがとうございました。この記事を通じてこのコンペの面白さや知見が少しでも皆様に伝われば嬉しいです。
参考文献
[1]1st Place Detailed Solution
[2]2nd place solution
[3]3rd Place Solution (with Magic Boost)
[4]4th Place Solution
[5]5th Place Solution
[6]6th Place Solution
[7]Private 7th (Public 2nd) Place Solution Summary
[8]8th Place Solution
[9]Private 9th (Public 7th) Place Solution
[10]10st Place Solution Summary
[11]Private 11th (Public 9th) Place Solution Summary
[12]12th Place Solution
Discussion