KenLMのPerplexityで日本語Instructionデータセットの品質を評価できないか試す
モチベーション
LLMの開発は一般的に、
- 事前学習(知識の習得)
- Supervised Fine-Tuning(指示に対して適切な回答を行う能力の習得)
- DPOやPPO(人間が好む応答をする能力の習得。行われない場合もある)
という流れで行われます。
Supervised Fine-Tuningでは基本的にInstruction Tuningという手法が用いられ、この学習で使われるのがInstructionデータセットになります。
Instructionデータセットはプロンプトとそれに対する出力がセットになっているデータセットで、LLMに対して知識の使い方を教え込む重要なものになります。
LIMA論文では、Instruction Tuningに高品質な1000件のデータを用いた結果、GPT-4に匹敵する性能を得ることができたと主張されています。
しかし、日本語特化のLLMを作ろうとしたときの大きな課題になるのが、この高品質なInstructionデータが少ないということです。
公開されている日本語のInstructionデータセットは英語で書かれたものをGoogle翻訳やDeepLで翻訳したものがほとんどで、日本語として不自然なデータや日本文化に合わないようなデータが含まれているのが現状です。
これを1件1件見て削除や修正していくのがベストではあると思いますが、数万、数十万とあるデータを人力で評価していくのは心が荒みます。
ということで、この「Instructionデータセットの品質評価」を機械的に行えないか、というのが今回試したいことになります。
実験概要
- KenLMでkunishou/databricks-dolly-15k-jaのInstructionのPerplexityを計算
- データセット全体のうち、Perplexityの値が小さかった(日本語として自然な)30%と大きかった(日本語として不自然な)30%を抽出
- それぞれからランダムに1035レコードを取得したサブセットを3つずつ作成
- 1035件のデータを使ってcyberagent/calm-2をLoRAチューニング×3×2
- Stability-AI/lm-evaluation-harnessでJCommonsenseQA,MARC-ja,JSQuADを評価
実験詳細
KenLMでのスコアリング
事前学習のデータセットについて調べていた際に拝見したこちらの記事を参考にさせていただき、KenLMで計算されるPerplexityを使ってみることにします。
まず、kunishou/databricks-dolly-15k-jaのデータは下記のような形式をしています。
{
'category': 'closed_qa',
'instruction': 'ヴァージン・オーストラリア航空はいつから運航を開始したのですか?',
'index': '0',
'input': 'ヴァージン・オーストラリア航空(Virgin Australia Airlines Pty Ltd)はオーストラリアを拠点とするヴァージン・ブランドを冠する最大の船団規模を持つ航空会社です。2000年8月31日に、ヴァージン・ブルー空港として、2機の航空機、1つの空路を運行してサービスを開始しました。2001年9月のアンセット・オーストラリア空港の崩壊後、オーストラリアの国内市場で急速に地位を確立しました。その後はブリスベン、メルボルン、シドニーをハブとして、オーストラリア国内の32都市に直接乗り入れるまでに成長しました。',
'output': 'ヴァージン・オーストラリア航空は、2000年8月31日にヴァージン・ブルー航空として、2機の航空機で単一路線の運航を開始しました。'
}
categoryが指示されているタスクを表し、instructionが指示、outputがそれに対する回答、inputは必要に応じて用意される参照文になります。
KenLMでPerplexityを算出する対象が悩ましいのですが、
- タスクによりinputを持たないレコードが存在している
- outputは分類や一問一答系のタスクの場合は名詞だけや端的な文章になりやすい一方、説明を求められるようなタスクではしっかりとした文章になりやすいと思われる
という2点から、すべてのデータを統一的に評価できるinstructionのみに絞って算出しました。
ただし、このあとの処理でPerplexityの良し悪しの判断をcategoryごとに分けて行ったことを考えると、instruction + outputを対象に算出するのが妥当だったようにも思います。
Perplexityの算出語、上述の通り、categoryごとにPerplexityが高かった30%と低かった30%を抽出しました。
df_large_30 = df.groupby('category').apply(lambda x: x.nlargest(int(len(x)*0.3), 'score'))
df_small_30 = df.groupby('category').apply(lambda x: x.nsmallest(int(len(x)*0.3), 'score'))
ここで、Perplexityが大きかったデータと小さかったデータを見比べてみます。
※inputとoutputが長文の場合には適当なところでカットしてます
Perplexityが小さかったデータ
category | instruction | input | output | Perplexity |
---|---|---|---|---|
closed_qa | サー・ヘンリー・クリントン将軍とは | サー・ヘンリー・クリントン将軍(KB、1730年4月16日 - 1795年12月23日)は、1772年から1795年にかけ... | サー・ヘンリー・クリントン将軍(KB、1730年4月16日 - 1795年12月23日)は、1772年から1795年にかけ... | 173.37655124180176 |
general_qa | 宇宙への答えはあるのか? | 小説『銀河ヒッチハイク・ガイド』によると、人生、宇宙、すべての答えは42だそうです。 | 226.3177359142668 | |
open_qa | 人は何を描くのか? | 絵を描くことは、誰にでもできる活動です。 芸術家は職業として絵を描いていますが、一般の人も絵を... | 402.8290339525847 |
Perplexityが大きかったデータ
category | instruction | input | output | Perplexity |
---|---|---|---|---|
closed_qa | Linux セカンドステージブートローダ | GNU GRUB、rEFInd、BOOTMGR、Syslinux、NTLDR、iBootなどの第2段ブートローダは、それ自体はオペ... | 2ステージブートローダは、実際には2つのブートローダが互いの後に構成され... | 3901373.393502721 |
brainstorming | ベストベンガルシンガー | キショール・クマール | 3901373.393502721 | |
creative_writing | アメリカズカップカムバック | 2013年9月、サンフランシスコ湾で、スポーツ界を代表するカムバックが行われた。... | 3901373.393502721 |
Perplexityが大きかったものについてはinstructionが指示を与えるような文章になっておらず、品質としては決して良いものとは言えないように思います。
一方で、Perplexityが小さかったデータについては指示を与える文章になっているようです。
instructionの評価としてはある程度機能しているように思うので、このまま実験を続けていきます。
学習に使うサブセットの作成
単純なランダムサンプリングをseed値を変えて実行し、低Perplexityグループと高Perplexityグループのそれぞれで3組ずつのサブセットを作成しました。
def sample_df(df, seed):
sampled_df = df.apply(lambda x: x.sample(frac=0.23, random_state=seed))
sampled_df = sampled_df.reset_index(drop=True)
return sampled_df
(30%ずつに分けたときにはcategoryの割合を変えなかったのに、ここではそれを考慮していないのは完全にミスです...)
LoRAでの学習
ベースモデルはcyberagent/calm2-7bとして、1035レコードのデータを2epoch学習させました。
プロンプトテンプレートは下記の通りです。
# プロンプトテンプレートの準備
def generate_prompt(data_point):
if data_point["input"]:
result = f"""以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。リクエストを適切に完了するための回答を記述してください。
### 指示:
{data_point["instruction"]}
### 入力:
{data_point["input"]}
### 応答:
{data_point["output"]}"""
else:
result = f"""以下に、あるタスクを説明する指示があります。リクエストを適切に完了するための回答を記述してください。
### 指示:
{data_point["instruction"]}
### 応答:
{data_point["output"]}"""
return result
その他、学習に関係するパラメータは下記の通りです。
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_8bit=True,
device_map="auto",
)
lora_config = LoraConfig(
r= 8,
lora_alpha=16,
target_modules=['gate_proj', 'v_proj', 'q_proj', 'k_proj', 'up_proj', 'down_proj', 'o_proj'],
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM
)
save_steps = 200
logging_steps = 20
args =transformers.TrainingArguments(
num_train_epochs=2,
learning_rate=3e-4,
do_eval=False,
logging_steps=logging_steps,
save_strategy="steps",
save_steps=save_steps,
output_dir=output_dir,
report_to="wandb",
run_name=lora_name,
save_total_limit=3,
push_to_hub=False,
auto_find_batch_size=True,
)
# トレーナーの準備
trainer = transformers.Trainer(
model=model,
train_dataset=dataset,
args=args,
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
評価
評価にはStability-AI/lm-evaluation-harnessを使用させていただきました。
評価指標はJCommonsenseQA(3-shot), MARC-ja(3-shot), JSQuAD(2-shot)の3つとしています。
結果
モデルごとの結果が下記の表になります
dataset | jcommonsenseqa | marc-ja | jsquad_em | jsquad_f1 |
---|---|---|---|---|
低1 | 0.7632 | 0.8539 | 32.4178 | 40.8709 |
低2 | 0.7703 | 0.9079 | 23.4354 | 30.3483 |
低3 | 0.7900 | 0.8811 | 21.3868 | 27.9210 |
高1 | 0.7900 | 0.7839 | 29.2211 | 38.4331 |
高2 | 0.7873 | 0.8958 | 23.2778 | 29.4238 |
高3 | 0.7641 | 0.8633 | 32.0351 | 42.1141 |
そして、3つのモデルのスコアを平均したものが下記の表になります。
dataset | jcommonsenseqa | marc-ja | jsquad_em | jsquad_f1 | mean |
---|---|---|---|---|---|
低平均 | 0.7745 | 0.8810 | 25.7467 | 33.0467 | 0.6376 |
高平均 | 0.7805 | 0.8477 | 28.1780 | 36.6570 | 0.6366 |
MARC-jaでは高品質であると想定していたPerplexityが低かったデータセット群で学習したモデルが上回りましたが、JCommonsenseQAとJSQuADでは下回る結果となりました。
また、meanはJCommonsenseQA,MARC-ja,JSQuAD(jsquad_em/100)を平均したものですが、ほとんど差がありませんでした。
結論
KenLMでPerplexityを計算することでInstructionデータセットの品質評価が行えるのではないかと仮定して実験を行いましたが、instruction部分のみを対象としたフィルタリングではほとんど差がないという結果になりました。
すべてのデータに共通している部分だからということでinstructionのみを対象にしましたが、outputも評価対象に含めることも必要かもしれません。
また、Perplexityでの判定以外にも方法はあると思うので、思いつき次第いろいろと試していこうと思います。
ToDo
inputとoutputも含めたPerplexityでのフィルタリング
KenLM以外のスコアリング方法の検討
高品質テキストと低品質テキストの分類器を作成して使えないか実験中
1035レコードのサブセットの作成方法の検討
embedding→k-meansで多様なデータをサンプリングできないか
Discussion