📑

KenLMのPerplexityで日本語Instructionデータセットの品質を評価できないか試す

2024/01/21に公開

モチベーション

LLMの開発は一般的に、

  1. 事前学習(知識の習得)
  2. Supervised Fine-Tuning(指示に対して適切な回答を行う能力の習得)
  3. DPOやPPO(人間が好む応答をする能力の習得。行われない場合もある)

という流れで行われます。

Supervised Fine-Tuningでは基本的にInstruction Tuningという手法が用いられ、この学習で使われるのがInstructionデータセットになります。
Instructionデータセットはプロンプトとそれに対する出力がセットになっているデータセットで、LLMに対して知識の使い方を教え込む重要なものになります。

LIMA論文では、Instruction Tuningに高品質な1000件のデータを用いた結果、GPT-4に匹敵する性能を得ることができたと主張されています。

しかし、日本語特化のLLMを作ろうとしたときの大きな課題になるのが、この高品質なInstructionデータが少ないということです。
公開されている日本語のInstructionデータセットは英語で書かれたものをGoogle翻訳やDeepLで翻訳したものがほとんどで、日本語として不自然なデータや日本文化に合わないようなデータが含まれているのが現状です。
これを1件1件見て削除や修正していくのがベストではあると思いますが、数万、数十万とあるデータを人力で評価していくのは心が荒みます。

ということで、この「Instructionデータセットの品質評価」を機械的に行えないか、というのが今回試したいことになります。

実験概要

  1. KenLMでkunishou/databricks-dolly-15k-jaのInstructionのPerplexityを計算
  2. データセット全体のうち、Perplexityの値が小さかった(日本語として自然な)30%と大きかった(日本語として不自然な)30%を抽出
  3. それぞれからランダムに1035レコードを取得したサブセットを3つずつ作成
  4. 1035件のデータを使ってcyberagent/calm-2をLoRAチューニング×3×2
  5. Stability-AI/lm-evaluation-harnessでJCommonsenseQA,MARC-ja,JSQuADを評価

実験詳細

KenLMでのスコアリング

事前学習のデータセットについて調べていた際に拝見したこちらの記事を参考にさせていただき、KenLMで計算されるPerplexityを使ってみることにします。

https://zenn.dev/syoyo/articles/529ce949121ca4

まず、kunishou/databricks-dolly-15k-jaのデータは下記のような形式をしています。

{
    'category': 'closed_qa',
    'instruction': 'ヴァージン・オーストラリア航空はいつから運航を開始したのですか?',
    'index': '0',
    'input': 'ヴァージン・オーストラリア航空(Virgin Australia Airlines Pty Ltd)はオーストラリアを拠点とするヴァージン・ブランドを冠する最大の船団規模を持つ航空会社です。2000831日に、ヴァージン・ブルー空港として、2機の航空機、1つの空路を運行してサービスを開始しました。20019月のアンセット・オーストラリア空港の崩壊後、オーストラリアの国内市場で急速に地位を確立しました。その後はブリスベン、メルボルン、シドニーをハブとして、オーストラリア国内の32都市に直接乗り入れるまでに成長しました。',
    'output': 'ヴァージン・オーストラリア航空は、2000831日にヴァージン・ブルー航空として、2機の航空機で単一路線の運航を開始しました。'
}

categoryが指示されているタスクを表し、instructionが指示、outputがそれに対する回答、inputは必要に応じて用意される参照文になります。

KenLMでPerplexityを算出する対象が悩ましいのですが、

  • タスクによりinputを持たないレコードが存在している
  • outputは分類や一問一答系のタスクの場合は名詞だけや端的な文章になりやすい一方、説明を求められるようなタスクではしっかりとした文章になりやすいと思われる

という2点から、すべてのデータを統一的に評価できるinstructionのみに絞って算出しました。
ただし、このあとの処理でPerplexityの良し悪しの判断をcategoryごとに分けて行ったことを考えると、instruction + outputを対象に算出するのが妥当だったようにも思います。

Perplexityの算出語、上述の通り、categoryごとにPerplexityが高かった30%と低かった30%を抽出しました。

df_large_30 = df.groupby('category').apply(lambda x: x.nlargest(int(len(x)*0.3), 'score'))
df_small_30 = df.groupby('category').apply(lambda x: x.nsmallest(int(len(x)*0.3), 'score'))

ここで、Perplexityが大きかったデータと小さかったデータを見比べてみます。
※inputとoutputが長文の場合には適当なところでカットしてます

Perplexityが小さかったデータ

category instruction input output Perplexity
closed_qa サー・ヘンリー・クリントン将軍とは サー・ヘンリー・クリントン将軍(KB、1730年4月16日 - 1795年12月23日)は、1772年から1795年にかけ... サー・ヘンリー・クリントン将軍(KB、1730年4月16日 - 1795年12月23日)は、1772年から1795年にかけ... 173.37655124180176
general_qa 宇宙への答えはあるのか? 小説『銀河ヒッチハイク・ガイド』によると、人生、宇宙、すべての答えは42だそうです。 226.3177359142668
open_qa 人は何を描くのか? 絵を描くことは、誰にでもできる活動です。 芸術家は職業として絵を描いていますが、一般の人も絵を... 402.8290339525847

Perplexityが大きかったデータ

category instruction input output Perplexity
closed_qa Linux セカンドステージブートローダ GNU GRUB、rEFInd、BOOTMGR、Syslinux、NTLDR、iBootなどの第2段ブートローダは、それ自体はオペ... 2ステージブートローダは、実際には2つのブートローダが互いの後に構成され... 3901373.393502721
brainstorming ベストベンガルシンガー キショール・クマール 3901373.393502721
creative_writing アメリカズカップカムバック 2013年9月、サンフランシスコ湾で、スポーツ界を代表するカムバックが行われた。... 3901373.393502721

Perplexityが大きかったものについてはinstructionが指示を与えるような文章になっておらず、品質としては決して良いものとは言えないように思います。
一方で、Perplexityが小さかったデータについては指示を与える文章になっているようです。
instructionの評価としてはある程度機能しているように思うので、このまま実験を続けていきます。

学習に使うサブセットの作成

単純なランダムサンプリングをseed値を変えて実行し、低Perplexityグループと高Perplexityグループのそれぞれで3組ずつのサブセットを作成しました。

def sample_df(df, seed):
    sampled_df = df.apply(lambda x: x.sample(frac=0.23, random_state=seed))
    sampled_df = sampled_df.reset_index(drop=True)
    return sampled_df

(30%ずつに分けたときにはcategoryの割合を変えなかったのに、ここではそれを考慮していないのは完全にミスです...)

LoRAでの学習

ベースモデルはcyberagent/calm2-7bとして、1035レコードのデータを2epoch学習させました。
プロンプトテンプレートは下記の通りです。

# プロンプトテンプレートの準備
def generate_prompt(data_point):
    if data_point["input"]:
        result = f"""以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。リクエストを適切に完了するための回答を記述してください。

### 指示:
{data_point["instruction"]}

### 入力:
{data_point["input"]}

### 応答:
{data_point["output"]}"""
    else:
        result = f"""以下に、あるタスクを説明する指示があります。リクエストを適切に完了するための回答を記述してください。

### 指示:
{data_point["instruction"]}

### 応答:
{data_point["output"]}"""
    return result

その他、学習に関係するパラメータは下記の通りです。

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    device_map="auto",
)
lora_config = LoraConfig(
    r= 8,
    lora_alpha=16,
    target_modules=['gate_proj', 'v_proj', 'q_proj', 'k_proj', 'up_proj', 'down_proj', 'o_proj'],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)
save_steps = 200
logging_steps = 20

args =transformers.TrainingArguments(
    num_train_epochs=2,
    learning_rate=3e-4,
    do_eval=False,
    logging_steps=logging_steps,
    save_strategy="steps",
    save_steps=save_steps,
    output_dir=output_dir,
    report_to="wandb",
    run_name=lora_name,
    save_total_limit=3,
    push_to_hub=False,
    auto_find_batch_size=True,
)

# トレーナーの準備
trainer = transformers.Trainer(
    model=model,
    train_dataset=dataset,
    args=args,
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

評価

評価にはStability-AI/lm-evaluation-harnessを使用させていただきました。

https://github.com/Stability-AI/lm-evaluation-harness

評価指標はJCommonsenseQA(3-shot), MARC-ja(3-shot), JSQuAD(2-shot)の3つとしています。

結果

モデルごとの結果が下記の表になります

dataset jcommonsenseqa marc-ja jsquad_em jsquad_f1
低1 0.7632 0.8539 32.4178 40.8709
低2 0.7703 0.9079 23.4354 30.3483
低3 0.7900 0.8811 21.3868 27.9210
高1 0.7900 0.7839 29.2211 38.4331
高2 0.7873 0.8958 23.2778 29.4238
高3 0.7641 0.8633 32.0351 42.1141

そして、3つのモデルのスコアを平均したものが下記の表になります。

dataset jcommonsenseqa marc-ja jsquad_em jsquad_f1 mean
低平均 0.7745 0.8810 25.7467 33.0467 0.6376
高平均 0.7805 0.8477 28.1780 36.6570 0.6366

MARC-jaでは高品質であると想定していたPerplexityが低かったデータセット群で学習したモデルが上回りましたが、JCommonsenseQAとJSQuADでは下回る結果となりました。
また、meanはJCommonsenseQA,MARC-ja,JSQuAD(jsquad_em/100)を平均したものですが、ほとんど差がありませんでした。

結論

KenLMでPerplexityを計算することでInstructionデータセットの品質評価が行えるのではないかと仮定して実験を行いましたが、instruction部分のみを対象としたフィルタリングではほとんど差がないという結果になりました。

すべてのデータに共通している部分だからということでinstructionのみを対象にしましたが、outputも評価対象に含めることも必要かもしれません。

また、Perplexityでの判定以外にも方法はあると思うので、思いつき次第いろいろと試していこうと思います。

ToDo

inputとoutputも含めたPerplexityでのフィルタリング

KenLM以外のスコアリング方法の検討

高品質テキストと低品質テキストの分類器を作成して使えないか実験中

1035レコードのサブセットの作成方法の検討

embedding→k-meansで多様なデータをサンプリングできないか

Discussion