👨‍💻

Kaggle Eedi 参加記録

に公開

はじめに

データアナリティクスラボ株式会社です。

今回はデータソリューション事業部のメンバーでKaggleの「Eedi - Mining Misconceptions in Mathematics」に参加しましたので、その取り組みや上位解法についてご紹介いたします。

参加メンバー

  • 宮澤:入社3年。マーケティング分野の分析・モデル構築に従事しており、社内では生成AIの研究活動を実施。
  • 力岡:入社2年半。金融分野の分析や生成AI関連の開発に取り組んでおり、社内では生成AIの研究活動を実施。
  • 平野:入社2年半。金融分野の分析を行っており、社内では生成AIの研究活動を実施。

結果

結果は168位 / 1449チームでした!
今回は参加メンバー全員があまり時間を作ることができず、振るわない結果となりました。

課題について

概要

今回は自然言語処理に関する課題でした。

Eedi - Mining Misconceptions in Mathematics

概要としては、選択式の算数の問題とそれに対する誤った回答が与えられ、その誤解の理由として適切なものを誤解理由リストから取得するタスクとなっていました。

簡単な例として以下のような問題があります。

https://www.kaggle.com/competitions/eedi-mining-misconceptions-in-mathematics/overview より引用

こちらの問題の正しい答えは A: 23 ですが、B: 13を回答した場合、どのように誤ったと考えられるでしょうか?

実際に13になるような計算手順を考えてみると、5 × 4 = 20 → 20 + 6 = 26 → 26 ÷ 2 = 13と計算していると考えられます。したがって適切な誤り理由は Carries out operations from left to right regardless of priority order. (優先順位に関係なく左から右へ計算している。)となります。

データ

与えられていたデータセットは以下の通りです。

train.csvには1,869件の問題が与えられていました。
また、誤解理由リストであるmisconception_mapping.csvは2,587件ありました。

https://www.kaggle.com/competitions/eedi-mining-misconceptions-in-mathematics/data より引用

評価指標

評価指標はMAP@25が使われました。(レコメンドなどのランキングの精度指標でよく使われます。)

評価指標を解釈すると、適切な誤解理由であると思われる25件のリストをランキングで出力し、正解がより高い順位にあるほどスコアが高くなる指標となっています。

https://www.kaggle.com/competitions/eedi-mining-misconceptions-in-mathematics/overview より引用

取り組み

上で述べたように、今回はメンバー全員があまり時間を作ることができなかったため、スコアの高い公開ノートブックをベースにして、一部手を加えるという解法で進めていました。

ベースライン(公開ノートブック)の解法

上位解法も含めてベースラインとして使われていたのが以下のようなTow-Towerのアプローチです。1段目で正しい誤解と予測される候補を抽出し、2段目で候補を並び替えてMAP@25のスコア向上を目指すという形です。

自分たちの取り組み

作業時間が限られていたため、主に1段目に焦点を当てて出来ることを考えて取り組みました。

合成データの生成

まず行ったことは合成データの生成です。提供された学習データには、一部のMisconceptionNameに関連するデータのみが含まれていたため、モデルの多様性と汎化性能を向上させる目的で、全2587種類のMisconceptionNameから、LLMを使用して以下の要素を含む合成データを作成しました。

ConstructName , SubjectName , QuestionText , CorrectAnswerText , WrongAnswerText

また、データ精度を高めるため、各MisconceptionNameを基に類似度の高いデータを抽出しました。この際、類似データの上位50件からランダムに15件をサンプリングし、以下のフォーマットでfew-shotの例として使用しました。

misconception: {MisconceptionName}
concept: {ConstructName}
subject: {SubjectName}
question: {QuestionText}
correctanswer: {CorrectAnswerText"}
wronganswer: {AnswerText}

実際に合成データ作成用に指定したプロンプトは以下の通りです。

SYSTEM_PROMPT = """Based on the reasons for the user’s incorrect answer, create an appropriate combination of a math problem and solution, and output it in JSON format.

# Requirements
- For each reason of incorrect understanding, structure the problem in the following format: concept, subject, question, wronganswer.
- The content of the problems should cover a range from elementary to high school math, incorporating knowledge applicable to different mathematical fields.
- Ensure the problems do not duplicate examples provided by the user, and instead include original problems.
- Output one set in JSON array format."""

USER_PROMPT = """Based on the reasons for the user's incorrect answers from the examples below, create the appropriate combination of applicable math questions and wronganswers.

# Examples
{examples}

# Misconception
{misconception}"""

この方法で、2,587種類のMisconceptionNameそれぞれについて10件ずつ問題を生成し、合計25,870件の合成データを作成しました。これらのデータは後続のモデルの学習に使用しました。

FlagEmbedingの実装

1段目では誤解理由の候補を抽出しますが、これを実現するモデルを学習するために、FlagEmbeddingというライブラリを使用しました。FlagEmbeddingは、BAAI(というAIの研究開発を行う民間非営利団体が開発したもので、LLMの検索機能を強化するためのプログラムを提供するプロジェクトです。

今回は、 FlagEmbedding/research/llm_dense_retriever というプロジェクト内のプログラムを使用して、Qwen/Qwen2.5-14B-Instruct-AWQを微調整し、LLMが誤答のパターンを識別し、誤りの原因を解析できるように追加学習を実施しました。(1st SolutionではFlagEmbedding/finetune/embedder/decoder_only が使われていました。コードの違いは確認していませんが、より適切なスクリプトがあったかもしれないと後で知りました。)

作業環境はGoogle ColaboratoryのProプランでA100 GPUを使用して学習しています。

データ形式

サンプルプログラムで使用されているサンプルデータを見ると、以下のような構造の学習データを用意する必要があることがわかります。

  • query: 検索用のクエリ
  • pos: 検索クエリに対して、検索結果として正解となるデータ
  • neg: 検索クエリに対して、検索結果として不正解となるデータ
  • category: カテゴリ(Nanでも可)
  • type: タイプ(normalでも可)
  • prompt: 検索用の指示プロンプト
  • pos_scores: 正解データのスコア値
  • neg_scores: 不正解データのスコア値

生成した合成データや提供されている学習データを、上記の学習フォーマットに従うように作成し直しました。参考として、実際に作成した学習データを以下のリンクに掲載しています。興味のある方はぜひご覧ください。

https://huggingface.co/datasets/rikioka/training_data_v1

ここで、クエリは以下のように問題種別を表すConstrcutNameやSubjectName、問題のテキスト、誤った回答などを用いて一つのテキストとしました。

# クエリのフォーマット
format = """This is a question related to {ConstructName} ({SubjectName}).

Question: {QuestionText}

Incorrect Answer: {MisAnswerText}
Incorrect Reason:
"""

プロンプトは以下のように設定しました。

prompt = "Get appropriate reasons for misconception for the question and misconception given below."

ポジティブサンプルは正解となる誤解理由であり、ネガティブサンプルはそれ以外の誤解理由です。ネガティブサンプルの選び方としては、小規模の事前学習済みEmbeddingモデルを用いてスコアが高いものを選んだり(比較的判別が難しいハードネガティブとして用いる)、ランダムに取得したりといくつか試しましたが、学習に時間がかかるため、どの選び方が最適かまではわかりませんでした。

学習の実行

学習の実行は非常にシンプルで、学習のための設定を与えてrun.pyを実行するだけで学習をスタートすることができます。(詳細はGithubのサンプルコードをご参照ください。)

私たちの実装においては以下のように設定しました。
上述したように、モデルはQwen/Qwen2.5-14B-Instruct-AWQを使っており、1件のポジティブサンプルに対して8件のネガティブサンプルを設定しました。AWQモデルを読み込むために autoawqをインストールしておく必要があります。

内部処理

コードを詳細に全て確認はしていませんが、ざっと確認したFlagEmbeddingの内部処理の仕組みについて簡単に説明しておきます。

大きな学習の仕組みは通常のEmbeddingモデルと同じであり、クエリに対してポジティブサンプルとのcos類似度を高く、ネガティブサンプルに対するcos類似度を小さくするように学習します。 各サンプルに対する教師スコアをデータとして与えているため、計算されたcos類似度と教師スコアの差分が損失となっており、それを最小化するようにコードが組まれているようでした。

クエリと各サンプルのcos類似度を求めるにはクエリや各サンプルの特徴ベクトルが必要になりますが、そのエンコード処理はmodeling.pyに書かれています。既存のLLMの多くはDecoder-onlyのモデルであるため、Encoderの出力を得るようには実装されていません。

そこで、modeling.pyのエンコード部分を見ると以下のようになっていました。

psg_out = self.model(**sub_features, return_dict=True, output_hidden_states=False)
p_reps = psg_out.last_hidden_state[:, -1, :]

これを見ると、おそらくバッチ内のサンプルに対してフォワードパスを通した時のlast_hidden_state (最後の隠れ層)における-1部分(シーケンスの最後のトークン)の隠れ状態を取得していると考えられます。つまり、Decoder-onlyのモデルで自己回帰的に処理してきた最後の状態をそのサンプルの特徴ベクトルとして使用していると読み取れました。

工夫した点

この処理を理解した時にはもうコンペ終盤であったため、私たちの取り組みとしてはこれをベースに精度が上げられる方法を考えました。

その一つが損失関数の変更です。上述した通り、デフォルトの実装ではクエリと各サンプルのcos類似度が、あらかじめ設定したスコアと近づくように学習されます。しかし実際にはどのサンプルがクエリとどの程度近いかは知ることができず(知っていたらそもそも学習する必要はなく)、暫定的にポジティブサンプルは0.995, ネガティブサンプルは0.05として設定していました。

ここで、教師スコアを用いない学習方法がないかを調べたところ、以下の論文が見つかりました。

https://arxiv.org/abs/2401.00368

本論文の主軸は損失関数ではないですが、使われている損失関数を確認したところ、InfoNCE loss が使われていました。この損失関数は数年前からすでに提唱されており、これやこの亜種も多くあり、対照学習でよく使われているようです。

\mathcal{L} = -\log \frac{\phi(q_{\text{inst}}^{+}, d^{+})}{\phi(q_{\text{inst}}^{+}, d^{+}) + \sum_{n_i \in \mathcal{N}} \phi(q_{\text{inst}}^{+}, n_i)}
\phi(q, d) = \exp\left(\frac{1}{\tau} \cos(h_q, h_d)\right)

今回はこのInfoNCE lossをFlagEmbeddingの損失関数としてカスタムして学習を行い、デフォルトの損失関数の結果と比較を行いました。(温度パラメータ \tau = 0.02

結果としては、InfoNCE lossを使ったことでprivate scoreが0.023上がりました。
デフォルトのスコアの設定(99.5など)がよくなかった可能性はありますが、その辺りを考えずに教師スコアなしで学習できるようになり、かつ精度としても向上したという意味では効果があったと考えます。

Public Score Private Score
デフォルト 0.327 0.354
InfoNCE loss 0.367 0.376

最終的にはこの2モデルと公開ノートブックで最も高かったモデル(同じくQwen2.5-14B-AWQに対するLoRA学習)をアンサンブルしたモデルが自分たちの最高スコアとなりました。(括弧内は(InfoNCE loss, デフォルト, 公開ノートブック)の加重平均の割合)

Public Score Private Score
公開ノートブック 0.482 0.448
アンサンブル (0.2, 0.2, 0.6) 0.460 0.455
アンサンブル (0.4, 0.4, 0.2) 0.395 0.404

時間がない中でしたが、結果的に最高スコアの公開ノートブックを超えられたのは一安心でした。

Embeddingモデルに対するLoRAの実装

FlagEmbeddingの存在を知る前は、既存のEmbeddingモデルを学習できないかを調査していました。(FlagEmbedingで学習したモデルの方が精度が高かったためこちらは途中で打ち切りました。)

手持ちのリソースでは、例えば約30MパラメータをもつBAAI/bge-small-en-v1.5 のようなモデルであればフルパラメータでファインチューニングすることができるかもしれませんが、やはり最近におけるモデルの大きさと精度の相関を踏まえるとより大きなEmbeddingモデルをファインチューニングしたいと考えていました。

そこで7B級のパラメータを持つEmbeddingモデルをLoRAで学習できないかを調査していましたが、記事や実装例がほとんど見つかりませんでした。

そこでLoRA学習に使われるpeftライブラリを直接見にいくと、feature_extraction についてのサンプルコードがありました。これをカスタムすることで、intfloat/e5-mistral-7b-instructなど7B級のEmbeddingモデルに対してLoRAを使って学習を行うことができました

プロンプトエンジニアリング

ここまではモデルのパラメータを微調整した学習についての取り組みでしたが、学習不要の取り組みとして、公開ノートブックのEmbeddingによる類似度抽出時に質問関連情報からEmbeddingを生成するプロンプトを改良するプロンプトエンジニアリングも試みました。

結論から言うと、ベースラインにしたプロンプトの完成度が高くほとんどスコア向上につながりませんでした。

<ベースラインのプロンプト>

PROMPT  = """Here is a question about{ConstructName}({SubjectName}).
Question:{Question}Correct Answer:{CorrectAnswer}Incorrect Answer:{IncorrectAnswer}You are a Mathematics teacher. Your task is to reason and identify the misconception behind the Incorrect Answer with the Question.
Answer concisely what misconception it is to lead to getting the incorrect answer.
No need to give the reasoning process and do not use "The misconception is" to start your answers.
There are some relative and possible misconceptions below to help you make the decision:

{Retrival}"""

問題文の内容を考慮した誤解理由が多かったため、以下の文言を追加したプロンプ多少スコアが向上しました。

Focus on:
1. The specific reasoning or process that would lead to the Incorrect Answer, based on the Question's context and setup.
2. How the Incorrect Answer might stem from misinterpreting the mathematical concept or procedure being tested.
3. The relationship between the Question's wording, Correct Answer, and Incorrect Answer to ensure the identified misconception aligns with the context.

上位解法

ここからは上位解法の取り組みについてまとめていきます。多様なアプローチと工夫がされており、非常に学びのある解法が多くありました。

本記事では、スコアの差を生んだと思われる以下の項目に焦点を当てて見ていきます。

  • 合成データの生成
  • Retieverの工夫
  • Rerankerの工夫
  • その他

本記事で引用・参照させていただいたソリューションは以下です。
(他の上位解法も様々な工夫がされており学びが多くありましたが、本記事では5位までのまとめとさせていただきました。)

パイプライン

前提として、上位解法では上で述べたTow-Towerのアプローチが使われていました。1段目で正しい誤解と予測される候補を抽出し、2段目で候補を並び替えてMAP@25のスコア向上を目指すという形です。

合成データの生成

まずは合成データの生成についてです。生成するデータセットはいくつかの目的に分けられると考えられました。

train.csvと同形式の問題・正解・誤り・誤解理由などを含むデータの生成

こちらは多くの参加者が取り組んでいたことだと思いますが、いかに高精度で品質の高いデータを作れるかという点では勝敗を大きく分けた部分であるとも言えると思います。特に1st Solutionで使われていたプロンプトが精巧に作られており驚きました。

1st Solution

  • Claude 3.5 Sonnetを使用してデータ生成。
  • 誤解理由に関連した参考例を追加するために事前にクラスタリングを実施。
  • 品質・難易度・トーン・言語が既存のデータと合うようにプロンプトを調整。
  • AnthropicsのMetapromptを参考にプロンプトエンジニアリング。
  • gpt-4oを用いたキュレーション(品質評価)
  • 合成データの生成はこちらのブログが参考になるとのこと。
データ生成のプロンプト
```
You will be generating Multiple Choice Questions (MCQs) that diagnose specific mathematical misconceptions. Here are the misconceptions you should focus on:
<misconceptions>
{cluster_misconceptions}
</misconceptions>
Here are reference MCQs that demonstrate how to effectively diagnose these misconceptions:
<reference_mcqs>
{reference_mcqs}
</reference_mcqs>
Your task is to generate {num_mcqs} new MCQs that diagnose misconceptions not already covered by the reference MCQs.
First, analyze the reference MCQs carefully:
1. For each reference MCQ, identify in your <analysis> tags:
   - Which misconception it targets
   - How the incorrect answers map to specific misconceptions
   - What makes the question effective at diagnosing the misconception
2. Note the style, difficulty level, and precision of language used
Then, in your <planning> tags:
- List which misconceptions still need coverage
- For each needed misconception, brainstorm mathematical contexts where it commonly appears
- Design questions where the misconception leads naturally to specific wrong answers
- Take notes on how you can craft new MCQs that adheres to the reference MCQs' style, difficulty level, and precision of language
Finally, generate new MCQs following these important guidelines:
- Make sure each incorrect answer maps clearly to exactly one misconception
- Use precise mathematical language matching the style of reference MCQs
- Make questions challenging enough that students must demonstrate real understanding
- Ensure wrong answers are plausible and stem from genuine misconceptions, not careless errors
- Use the exact wording of misconceptions as given in the misconceptions list
- Pay careful attention to subtle differences between the misconceptions and observe which one is the most appropriate for a given incorrect answer
- Keep the construct name and subject name as short as possible hiding the details of the misconception
- Questions should be of higher difficulty level than reference MCQs
```
キュレーションのプロンプト
```
You will analyze how well an incorrect answer reflects a suspected misconception in a mathematics problem. Your goal is to determine whether there is a clear, logical connection between the misconception and the wrong answer.
Here is the problem with both correct and incorrect answers. The suspected misconception is also provided:
<problem>
{PROBLEM_DATA}
</problem>
First, analyze the problem in your scratchpad:
<scratchpad>1. Solve the problem independently to verify the correct answer
2. Examine how someone holding the suspected misconception would approach the problem
3. Trace the logical path from misconception to incorrect answer
4. Identify any gaps or inconsistencies in this connection
</scratchpad>
Then provide your evaluation using this format:
<evaluation>1. Brief explanation of how the misconception could lead to the wrong answer
2. Score from 0-10 based on these criteria:
   - 10: Perfect alignment - wrong answer is direct result of misconception
   - 8-9: Strong alignment - clear logical path from misconception to answer
   - 5-7: Moderate alignment - connection exists but has some gaps
   - 1-4: Weak alignment - connection is unclear or requires assumptions
   - 0: No alignment - misconception does not explain wrong answer
</evaluation>
Important guidelines:
- Focus solely on the logical connection between misconception and wrong answer
- Do not speculate about other possible misconceptions
- Be specific about how the misconception leads to the error
- Flag and deduct scores if any assumptions are required to connect misconception to answer
- Consider whether a student with this misconception would consistently arrive at this wrong answer
```

2nd Solution

  • Few-shotで問題を生成。
  • 生成した問題に対してはQwen-mathで正しい答えと誤った答えを生成。
  • gpt-4o-miniを用いてキュレーション。5段階中2以上のものを使用。
  • 3回生成しており、3世代目のプロンプトはこちらのブログを参考にしている。

4th Solution

  • Qwen2.5-72B-Instruct-AWQ を使用してデータ生成。
  • train.csvに登場しない誤解理由を中心に作成した。

5th Solution

  • gemini-1.5-pro を使用してデータ生成。
  • train.csvに登場しない誤解理由を中心に作成した。
  • Few-shotを用いて生成した。その際にstella_en_1.5B_v5で類似度が高く算出されたものを例示として含めることでランダムサンプリングの時と比較してCVが向上した。

誤解理由の情報量を増やしたデータの生成

もとの誤解理由リストの各テキストは比較的短い文章が多かったため、LLMを用いて情報量を拡張することは有効であると思われ、上位解法でもその取り組みが見られるものがありました。

2nd Solution

  • llama3.1-70b-Instructqwen2.5-72b-Instruct を使用して誤解理由の情報量を拡張。
プロンプト
```
system_prompt_template = 'You are an excellent math teacher about to teach students of year group 1 to 14. The subject of your lesson includes Number, Algebra, Data and Statistics, Geometry and Measure. You will be provided a misconception that your students may have. Please explain the misconception in detail and provide some short cases when the misconception will occur. No need to provide the correct approach. The explanation should be in format of "Explanation: {explanation}"'

user_prompt_template = 'Misconception: {misconception}'
```

CoTによる推論過程を含むデータ

RetrieverおよびRerankerの精度を上げるためにモデルをCoTで推論するように微調整する取り組みも上位解法に見られました。これにはCoTの教師データ必要であるため、LLMを用いて合成データが生成されていました。

1st Solution

  • Claude 3.5 Sonnetを使用してデータ生成。
プロンプト
```
You will analyze a student's incorrect answer to identify the specific reasoning flaw that led to their error.
Your goal is to explain precisely how their misconception caused them to arrive at the wrong answer.
Here is the problem information:
<problem_data># Question: Simplify the expression: \[x \cdot y \cdot x\]
# Correct Answer: \(x^2y\)
# Incorrect Answer: \(x^2\)
# Primary Misconception: Ignores variables without explicit coefficients when multiplying
</problem_data>
Here are related misconceptions that are similar but do not explain this specific error as precisely:
<related_misconceptions>
- Thinks only like terms can be multiplied
- Fails to combine all instances of the same variable
- Incorrectly identifies an incomplete variable factor
- Does not understand how to multiply algebraic terms
</related_misconceptions>
First, examine all components of the problem carefully:
1. The problem statement and question asked
2. The correct answer and solution method
3. The student's incorrect answer
4. The primary misconception given
5. The related misconceptions that should be distinguished from the primary one
Then, reconstruct the student's likely thought process:
- Identify the exact point where their reasoning diverged from the correct solution path
- Note which specific mathematical operations or concepts they misapplied
- Connect their error directly to the stated primary misconception
- Verify that this explanation better fits the error than the related misconceptions
Write your analysis in <evaluation> tags, following this structure:
- Show the correct calculation first
- Show the incorrect calculations that demonstrate the error
- Explain the specific flaw in the student's reasoning
- Demonstrate how the misconception led to this particular error
- Distinguish from the related misconceptions
- Keep your explanation to 5-6 clear, non-repetitive sentences
- Focus solely on the reasoning that produced this specific error
Guidelines for writing your explanation:
- Do not restate the problem or name the misconception
- Be precise about the mathematical concepts involved
- Show exactly how the misconception led to the error
- Distinguish from related misconceptions
- Avoid repetition
- Stay focused on this specific error
```

2nd Solution

  • qwen2.5-32B-Instruct-AWQ を使用してデータ生成。
プロンプト
```
system_prompt_template = "You are an excellent math teacher about to teach students of year group 1 to 14. The detail of your lesson is as follows. Subject:{first_subject}, Topic: {second_subject}, Subtopic {third_subject}. Your students have made a mistake in the following question. Please explain the mistake step by step briefly and describe the misunderstanding behind the wrong answer at conceptual level. No need to provide the correct way to achieve the answer."

user_prompt_template = "Question: {question_text}\nCorrect Answer: {correct_text}\nWrong Answer of your students: {answer_text}\n\nExplanation: \nMisunderstanding: "
```

Retieverの工夫

1段目であるRetieverについてです。使っているモデルの違いや、注目した評価指標の違いなど、解法ごとに工夫がされていました。

1st Solution

  • e5-mistral-7b-instruct, bge-en-icl, Qwen2.5-14B のモデルのアンサンブルを使用。
  • 損失関数にMultipleNegativesRankingLossを使用して微調整。
  • スクリプトはFlagEmbedding/finetune/embedder/decoder_onlyをベースに使用。
  • 取りこぼしを減らすため、MAP@25ではなくRecallに注目。
  • 上位32件の候補を選択。さらに、上位候補のスコアから0.06以内の候補を追加で最大32件取得するという動的な閾値で候補を抽出。
  • ハードマイニング(難しいネガティブ例を集中的に学習)はRecall@32の向上に寄与しなかったため採用せず。
  • 1つのバッチ内に同じ正例を複数含めないように調整。

2nd Solution

  • Rerankingの推論時間も考慮して、Linq-Embed-Mistralと2つのQwen2.5-14Bのアンサンブルを使用。
  • Linq-Embed-Mistral の損失関数にはArcface lossを使用。
  • Qwen2.5-14B の損失関数にはMultipleNegativesRankingLossを使用。
  • キュレーション、誤解理由の拡張、CoTの出力結果の追加、最終トークンのプーリングが効いたとのこと。

4th Solution

  • Qwen2.5-14B-instructQwen2.5-32B-instruct の出力を連結して重複削除。
    • すべての誤解理由から上位25件、train.csvに登場しなかった誤解理由から上位15件をそれぞれ取得。これらから重複を削除したリストを取得。
  • Qwen2.5-32B-instruct-AWQ から生成した誤解理由テキストも入力に追加。
  • ポジティブサンプル1件につき47件のネガティブサンプルで1バッチを構成。
  • MAP@25だけでなくRecallに注目。
  • 損失関数はMultipleNegativesRankingLossを使用。
  • vllmをカスタムして埋め込み計算を高速化。

5th Solution

  • stella_en_1.5B_v5 を使用。
  • Qwen 2.5 32B Instruct を使って誤解理由を生成してRetriverの学習に加える(蒸留する)ことでCVが向上。
  • 損失関数はCachedMultipleNegativesRankingLossを使用。
  • 104件の候補を抽出。

Rerankerの工夫

2段目であるRerankingについてです。Rerankingは非常に重要でありスコアや順位に違いが出る部分であったように感じます。

1st Solution

  • 段階的なRerakingのアプローチを採用。
    • 第1段階: Qwen/Qwen2.5-14B (候補を上位8件に絞り込む, Pointwise)
    • 第2段階: Qwen/Qwen2.5-32B (候補を上位5件に絞り込む, Pointwise)
    • 第3段階: Qwen/Qwen2.5-72B (最終ランキング, Listwise)
  • CoTを用いた学習。外部推論の利用と内部推論への依存をうまく学習させるために微調整データセットのうち50%にのみCoTを使用
  • Few-shot Learningの使用。トレーニング中に少数のデモンストレーション例を提示。特にQwen2.5-14Bの性能向上に寄与。
  • ポジティブ例1件に対してネガティブ例24件をトレーニングに含めることで、モデルの精度を向上。
  • Pseudo Labelingの使用。Qwen2.5-Math-72BQwen2.5-72Bの2つのモデルを使用して合成データに擬似ラベルを付与。このデータを用いてQwen2.5-14BおよびQwen2.5-32Bモデルを再訓練。

2nd Solution

  • Listwiseでランク付け。
  • スライディングウィンドウを複数回適用してランキングを段階的に並び替える。
    • Window1: Qwen/Qwen2.5-14B-Instruct を用いて8位~17位を並び替え。
    • Window2: Qwen/Qwen2.5-72B-Instruct, meta-llama/Llama-3.3-70B-Instruct を用いて1位~10位を並び替え。
  • vllmでlogits_processorsを使用してランクづけ。(直接言及はなかったがディスカッションではこちらのNVIDIAのコードがあがっていた。)
    • 特定のトークンに+100を加えることで注目したいトークンの範囲の中でどのトークンの対数尤度が高いかを算出できるようにしたと読み取れる。

4th Solution

  • Qwen2.5-32B-Instruct-GPTQ-Int4 をベースにして3つのモデルを作成。
    • モデル1と2は異なるデータ(Fold)で作成されたLoRAをマージする形で作成。また、これれはポジティブサンプルとネガティブサンプルが1:9で学習された。
    • モデル3は合成データのうち2,500件を含めたデータで学習。ポジティブサンプルとネガティブサンプルは1:19で学習。
  • おそらくListwise。

5th Solution

  • Qwen2.5-32B-Instruct をLoRAで学習。その後GPTQで量子化してvllmで推論。
  • 大文字と小文字のアルファベットを各候補のインデックスとしてプロンプトに入力してランク付けを行った。1回の推論で52件を入力として2回(52*2=104件)の推論を実行。
  • どのアルファベットを正解として出力するかの単一トークンに対するロジットを使用してソート(ランク付け)する方法を採用。
  • Retriverで使った知識蒸留の入力をここでも使用してCVが向上。

その他の工夫

2nd solution

  • vllmのenabling_prefix_cacheをTrueにすることで推論時間を約10%節約できた。
  • https://github.com/intel/auto-round を使った量子化。
    • AutoAWQやAutoGPTQよりも使いやすく精度低下も最小限であった。vllmとの互換性もあり。

3rd Solution

  • テストデータにおける誤解理由の出現の予測
    • 誤解理由データの多くが提供されたtrain/testデータにおいて未確認であることから、未確認の誤解理由がPublic/Privateデータセットの大部分になっているという仮説を立てた。それに対して選択する誤解理由を確認/未確認それぞれのみでスコア検証を行いスコア比率が1:3と(大まかに)予想。
    • 上記の予想を踏まえて、未確認の誤解理由の確率合計が全体の75%を占めるよう、未確認の誤解理由には定数Cを乗算することでスコアが向上した。

4th Solution

  • 学習データに出てきたが誤解の予測のスコアを小さく見積もるために0.4をかけた。

おわりに

今回の課題は、単純な分類のような自然言語処理タスクとは違い、数学的なテキストを理解させるための工夫と、レコメンド的なランク付けのテクニックが求められるようなタスクでした。上位解法のなかでも異なった工夫が様々されており、非常に学びの多いコンペでした。細かい工夫はさておき、個人的に改めて重要だと感じたのは、CVとLBの相関をしっかりとウォッチすることで、trainとtestの分布の違いを疑うということと、見るべき評価指標を吟味すること(今回で言えば1段目でMAP@25だけでなくRecallを見ること)でした。これらは基本的なことではあると思いますが、計算リソースに関係なく誰もができることなので、改めてその重要性を感じさせられました。

今回は作業時間を確保しきれずに復習が中心となってしまいましたが、2025年も引き続き入賞を目指して頑張りたいと思います!

DAL Tech Blog

Discussion