SFTTrainerでdataset_text_fieldの引数を使う場合、学習データの構造次第では意図した学習が行われない可能性がある
はじめに
この記事では、私が個人的に遭遇したSFTTrainerでdataset_text_fieldの引数を使う場合の問題とその回避策についてまとめます。
※この記事の情報は記事執筆時点(2024/11/10)のものであり、その後のライブラリ更新等で状況が変わっている可能性があります。
ざっくりしたまとめ
まず最初に問題と解決策をざっくりとまとめます。
SFTTrainerでdataset_text_fieldの引数を使う場合、学習データにmessages、conversations、instructionという名前のついた列が存在しないことを確認する必要があります。
これらの名前が付いた列が存在し、かつこれらの列のデータが特定の形式(messagesまたはconversations列が[{"role": str, "content": str}]のようなリスト、instruction列が{"prompt": str, "completion": str}のような辞書)である場合、dataset_text_fieldで指定したフィールドが学習に使われず、これらの列をフォーマットしたデータが優先的に学習に利用されてしまうという問題があります。
この問題の回避策は大きく以下の3通りあると考えられます。
-
dataset_text_fieldを使わず、formatting_funcを使用して明示的に学習に使用するデータとフォーマット方法を指定する。 - データセットの
messages、conversations、instruction列の名前をあらかじめ変更する。 -
dataset.remove_columns(["messages", "conversations", "instruction"])などを使用して、これらの列をあらかじめデータセットから削除する。
以下、より詳細に解説します。
SFTTrainerによるLLMのSFTにおける学習対象フィールドの指定について
trlのSFTTrainerを使ってLLMのSFTを行う場合、データセット内の学習対象となるフィールドに関する情報をdataset_text_fieldまたはformatting_funcの引数で与えます。formatting_funcが指定されている場合はformatting_funcが優先され、formatting_funcがNoneの場合にdataset_text_fieldが参照されます。
1. dataset_text_fieldを使う場合
あらかじめフォーマットされたテキストを持つフィールドを学習対象として指定します。
例えば、以下のように元々あるデータセット内のquestionとresponseの列を元にtext列を作成し、それをSFTTrainerに渡します。
def formatting_prompts_func(examples, tokenizer):
questions = examples["question"]
responses = examples["response"]
texts = []
for question, response in zip(questions, responses):
messages = [
{"role": "user", "content": question},
{"role": "assistant", "content": response}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
texts.append(text)
return { "text": texts }
tokenizer = AutoTokenizer.from_pretrained("foo/bar")
train_dataset = load_dataset("hoge/piyo", split="train")
train_dataset = train_dataset.map(
formatting_prompts_func,
batched=True,
fn_kwargs={'tokenizer': tokenizer}
)
(中略)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
dataset_text_field="text",
)
2. formatting_funcを使う場合
学習データをフォーマットして学習対象となるテキストを作成するフォーマット用関数を定義し、それを引数として渡す方法です。
例えば、以下のように元々あるデータセット内のquestionとresponseの列を元に学習対象となるテキストを構成する関数を渡します。
def formatting_prompts_func(examples, tokenizer):
questions = examples["question"]
responses = examples["response"]
texts = []
for question, response in zip(questions, responses):
messages = [
{"role": "user", "content": question},
{"role": "assistant", "content": response}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
texts.append(text)
return texts
tokenizer = AutoTokenizer.from_pretrained("foo/bar")
train_dataset = load_dataset("hoge/piyo", split="train")
(中略)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
formatting_func=formatting_prompts_func,
)
dataset_text_fieldを使う場合の問題
本題に入ります。最初にも書きましたが、dataset_text_fieldを使う場合、元の学習データセット内にmessages、conversations、instructionという名前の付いた列があり、かつそれが特定の形式だと、この列をフォーマットしたデータが自動的に学習データとして使われてしまいます。
例えば、元のデータセット内にmessages、input、outputの3つの列があったとします。ここで、messagesとinput、outputの中身は異なっており、あくまで学習はinputとoutput列のデータに対して行いたいと考えているとしましょう。また、messages列の中はOpenAI messagesのようなroleとcontentを含んだ辞書のリスト形式になっているとします。(例:[{"role": "user", "content": "こんにちは"}, {"role": "assistant", "content": "こんにちは!"}])
この状態で、以下のようにdataset_text_fieldを使ってinputとoutputを元に作成した学習対象フィールドを指定しても意図した通りに学習されません。この場合、データセット内にmessages列が存在するため、これが優先して学習対象データの列として使われてしまい、text列のデータが学習に使われません。
def formatting_prompts_func(examples, tokenizer):
# inputとoutputのキーを元にtextを構成する
inputs = examples["input"]
outputs = examples["output"]
texts = []
for input, output in zip(inputs, outputs):
messages = [
{"role": "user", "content": input},
{"role": "assistant", "content": output}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
texts.append(text)
return { "text": texts }
tokenizer = AutoTokenizer.from_pretrained("foo/bar")
train_dataset = load_dataset("hoge/piyo", split="train")
train_dataset = train_dataset.map(
formatting_prompts_func,
batched=True,
fn_kwargs={'tokenizer': tokenizer}
)
(中略)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
dataset_text_field="text", # 学習対象としてtextキーを指定
)
問題が発生する原因
上記の問題が発生する原因をソースコードを読みながら解説します。
まず、学習対象となる実際のテキストはtrl/trainer/sft_trainer.py中の以下の部分で指定されています。
①packing=Falseの場合
②packing=Trueの場合
どちらの場合でも、formatting_funcがNoneの場合にdataset_text_fieldを利用し、そうでない場合はformatting_funcを利用するようになっています。
先述したdataset_text_fieldを引数に渡す方法の場合、formatting_funcを指定しなければ値はNoneになっているので、一見問題ないように見えます。しかし、formatting_funcがNoneの場合、sft_trainer.pyの以下の部分の処理が走り、get_formatting_func_from_dataset()関数が呼び出されます。
get_formatting_func_from_dataset()関数はtrl/extras/dataset_formatting.pyの中で以下のように定義されています。
また、FORMAT_MAPPINGは以下のように定義されています。
get_formatting_func_from_dataset()関数の実装を見ると分かるように、データセット内にmessages、conversations、instructionという名前のついた列が存在し、かつそれが特定の形式になっている場合、内部でformatting_funcにここで定義された関数が設定され、元々Noneだったformatting_funcの値がNoneではなくなってしまいます。
その結果、formatting_funcがNoneの場合のみに使われるdataset_text_fieldが利用されなくなり、messagesやconversations、instruction列をフォーマットしたデータが学習対象とされてしまい、想定通りの学習結果が得られないという問題が発生します。
回避策
解決策は最初に提示したような3つの方法が考えられます。
-
dataset_text_fieldを使わず、formatting_funcを使用して明示的に学習に使用するデータとフォーマット方法を指定する。
formatting_funcがNoneでなければ上記の処理は行われずそのまま使われるので、こちらを使う方が確実です。 - データセットの
messages、conversations、instruction列の名前をあらかじめ変更する。
データセット内に特定の名前の列が含まれているかどうかで判定されているので、これらの列名を変更しておけば判定に引っかかりません。 -
dataset.remove_columns(["messages", "conversations", "instruction"])などを使用して、これらの列をあらかじめデータセットから削除する。
2番目の方法と同様に、あらかじめこれらの列をデータセットから削除しておけば条件に引っかからずこの挙動は発生しません。
まとめ
この記事ではSFTTrainerでdataset_text_fieldの引数を使う場合の問題とその回避策についてまとめました。
根本的な解決策としては、formatting_funcを利用して学習対象のテキストを定義する方法が確実だと思います。
この記事が同様の問題に遭遇した方の一助になれば幸いです。
Discussion
有益な情報展開ありがとうございます!
自分もこちらの内容に当てはまる学習データを使っていたのでとても助かります・・・。
一応調べてみたところ、どうやら以下の時点(2024/10/4)で変更されたようです。。。
PR
https://github.com/huggingface/trl/pull/2078
差分
https://github.com/huggingface/trl/pull/2078/files#diff-67e157adfcd37d677fba66f610e3dfb56238cc550f221e8683fcfa0556e0f7ca
↓
「formatting_funcとdataset_text_fieldがない場合、get_formatting_func_from_dataset関数を呼び出す」から、
「formatting_funcがない場合、get_formatting_func_from_dataset関数を呼び出す」に変更されておりました。
それまではdataset_text_fieldがあれば問題なかったのですが・・・。
ありがとうございます。結構最近あった変更なんですね…。
私もこれに気付かずSFTをずっとしており500ドルくらい無駄にしました…。
500ドルは・・・きついですね😓
この記事のおかげで新たな犠牲者は減るかと思います。。。
私も変更に気付かずじまいだったかもしれないです・・・ありがとうございます🙏