💬
decoder onlyのデータセットに関して
学習方法
- pretrain(事前学習)
- インストラクションチューニング
CausalLMの前提
- CausalLM modelsの場合、labels を 質問応答モデルのように考えてはいけない。
代わりに、labels と input_ids は形状が同じになるように作成 する必要があります。
数学的に表すと、インデックス k の labels トークンを予測するために、input_ids のインデックス 0 から k-1 までのトークンを使用
pred_token_k = model(input_ids[:k])
ポイント1
- 損失 (loss)の算出はlabels[k] の真の値と、モデルが予測した pred_token_k を比較 であること。
予測値 (pred_token_k)
モデルが予測した確率分布であり、各トークンに対する確率が割り当てられる。
真のラベル (labels[k])
1 × v のベクトルとして表せる。
実際の正解トークンには 1 を、それ以外のトークンには 0 を設定 することで、適切な形に整える
ポイント2
- 損失を計算するために、 CausalLMは 内部的に labels を1つ右にシフト してから、クロスエントロピー損失を計算すること
なので
基本的の以下の構造をとる。
{"input_ids":instruction + output, "labels":instruction + output} # the HF model will take care of the shift + 1
ポイント3
CausalLM モデルでは、通常 すべての過去のトークンにアテンションできるようにするため、基本的に、attention_mask をすべて 1 にする
インストラクションチューニングの場合
instruction masking
ラベル (labels) の instruction 部分を -100 に置き換える手法がある。
インストラクション部分の損失計算を無視 (-100) することで、モデルが output の学習に集中できるようにする。
こうすることで、モデルが適切に completion (出力部分) を学習し、インストラクション部分の誤った予測に引っ張られないようにすることができます
引用:https://wandb.ai/capecape/alpaca_ft/reports/How-to-Fine-tune-an-LLM-Part-3-The-HuggingFace-Trainer--Vmlldzo1OTEyNjMy
{"input_ids":instruction + output, "labels":[-100]*len(instruction) + output}
コード例
def tokenize_function(example):
# label等の作成
MAX_LENGTH = int(512)
prompt = tokenizer.encode("ここにインストラクション", add_special_tokens=False)
response = tokenizer.encode("ここに回答" + tokenizer.eos_token, add_special_tokens=False)
full_context_inputs_ids = (prompt + response)[-MAX_LENGTH:]
sample = {
"input_ids": full_context_inputs_ids,
"attention_mask" : [1] * (len(full_context_inputs_ids)),
"labels": ([-100] * len(prompt) + response)[-MAX_LENGTH:],
}
pad_length = MAX_LENGTH - len(full_context_inputs_ids)
if pad_length >= 0:
sample["input_ids"] = sample["input_ids"] + [tokenizer.pad_token_id] * pad_length
sample["attention_mask"] = sample["attention_mask"] + [0] * pad_length
sample["labels"] = sample["labels"] + [-100] * pad_length
return sample
}
- ラベルにおいて、クロスエントロピー損失を計算する際に-100をつけると無視される。
- これらのパディングトークンに対してモデルが何を予測するかは重要ではない。そのため、パディングトークンにも-100を設定すると良さげか?
pretrain時
<追記予定>
参考/引用
Discussion