💽

ところでSFT Trainerに渡したdatasetってどうなるの?

2024/03/05に公開

年度末いかがお過ごしですか。(時候の挨拶)
社内で「ロリババアといえば」みたいな話をしていたのですが、真っ先に「ゆのは」を思い浮かべた自分は限界ヲタク。
冬になるとイチゴサンデー7つくらい食べないといけない謎の焦燥感に駆られる皆さん、こんにちは。

SFT Trainer内のdatasetの取り扱い

どうも、限界ヲタクの@ken11です。この季節になるとローソンで流れるone more time, one more chanceで発作を起こします。嘘です。

今回は唐突ですがSFT Trainerの話、特にdatasetの話をしたいと思います。

Dataset format support

公式ドキュメントを見ると、SFT TrainerにはDataset format supportというのがあります。
これはなにかというと、「ある特定のフォーマットでdatasetを渡してくれたらよしなにやっとくよ」という機能です。

一方で、 formatting_func という機能もあります。
これは上述のDataset format supportのフォーマットに従っていないdatasetを使う場合などに、「あなたの好きなフォーマッターでいい感じにできるようにするよ」という機能です。

後者は非常にわかりやすく、公式ドキュメント上の例を引用させていただくと

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['question'])):
        text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
        output_texts.append(text)
    return output_texts

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
)

このような形で、フォーマット用の関数を用意してあげればそれに従ってdatasetをいい感じにしてくれるわけです。
この機能はdatasetのフォーマットがSFT Trainerでサポートされているかどうかなどを気にする必要がなく、自分で好きに編集できるので、多用している方も多いのではないでしょうか?
実用上、自前でdatasetを用意する際に必ずしもサポートされているフォーマットに簡単にできるわけではないこともあると思うので、わりと好まれる機能に思います。

また、前者のサポートされているフォーマットを使った場合、そもそもその後SFT Trainer内でなにが起きるのかわからないから自前でformatting_func使うというケースもあるのではないでしょうか。
僕もそこが気になったので、今回はこのサポートされているフォーマットのデータを投げた場合なにが起きるのかを確認したいと思います。

サポートされているフォーマット

そもそもSFT Trainerでサポートされているフォーマットはどんなものでしょうか?
いまサポートされているのは ChatMLinstruction という2つの形式です。
こちらも公式ドキュメントから引用すると、前者が

{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]}

後者が

{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}

というフォーマットです。
大きく異なる2つのフォーマットがサポートされているので、それぞれSFT Trainer内では大きく違った形で扱われるのではないか?と思っていました。
でもどうやらそういうわけではなさそうです。

datasetの行方

formatting_func がない場合(つまりサポートされたdatasetフォーマットを使っていて formatting_func を指定しない場合)はここget_formatting_func_from_dataset という関数で formatting_func を取得しています。
そうなんですね、サポートされているフォーマットの場合、デフォルトでformatting_funcが用意されているというだけのことのようです。
内部的には結局ここで予め用意されている formatting_func でフォーマットされるわけです。

ではその予め用意されている formatting_func がどういうことをしているのか見ていきましょう。
この formatting_func の実体は dataset_formatting.py 内にあります。

https://github.com/huggingface/trl/blob/main/trl/extras/dataset_formatting.py

先ほどの get_formatting_func_from_dataset 内で、まずフォーマットで分類されます。
ChatML なのか instruction なのかをここで判断し、それぞれに用意された formatting_func を返します。

ChatML 形式の場合は conversations_formatting_function という関数から

def format_dataset(examples):
    if isinstance(examples[messages_field][0], list):
        output_texts = []
        for i in range(len(examples[messages_field])):
            output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
        return output_texts
    else:
        return tokenizer.apply_chat_template(examples[messages_field], tokenize=False)

このような formatting_func が返され、
instruction 形式の場合は instructions_formatting_function という関数から

def format_dataset(examples):
    if isinstance(examples["prompt"], list):
        output_texts = []
        for i in range(len(examples["prompt"])):
            converted_sample = [
                {"role": "user", "content": examples["prompt"][i]},
                {"role": "assistant", "content": examples["completion"][i]},
            ]
            output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
        return output_texts
    else:
        converted_sample = [
            {"role": "user", "content": examples["prompt"]},
            {"role": "assistant", "content": examples["completion"]},
        ]
        return tokenizer.apply_chat_template(converted_sample, tokenize=False)

このような formatting_func が返されます。

この実装内容を見ていくと、実は最終的な出力(フォーマッティングされたdataset内容)はそれぞれの形式で大きく変わらないことがわかります。

実際にChatMLとinstructionで実行した結果

実際の挙動を見てみましょう。

ChatML

from trl.extras.dataset_formatting import instructions_formatting_function, conversations_formatting_function

# 先ほどのformatting_funcを呼び出す
formatting_func = conversations_formatting_function(tokenizer, 'messages')

formatting_func({"messages": [{"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]})
<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

What's the capital of France? [/INST] ... </s><s>[INST] What's the capital of France? [/INST] ... </s>

instruction

from trl.extras.dataset_formatting import instructions_formatting_function, conversations_formatting_function

# 先ほどのformatting_funcを呼び出す
formatting_func = instructions_formatting_function(tokenizer)

formatting_func({"prompt": "<prompt text>", "completion": "<ideal generated text>"})
<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

<prompt text> [/INST] <ideal generated text> </s>

tokenizerがLlamaの場合、それぞれこのような形になりました。全く別の ChatMLinstruction というフォーマットなので、もっと異なる形の出力になるのかと思っていましたが、実際には同じ形で出力されました。
また、この最終的な出力フォーマットは formatting_func 内で apply_chat_template している結果なので、tokenizerがchat templateを持っていない場合に予期せぬ動作となる可能性がある点は注意が必要だと思います。

たとえばLlamaのtokenizerを使っていても、学習データのフォーマットは ### 指示:\n\n### 入力: \n\n### 応答:\n のような日本語化されたalpaca形式を期待している場合などは、datasetがサポートされている形式だからといってそのままSFT Trainerに渡すと期待された形にならないというのがわかります。
この場合、tokenizer側のchat templateを編集するか、あるいは自前で formatting_func を用意することが期待されます。

まとめ

というわけで、普段何気なくつかっているSFT Trainer内の、サポートされたdatasetの取り扱いについて見てみました。
ここまでわかると、どういったときにSFT Trainerに丸投げできてどういったときに自分で formatting_func を書くとよいのかよくわかりますね。
tokenizerのchat templateに満足いかないときは自分で formatting_func を書くことが重要そうです。

中を見ればなんてことないんですが、ドキュメントの内容だけだと判断できないこともたくさんあるので、こうやって実際に見ると面白いですね。
僕はコード読むのって謎解きみたいで好きです()

近年のLLMの開発はライブラリが充実していて簡単に学習を実行することができます。また、莫大なデータ量を投入することもあり、あまり細部まで気にしないで実行してしまうケースもあります。しかし、意外とデフォルトで用意されているものでは実は自分がやりたいことを実現できていなかったり、細部でズレていたりすることもあるので、今回のように実装内容を見てライブラリ内を理解していくことはとても重要と思っています。

そんな感じで、基本に忠実、丁寧なLLM開発に興味がある方、ご応募お待ちしております!

SpiralAIテックブログ

Discussion