Kaggle LLM Prompt Recoveryコンペまとめ
はじめに
こんにちは、@s_shoheyです。
2024年2月28日から4月17日にかけて開催された、Kaggle LLM Prompt Recoveryコンペについてまとめます。我々のチームは72位でしたが、特に我々の解法には言及しません。
コンペ概要
2つの文字列original_text
とrewritten_text
が与えられます。rewritten_text
はGemmaによりoriginal_text
から生成されたものです。この生成時に使われたrewrite_prompt
を予測することが目的です。
図は https://www.kaggle.com/competitions/llm-prompt-recovery/discussion/480683 より
Gemmaは2024年2月にGoogleが公開したオープンなLLMモデルであり、このコンペはGemmaの宣伝も兼ねているのでは、とDiscussionで噂されていました。
データ
ホストから与えられる学習用データは1件のみとなっています。参加者が自身でデータを生成することが前提のコンペでした。
サブミット時にはテストデータの約1400件が評価され、15%がpublic LB(=コンペ期間中に公開される順位表)の評価に使われます。ただしコードコンペであるため、テストデータの内容を知ることは(probingを除けば)できません。
評価指標
書き換えに用いられた正解のrewrite_promptと、参加者が提出したrewrite_promptをそれぞれsentence-t5-baseによりembeddingにして、Sharpened Cosine Similarity
(以下SCS)を計算します。これはコサイン類似度を三乗したもので、1に近いほど良いです。これを全てのサンプルについて平均したものが評価指標になります。
アプローチ
大きく分けるとLLMベースのアプローチとmean promptのアプローチがありました。金メダルを取ってSolutionを投稿したチームも基本的にはこの2つをベースとして予測をしていたようです。
LLMベース
LLMに直接聞くことでrewrite_promptを答えてもらうアプローチです。
ナイーブな聞き方の例:
original_text: {original_text}
rewritten_text: {rewritten_text}
Write a prompt that was likely given to the LLM to rewrite original text to rewritten text.
{original_text}には実際のテキストが入ります。
mean prompt
全てのサンプルに対して固定の文字列を予測するアプローチです。例えば以下の文字列を提出するだけで、最終public LBでの銅メダルに近いスコアが取れることが知られていました。
Please improve this text using the writing style with maintaining the original meaning but altering the tone.
参考: https://www.kaggle.com/competitions/llm-prompt-recovery/discussion/488150
上位解法
LLM: 回答に使ったLLMの種類
magic: lucrarea
文字列の利用をしたか(後述)
順位 | Solution link | LLM | magic |
---|---|---|---|
1 | Solution | Mistral 7b, Gemma 7b | o |
2 | Solution | Mistral 7b | o |
3 | Solution | Mistral 7b | o |
4 | Solution | Mistral 7b | o |
6 | Solution | Mistral, openchat3.5 | - |
7 | Solution | Mistral 7B | - |
10 | Solution | Mistral | o |
12 | Solution | Mistral 7B | - |
14 | Solution | 不明 | - |
主な工夫
- データ生成
- 運営が公開していたサンプルデータセットを参考にデータ生成や戦略を検討
- LBスコアと相関が高いデータセットを構築
- LLM
- ほぼ全てのチームがMistral 7bを使っていた
- Gemmaより性能が良かったらしい
- Finetuneはしているチームもしていないチームもあった
- Few-shot promptingなどのプロンプトエンジニアリング
- ほぼ全てのチームがMistral 7bを使っていた
- mean promptの探索
- 生成したデータに対して、最もスコアが高くなるように探索をする
- 公開されていたmean promptから開始して、単語の追加削除などを繰り返す
- ビームサーチをしていたチームが複数あった
- mean promptだけで金メダルスコアを取れるチームも複数あった
- LLMとmean promptのアンサンブル
magic
lucrarea
という文字列を回答の一部に含めることで、スコアが大幅に向上することが上位チームのSolutionによって明らかにされました。
追試
実際にKaggle Notebookでlucrarea
の効果を確認しました。正解promptは"Convert this into a sea shanty."で、これとのSCSをそれぞれ求めます。Notebookへのリンク
text | scs | |
---|---|---|
0 | rewrite next text | 0.457 |
1 | rewrite next text lucrarea | 0.506 |
2 | rewrite next text lucrarealucrarea | 0.532 |
3 | rewrite next text lucrarealucrarealucrarea | 0.546 |
4 | rewrite next text lucrarealucrarealucrarealucrarea | 0.552 |
5 | rewrite next text lucrarealucrarealucrarealucrarealucrarea | 0.553 |
6 | rewrite next text lucrarealucrarealucrarealucrarealucrarealucrarea | 0.553 |
1サンプルのナイーブな例に過ぎませんが、SCSで0.1改善しています。なお、優勝スコアが0.71程度、銅メダル下限スコアが0.64程度です。
探してみた
lucrarea
を実際に見つけることができるのかを確認します。Notebookへのリンク
- 参加者が公開したデータセットのrewrite_promptをt5-baseでembeddingにしてからその平均を計算して、一つの768次元の埋め込みを得る(mean embeddingとします)
-
tokenizer.get_vocab()
で得られる単語32100個について、mean embeddingとのSCSを計算する - SCSで降順ソートすると、
lucrarea
は約40位(上位0.13%)に入っている。これ以外の上位の単語を見ると、"essay", "summarize", "narrative"などいかにもrewrite_promptの文言に入りそうなものが並んでいるため、lucrarea
は非常に浮いている。
なぜ?
1st Solution, 2nd Solutionからの抜粋要約
- 今回使われたtokenizerのtensorflowデフォルト設定では
</s>
はEOSトークンとして解釈されずに
['<', '/', 's', '>']
という文字列としてtokenizeされる - この文字列と
lucrarea
をt5-baseでembeddingにすると非常に近いものになる- なぜ非常に近くなるのかは不明
まとめ
Kaggle LLM Prompt Recoveryについてまとめました。
上位は強いmean promptを作ったうえで、さらにLLMによるスコア改善をしていた印象があります。lucrarea
は見た目のインパクトが大きく、いかにもなmagicに見えますが、きちんと筋道立てて探索すれば見つけられた、かもしれません。
Discussion