SFTTrainerでdataset_text_fieldの引数を使う場合、学習データの構造次第では意図した学習が行われない可能性がある
はじめに
この記事では、私が個人的に遭遇したSFTTrainerでdataset_text_field
の引数を使う場合の問題とその回避策についてまとめます。
※この記事の情報は記事執筆時点(2024/11/10)のものであり、その後のライブラリ更新等で状況が変わっている可能性があります。
ざっくりしたまとめ
まず最初に問題と解決策をざっくりとまとめます。
SFTTrainerでdataset_text_field
の引数を使う場合、学習データにmessages
、conversations
、instruction
という名前のついた列が存在しないことを確認する必要があります。
これらの名前が付いた列が存在し、かつこれらの列のデータが特定の形式(messages
またはconversations
列が[{"role": str, "content": str}]
のようなリスト、instruction
列が{"prompt": str, "completion": str}
のような辞書)である場合、dataset_text_field
で指定したフィールドが学習に使われず、これらの列をフォーマットしたデータが優先的に学習に利用されてしまうという問題があります。
この問題の回避策は大きく以下の3通りあると考えられます。
-
dataset_text_field
を使わず、formatting_func
を使用して明示的に学習に使用するデータとフォーマット方法を指定する。 - データセットの
messages
、conversations
、instruction
列の名前をあらかじめ変更する。 -
dataset.remove_columns(["messages", "conversations", "instruction"])
などを使用して、これらの列をあらかじめデータセットから削除する。
以下、より詳細に解説します。
SFTTrainerによるLLMのSFTにおける学習対象フィールドの指定について
trlのSFTTrainerを使ってLLMのSFTを行う場合、データセット内の学習対象となるフィールドに関する情報をdataset_text_field
またはformatting_func
の引数で与えます。formatting_func
が指定されている場合はformatting_func
が優先され、formatting_func
がNone
の場合にdataset_text_field
が参照されます。
dataset_text_field
を使う場合
1. あらかじめフォーマットされたテキストを持つフィールドを学習対象として指定します。
例えば、以下のように元々あるデータセット内のquestion
とresponse
の列を元にtext
列を作成し、それをSFTTrainer
に渡します。
def formatting_prompts_func(examples, tokenizer):
questions = examples["question"]
responses = examples["response"]
texts = []
for question, response in zip(questions, responses):
messages = [
{"role": "user", "content": question},
{"role": "assistant", "content": response}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
texts.append(text)
return { "text": texts }
tokenizer = AutoTokenizer.from_pretrained("foo/bar")
train_dataset = load_dataset("hoge/piyo", split="train")
train_dataset = train_dataset.map(
formatting_prompts_func,
batched=True,
fn_kwargs={'tokenizer': tokenizer}
)
(中略)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
dataset_text_field="text",
)
formatting_func
を使う場合
2. 学習データをフォーマットして学習対象となるテキストを作成するフォーマット用関数を定義し、それを引数として渡す方法です。
例えば、以下のように元々あるデータセット内のquestion
とresponse
の列を元に学習対象となるテキストを構成する関数を渡します。
def formatting_prompts_func(examples, tokenizer):
questions = examples["question"]
responses = examples["response"]
texts = []
for question, response in zip(questions, responses):
messages = [
{"role": "user", "content": question},
{"role": "assistant", "content": response}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
texts.append(text)
return texts
tokenizer = AutoTokenizer.from_pretrained("foo/bar")
train_dataset = load_dataset("hoge/piyo", split="train")
(中略)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
formatting_func=formatting_prompts_func,
)
dataset_text_field
を使う場合の問題
本題に入ります。最初にも書きましたが、dataset_text_field
を使う場合、元の学習データセット内にmessages
、conversations
、instruction
という名前の付いた列があり、かつそれが特定の形式だと、この列をフォーマットしたデータが自動的に学習データとして使われてしまいます。
例えば、元のデータセット内にmessages
、input
、output
の3つの列があったとします。ここで、messages
とinput
、output
の中身は異なっており、あくまで学習はinput
とoutput
列のデータに対して行いたいと考えているとしましょう。また、messages
列の中はOpenAI messagesのようなrole
とcontent
を含んだ辞書のリスト形式になっているとします。(例:[{"role": "user", "content": "こんにちは"}, {"role": "assistant", "content": "こんにちは!"}]
)
この状態で、以下のようにdataset_text_field
を使ってinput
とoutput
を元に作成した学習対象フィールドを指定しても意図した通りに学習されません。この場合、データセット内にmessages
列が存在するため、これが優先して学習対象データの列として使われてしまい、text
列のデータが学習に使われません。
def formatting_prompts_func(examples, tokenizer):
# inputとoutputのキーを元にtextを構成する
inputs = examples["input"]
outputs = examples["output"]
texts = []
for input, output in zip(inputs, outputs):
messages = [
{"role": "user", "content": input},
{"role": "assistant", "content": output}
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
texts.append(text)
return { "text": texts }
tokenizer = AutoTokenizer.from_pretrained("foo/bar")
train_dataset = load_dataset("hoge/piyo", split="train")
train_dataset = train_dataset.map(
formatting_prompts_func,
batched=True,
fn_kwargs={'tokenizer': tokenizer}
)
(中略)
trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
dataset_text_field="text", # 学習対象としてtextキーを指定
)
問題が発生する原因
上記の問題が発生する原因をソースコードを読みながら解説します。
まず、学習対象となる実際のテキストはtrl/trainer/sft_trainer.py
中の以下の部分で指定されています。
①packing=Falseの場合
②packing=Trueの場合
どちらの場合でも、formatting_func
がNone
の場合にdataset_text_field
を利用し、そうでない場合はformatting_func
を利用するようになっています。
先述したdataset_text_field
を引数に渡す方法の場合、formatting_func
を指定しなければ値はNone
になっているので、一見問題ないように見えます。しかし、formatting_func
がNone
の場合、sft_trainer.py
の以下の部分の処理が走り、get_formatting_func_from_dataset()
関数が呼び出されます。
get_formatting_func_from_dataset()
関数はtrl/extras/dataset_formatting.py
の中で以下のように定義されています。
また、FORMAT_MAPPING
は以下のように定義されています。
get_formatting_func_from_dataset()
関数の実装を見ると分かるように、データセット内にmessages
、conversations
、instruction
という名前のついた列が存在し、かつそれが特定の形式になっている場合、内部でformatting_func
にここで定義された関数が設定され、元々None
だったformatting_func
の値がNone
ではなくなってしまいます。
その結果、formatting_func
がNone
の場合のみに使われるdataset_text_field
が利用されなくなり、messages
やconversations
、instruction
列をフォーマットしたデータが学習対象とされてしまい、想定通りの学習結果が得られないという問題が発生します。
回避策
解決策は最初に提示したような3つの方法が考えられます。
-
dataset_text_field
を使わず、formatting_func
を使用して明示的に学習に使用するデータとフォーマット方法を指定する。
formatting_func
がNone
でなければ上記の処理は行われずそのまま使われるので、こちらを使う方が確実です。 - データセットの
messages
、conversations
、instruction
列の名前をあらかじめ変更する。
データセット内に特定の名前の列が含まれているかどうかで判定されているので、これらの列名を変更しておけば判定に引っかかりません。 -
dataset.remove_columns(["messages", "conversations", "instruction"])
などを使用して、これらの列をあらかじめデータセットから削除する。
2番目の方法と同様に、あらかじめこれらの列をデータセットから削除しておけば条件に引っかからずこの挙動は発生しません。
まとめ
この記事ではSFTTrainerでdataset_text_field
の引数を使う場合の問題とその回避策についてまとめました。
根本的な解決策としては、formatting_func
を利用して学習対象のテキストを定義する方法が確実だと思います。
この記事が同様の問題に遭遇した方の一助になれば幸いです。
Discussion
有益な情報展開ありがとうございます!
自分もこちらの内容に当てはまる学習データを使っていたのでとても助かります・・・。
一応調べてみたところ、どうやら以下の時点(2024/10/4)で変更されたようです。。。
PR
https://github.com/huggingface/trl/pull/2078
差分
https://github.com/huggingface/trl/pull/2078/files#diff-67e157adfcd37d677fba66f610e3dfb56238cc550f221e8683fcfa0556e0f7ca
↓
「formatting_funcとdataset_text_fieldがない場合、get_formatting_func_from_dataset関数を呼び出す」から、
「formatting_funcがない場合、get_formatting_func_from_dataset関数を呼び出す」に変更されておりました。
それまではdataset_text_fieldがあれば問題なかったのですが・・・。
ありがとうございます。結構最近あった変更なんですね…。
私もこれに気付かずSFTをずっとしており500ドルくらい無駄にしました…。
500ドルは・・・きついですね😓
この記事のおかげで新たな犠牲者は減るかと思います。。。
私も変更に気付かずじまいだったかもしれないです・・・ありがとうございます🙏