Zenn
🍄

継続事前学習とフルファインチューニングを比較してみた

2024/12/29に公開

継続事前学習とファインチューニングを比較してみた

はじめに

LLMの学習方法について、ファインチューニングでは新たな知識の取得は難しいといわれています。新規に知識を取得するためにその都度、一からモデルを学習させるのは高コストですが継続事前学習により一からモデルを学習させずに知識の取得ができます。継続事前学習を試していて継続事前学習とファインチューニングの2つの学習方法の違いにより出力にどう影響が出るのか気になったため検証してみました。

2つの学習方法の確認

特徴 ファインチューニング 継続事前学習
目的 特定タスク(対話生成や物語生成)への最適化 新たなドメイン知識の吸収と適応
データの性質 特定のタスクに整形されたデータ(対話形式や物語のフォーマット) 無ラベルのドメイン特化テキスト
損失計算 次のトークン生成(文脈に沿った応答生成) 次のトークン生成(新たな知識の吸収)
学習率 比較的低め(例: 1e-5~2e-5程度。既存の知識を保持しつつタスク適応を行うため) やや高め(例: 5e-5~1e-4程度。ドメイン適応の速度を優先し、新たな知識を効率よく吸収するため)
出力 特定のタスクに特化したテキスト生成 ドメインに特化した一般的なテキスト生成

検証内容

継続事前学習とファインチューニングを同じデータセットやパラメタで学習を行い学習後のモデルの出力の違いを確認します。

検証方法

モデル:GPT2LMHeadModel
https://huggingface.co/docs/transformers/model_doc/gpt2

データセット:医療QAデータセット(PubMedQA)
https://huggingface.co/datasets/qiaojin/PubMedQA

データセット:判例データセット(lex_glue)・・・応用のみ
https://huggingface.co/datasets/coastalcph/lex_glue

手順:

  1. オリジナルモデルの性能確認

    • 知識が不足していることを確認
    • QAタスクが苦手であることを確認
  2. 継続事前学習

    医療知識は取得できるが、質疑応答タスクに適合しておらず会話が流暢ではないのではない?

    • 医療QAデータセット(PubMed)のQAタスクのテキストをひとまとめにして継続事前学習させる。
  3. ファインチューニング

    流暢に会話できるが、医療知識がなく的外れな回答をするのでは?

    • 医療QAデータセット(PubMed)についてQAタスクで学習させる。
  4. 応用編(継続事前学習+ファインチューニング)
    おまけに継続事前学習後にファインチューニングを行うことで知識を蓄えて、出力精度も高めることができるか確認します。法律判例データ(lex_glue)を用いて継続事前学習を行い、医療分野のQAタスクでファインチューニングした出力を確認します。

    • 継続事前学習(知識の取得):判例データセットで継続事前学習

    • ファインチューニング:医療QAタスクでファインチューニング

出力の評価方法

今回のモデルおよびデータセットが英語で専門的分野であるためGPT-4oを用いて以下の観点で5階評価させます。

ChatGPTの評価項目 各項目5段階評価

  • 専門性及び用語の妥当性 1〜5
  • 文章の構造 1〜5

出力スコアの予想表

GPT-2自体の精度および2つの学習方法について定性的違いおよびから以下のような結果になるのではないか予想します。

医療の質問 法律系(判例)の質問
オリジナルモデル 妥当性:1、文書:2 妥当性:1、文書:2
ファインチューニング 妥当性:2、文書:3 -
継続事前学習モデル 妥当性:3、文書:2 -
応用 妥当性:2、文書:3 妥当性:3、文書:3

実行コード

データセットや学習パラメタは一定にして継続事前学習とファインチューニングを比較しています。
応用編のファインチューニングのエポック数のみ増やしています。

1.モデル&データセット読み込み

# トークナイザーの準備
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name).to("cuda")  # GPU使用の場合
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# データセットとトークナイザーの読み込み
dataset = load_dataset("pubmed_qa", "pqa_labeled", split="train")  # PubMed QAのラベルなしデータ

2. ファインチューニング

質問応答データセットを用い、タスク特化型のトレーニングを行います。

# トークナイズおよびファインチューニング用のラベル設定
def preprocess_for_finetuning(data):
    # 質問と回答を取得
    question = " ".join(data["question"]) if isinstance(data["question"], list) else data["question"]
    answer = " ".join(data["long_answer"]) if isinstance(data["long_answer"], list) else data["long_answer"]

    # 質問と回答を別々にトークナイズ
    tokenized_question = tokenizer(question + tokenizer.eos_token, add_special_tokens=False)
    tokenized_answer = tokenizer(answer + tokenizer.eos_token, add_special_tokens=False)

    # 入力IDを結合
    input_ids = tokenized_question['input_ids'] + tokenized_answer['input_ids']
    attention_mask = [1] * len(input_ids)

    # ラベルを作成(質問部分をマスク)
    labels = [-100] * len(tokenized_question['input_ids']) + tokenized_answer['input_ids']

    # パディングとトランケーション
    max_length = 512
    input_ids = input_ids[:max_length]
    attention_mask = attention_mask[:max_length]
    labels = labels[:max_length]

    padding_length = max_length - len(input_ids)
    input_ids += [tokenizer.pad_token_id] * padding_length
    attention_mask += [0] * padding_length
    labels += [-100] * padding_length

    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}

# トークナイズとラベル設定を適用
tokenized_dataset_for_finetuning = dataset.map(preprocess_for_finetuning, batched=False)

# トレーニング設定の準備
training_args = TrainingArguments(
    output_dir="./results_fin",
    learning_rate=2e-5,
    per_device_train_batch_size=12,
    num_train_epochs=1,
    evaluation_strategy="epoch",
    logging_dir="./logs",  # ログディレクトリの指定
    logging_steps=10,  # 10ステップごとに損失を表示
)

# 省略

3. 継続事前学習

ラベルなしデータを準備し学習させます。

# トークナイズおよび継続事前学習用のラベル設定
def preprocess_for_pretraining(data):
    # 質問と回答を文字列に変換し、結合
    question = " ".join(data["question"]) if isinstance(data["question"], list) else data["question"]
    answer = " ".join(data["long_answer"]) if isinstance(data["long_answer"], list) else data["long_answer"]
    text = question + " " + answer  # 質問と回答を結合

    # テキストをトークナイズしてラベルを設定
    tokenized_data = tokenizer(text, padding="max_length", truncation=True, max_length=512)
    tokenized_data["labels"] = tokenized_data["input_ids"].copy()  # ラベルとして入力IDをそのまま使用
    return tokenized_data

# トークナイズとラベル設定を適用
tokenized_dataset_for_pretraining = dataset.map(preprocess_for_pretraining, batched=False)

# トレーニング設定の準備
training_args = TrainingArguments(
    output_dir="./results_pre",
    learning_rate=2e-5,
    per_device_train_batch_size=12,
    num_train_epochs=1,
    evaluation_strategy="epoch",
    logging_dir="./logs",  # ログディレクトリの指定
    logging_steps=10,  # 10ステップごとに損失を表示
)

# 省略

損失計算コードの確認

ここで今回使用したライブラリの損失計算のコードを確認してみます。
継続事前学習とファインチューニングのどちらもCrossEntropyLossで計算されます。

# 抜粋
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(
    shift_logits.view(-1, shift_logits.size(-1)),
    shift_labels.view(-1)
)

https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py

ファインチューニング用データセットコード

ファインチューニングのため質問部分を無視し、回答部分のみを正解ラベルとして扱う設定としています。

labels = [-100] * len(tokenized_question['input_ids']) + tokenized_answer['input_ids']

継続事前学習用データセットコード

質問および回答のすべてのトークンを正解ラベルとして使用します。

tokenized_data["labels"] = tokenized_data["input_ids"].copy()

このようにデータセットを作成して学習させています。学習はどちらもCrossEntropyLossの同じコードを使用しているので、今回の2つ学習の違いはこの部分だけとなります。

データセットと出力内容

継続事前学習で知識が蓄えられていることを確認するためデータセットに含まれる以下の内容に関して質問します。医療系の質問はデータセットに含まれている内容に関して質問をしますが、法律は適当な内容で質問しました。

# データセットの列名と一部のサンプルを確認
print(dataset.column_names)  # データセットの列名を表示
print("question: ",dataset["question"][0])  # 質問の最初のサンプルを表示
print("long_answer: ", dataset["long_answer"][0])  # 長い回答の最初のサンプルを表示
print("context: ", dataset["context"][0])  # コンテキストの最初のサンプルを表示

"""
['pubid', 'question', 'context', 'long_answer', 'final_decision']
question:  Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?
long_answer:  Results depicted mitochondrial dynamics in vivo as PCD progresses within the lace plant, and highlight the correlation of this organelle with other organelles during developmental PCD. To the best of our knowledge, this is the first report of mitochondria and chloroplasts moving on transvacuolar strands to form a ring structure surrounding the nucleus during developmental PCD. Also, for the first time, we have shown the feasibility for the use of CsA in a whole plant system. Overall, our findings implicate the mitochondria as playing a critical and early role in developmentally regulated PCD in the lace plant.
context:  {'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.', 'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in late stages of PCD (LPCD). Window stage leaves were stained with the mitochondrial dye MitoTracker Red CMXRos and examined. Mitochondrial dynamics were delineated into four categories (M1-M4) based on characteristics including distribution, motility, and membrane potential (ΔΨm). A TUNEL assay showed fragmented nDNA in a gradient over these mitochondrial stages. Chloroplasts and transvacuolar strands were also examined using live cell imaging. The possible importance of mitochondrial permeability transition pore (PTP) formation during PCD was indirectly examined via in vivo cyclosporine A (CsA) treatment. This treatment resulted in lace plant leaves with a significantly lower number of perforations compared to controls, and that displayed mitochondrial dynamics similar to that of non-PCD cells.'], 'labels': ['BACKGROUND', 'RESULTS'], 'meshes': ['Alismataceae', 'Apoptosis', 'Cell Differentiation', 'Mitochondria', 'Plant Leaves'], 'reasoning_required_pred': ['y', 'e', 's'], 'reasoning_free_pred': ['y', 'e', 's']}
"""
# 医療に関する質問
Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?"

# オリジナルの出力
The answer is that mitochondrion is a key component of the cell cycle. The mitochondrin is the main component that controls the production of energy. It is also the key to the formation of new mitochondrios.
. . .
 (1) The role of mitochondra in the maintenance of cell health. (2) Mitochondria are the primary source of ATP for the cells. Mitrions are also responsible for maintaining the integrity of cells and the ability of certain cells to survive. In addition, mitochondrolysis is an important step in maintaining cell function. This is why mitochondrosomes are important for cell maintenance. However, the role that mitochondrialrolysation plays in cell repair is not well understood. Here we show that the mitochondroid is involved in repair of damaged mitochondrial cells by the enzyme cytochrome P450 (CYP) and that it is important in repairing damaged mitochondrial cells in vitro. We show the importance of mitochondrialroid in mitochondrial repair by showing that cytoplasmic mitochondrodiolase (MtR) is activated by mitochondradiol and is responsible in mitochondric repair. MtRs are activated when the mitochondrial

# ファインチューニングの出力
The results suggest that the mitochondrion is a key regulator of the plant's growth and survival.

# 継続事前学習の出力
The role played by cytokines in the regulation of immune response is not well understood. The cytokine IL-1β is a major cytokin that is involved with the immune system. It is important to understand the mechanisms by which cytokins regulate immune function.

# 継続事前学習+ファインチューニング(応用)の出力
The results suggest that mitochondrion-dependent changes in the plant's mitochondriol-mediated signaling pathway may be involved in this process.
# 法律に関する質問
Does incorporating copyrighted material into an educational video qualify as fair use under copyright law?"

# オリジナルの出力
The Copyright Act of 1976 (the "Act") requires that "the use of copyrighted materials in a public or private educational or scientific publication or publication" be "fair use" under the terms of the Act. The Act also requires the use by the public of "any and all materials, materials or information, including but not limited to, photographs, videotapes, audio, video, or any other form of media, that are of a nature that is likely to cause a substantial disruption to the educational, scientific, medical, social, economic, political, cultural, educational and other public interest."
. . .
,
"Fair use is defined as the dissemination of information that does not infringe upon the rights of others, and is not a violation of copyright." . ."
- The Copyright Office of Public Information, "Fair Use of Copyright Information," p. 5.
In the case of an instructional video that was made available to a student, the video was not copyrighted. However, in the event that the student was able to obtain a copy of it, he or she could use it to make a video about the subject matter of his or her education. In the

# 継続事前学習+ファインチューニング(応用)の出力
The use of copyrighted materials in an instructional video qualifies as a fair-use exercise under the Copyright Act. However, the use in the instructional videos of material that is not copyrighted is a violation of copyright. Therefore, it is important to consider whether the content of the educational videos is in accordance with the fair uses clause of § 5 of Title 5, United States Code, as amended, and whether it qualifies under § 6 of that Act as an "educational video."

評価

実際の出力スコアの結果表

医療の質問 法律系(判例)の質問
オリジナルモデル 妥当性:1、文書:2 妥当性:2、文書:2
ファインチューニング 妥当性:3、文書:2 -
継続事前学習モデル 妥当性:2、文書:2 -
応用 妥当性:3、文書:4 妥当性:2、文書:2

学習手法の違いと考察

  1. ファインチューニング
    • 妥当性が向上したが、文書の質は向上しなかった。
  2. 継続事前学習
    • 妥当性は向上したが、予想より向上しなかった。
  3. 応用モデル
    • 医療領域では妥当性、文書どちらも大幅に向上した。
    • 法律領域では予想より向上しなかった。

まとめ

今回、継続事前学習とファインチューニングを比較してみました。特定タスク、出力数、GPT-4oによるスコア検証の安定性などありますが、両者の学習による精度の向上が確認できました。また、今回のライブラリのコード動作について、正解ラベルの作り方で学習手法が変わることがわかりました。
大規模なLLMモデルを自分で継続事前学習したりフルファインチューニングするためにはそれなりのリソースが必要ですが、最近はGemma2-2Bのような小さな高精度モデルも発表されてきましたので、BERT普及期に皆が夢見た自分のデータセットを学習(知識の取得)させて活用できる時がきたのかもしれません。

参考

継続事前学習
https://arxiv.org/abs/2403.08763

Discussion

ログインするとコメントできます