🎉

【AI_6日目】Google Colaboratoryでファインチューニングをやってみる

2024/09/07に公開

こんにちは投資ロウトです。

背景

ファインチューニングができるようにしたい背景があります。
※先輩にこれやったら?というハンズオン的なものを教えて頂いたので、それに倣ってやってみるところから始めていきます。(他のサイトを実施するだけ)

https://note.com/npaka/n/n315c0bdbbf00

Llama 3

上記のリンクはLlama 3を実際に使ってみたいと思います。

# パッケージのインストール
!pip install -U transformers accelerate bitsandbytes
!pip install trl peft wandb
!git clone https://github.com/huggingface/trl
%cd trl

・bitsandbytes・・・LLM 向け 8bit 量子化ライブラリ

https://zenn.dev/syoyo/articles/3bde98e9972dea

実行するとエラーが出てしまった。

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.
ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 17.0.0 which is incompatible.

依存関係が問題のようなので、pyarrowのダウングレードをしてみる。

pip install pyarrow==14.0.2

またエラーが発生した。

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 2.21.0 requires pyarrow>=15.0.0, but you have pyarrow 14.0.2 which is incompatible.

さらに調整

!pip install datasets==2.10.0

成功したので、下記をさらに実施。

!pip uninstall pyarrow datasets
!pip install pyarrow==14.0.2 datasets==2.10.0

・以下でアカウント開設をしていく。

https://huggingface.co/login?next=%2Fsettings%2Ftokens

・次にメールアドレスの認証を行う

・次にトークンを発行してみる

https://huggingface.co/settings/tokens

・開設リンクにはwrite権限が必要とあったので、write権限で実施アカウント作成の実施。

・次に以下のコマンドでHugging Faceにアクセスできるとのことで、やってみる。

!huggingface-cli login

こちらに先ほど発行したトークンを入れる

実施するとログインに成功したように見えます。

To login, huggingface_hub requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) Y
Token is valid (permission: write).
Cannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub.
Run the following command in your terminal in case you want to set the 'store' credential helper as default.

git config --global credential.helper store

Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.
Token has not been saved to git credential helper.
Your token has been saved to /root/.cache/huggingface/token
Login successful

「trl/examples/sft.py」の編集。と言う記載がありますが、どこにそのファイルがあるかわかりませんでした。とりあえず下記のコマンドを実施してみる。

!ls -R

大量にファイルはあるんだよな・・・

ググってみると、下記のような編集方法を提示してくれておりました。

https://qiita.com/funatsufumiya/items/e455ab8d801af6e1415d

まずはGdriveをマウントする

from google.colab import drive
drive.mount('/content/drive')

そもそもGPUを選択し忘れていたので、一旦やり直す。元のリンクにあるように、「「GPU」の「A100」を選択。」で実施してみる。
※前回の記事で実施。

https://zenn.dev/doshirote/articles/c9b4bb3cd80b6f

こちらのツールを入れてみる。

うまくいかないので、bashで突破する

!cat ./examples/scripts/sft.py

下記ファイルが表示される

# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# regular:
python examples/scripts/sft.py \
    --model_name_or_path="facebook/opt-350m" \
    --dataset_text_field="text" \
    --report_to="wandb" \
    --learning_rate=1.41e-5 \
    --per_device_train_batch_size=64 \
    --gradient_accumulation_steps=16 \
    --output_dir="sft_openassistant-guanaco" \
    --logging_steps=1 \
    --num_train_epochs=3 \
    --max_steps=-1 \
    --push_to_hub \
    --gradient_checkpointing

# peft:
python examples/scripts/sft.py \
    --model_name_or_path="facebook/opt-350m" \
    --dataset_text_field="text" \
    --report_to="wandb" \
    --learning_rate=1.41e-5 \
    --per_device_train_batch_size=64 \
    --gradient_accumulation_steps=16 \
    --output_dir="sft_openassistant-guanaco" \
    --logging_steps=1 \
    --num_train_epochs=3 \
    --max_steps=-1 \
    --push_to_hub \
    --gradient_checkpointing \
    --use_peft \
    --lora_r=64 \
    --lora_alpha=16
"""

import logging
import os
from contextlib import nullcontext

from trl.commands.cli_utils import init_zero_verbose, SFTScriptArguments, TrlParser
from trl.env_utils import strtobool

TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0"))

if TRL_USE_RICH:
    init_zero_verbose()
    FORMAT = "%(message)s"

    from rich.console import Console
    from rich.logging import RichHandler

import torch
from datasets import load_dataset

from tqdm.rich import tqdm
from transformers import AutoTokenizer

from trl import (
    ModelConfig,
    RichProgressCallback,
    SFTConfig,
    SFTTrainer,
    get_peft_config,
    get_quantization_config,
    get_kbit_device_map,
)

tqdm.pandas()

if TRL_USE_RICH:
    logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)


if __name__ == "__main__":
    parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
    args, training_args, model_config = parser.parse_args_and_config()

    # Force use our print callback
    if TRL_USE_RICH:
        training_args.disable_tqdm = True
        console = Console()

    ################
    # Model init kwargs & Tokenizer
    ################
    quantization_config = get_quantization_config(model_config)
    model_kwargs = dict(
        revision=model_config.model_revision,
        trust_remote_code=model_config.trust_remote_code,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=model_config.torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    training_args.model_init_kwargs = model_kwargs
    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
    )
    tokenizer.pad_token = tokenizer.eos_token

    ################
    # Dataset
    ################
    raw_datasets = load_dataset(args.dataset_name)

    train_dataset = raw_datasets[args.dataset_train_split]
    eval_dataset = raw_datasets[args.dataset_test_split]

    ################
    # Optional rich context managers
    ###############
    init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...")
    save_context = (
        nullcontext()
        if not TRL_USE_RICH
        else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
    )

    ################
    # Training
    ################
    with init_context:
        trainer = SFTTrainer(
            model=model_config.model_name_or_path,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            peft_config=get_peft_config(model_config),
            callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
        )

    trainer.train()

    with save_context:
        trainer.save_model(training_args.output_dir)

下記のコマンドを実施して上書きする

%%bash
sed -i 's|raw_datasets = load_dataset(args.dataset_name)|# データセットの読み込み\n    dataset = load_dataset("bbz662bbz/databricks-dolly-15k-ja-gozarinnemon", split="train")\n    dataset = dataset.filter(lambda example: example["category"] == "open_qa")|' ./examples/scripts/sft.py

sed -i 's|train_dataset = raw_datasets\[args.dataset_train_split\]|# プロンプトの生成\n    def generate_prompt(example):\n        messages = [\n            {\n                "role": "system",\n                "content": "あなたは日本語で回答するAIアシスタントです。"\n            },\n            {\n                "role": "user",\n                "content": example["instruction"]\n            },\n            {\n                "role": "assistant",\n                "content": example["output"]\n            }\n        ]\n        return tokenizer.apply_chat_template(messages, tokenize=False)\n\n    # textカラムの追加\n    def add_text(example):\n        example["text"] = generate_prompt(example)\n        return example\n\n    dataset = dataset.map(add_text)\n    dataset = dataset.remove_columns(["input", "category", "output", "index", "instruction"])\n\n    # データセットの分割\n    train_test_split = dataset.train_test_split(test_size=0.1)\n    train_dataset = train_test_split["train"]|' ./examples/scripts/sft.py

sed -i 's|eval_dataset = raw_datasets\[args.dataset_test_split\]|eval_dataset = train_test_split["test"]|' ./examples/scripts/sft.py

再度ファイルを表示してみる

!cat ./examples/scripts/sft.py

上手く書き換わっているように見えますね。

元々のリンクにあるように、学習させてみる。

# 学習
!python examples/scripts/sft.py \
    --model_name meta-llama/Meta-Llama-3-8B-Instruct \
    --dataset_name bbz662bbz/databricks-dolly-15k-ja-gozaru \
    --dataset_text_field text \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --learning_rate 2e-4 \
    --optim adamw_torch \
    --save_steps 50 \
    --logging_steps 50 \
    --max_steps 500 \
    --use_peft \
    --lora_r 64 \
    --lora_alpha 16 \
    --lora_dropout 0.1 \
    --load_in_4bit \
    --report_to wandb \
    --output_dir Llama-3-Gozaru-8B-Instruct

上記を実施すると、

OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct.
403 Client Error. (Request ID: Root=1-66dc05a0-48b0e51a4798a0eb2ed676bd;190463cf-1f7b-49c6-ba4f-ae5ccb358864)

Cannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/resolve/main/config.json.
Access to model meta-llama/Meta-Llama-3-8B-Instruct is restricted and you are not in the authorized list. Visit https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct to ask for access.

ゲート付きのリポジトリで、アクセス権の申請が必要とのことみたいでした。今回はファインチューニングをすることが目的なので、他の公開リポジトリで試してみようと思います。

Llama 2は公開リポジトリなのか念の為検証

!python examples/scripts/sft.py \
    --model_name meta-llama/Llama-2-7b-hf \
    --dataset_name bbz662bbz/databricks-dolly-15k-ja-gozaru \
    --dataset_text_field text \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --learning_rate 2e-4 \
    --optim adamw_torch \
    --save_steps 50 \
    --logging_steps 50 \
    --max_steps 500 \
    --use_peft \
    --lora_r 64 \
    --lora_alpha 16 \
    --lora_dropout 0.1 \
    --load_in_4bit \
    --report_to wandb \
    --output_dir Llama-2-Gozaru-7B-Instruct

これもゲート付きだった。gpt2はどうだろう。。。

!python examples/scripts/sft.py \
    --model_name "gpt2" \
    --dataset_name bbz662bbz/databricks-dolly-15k-ja-gozaru \
    --dataset_text_field text \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --learning_rate 2e-4 \
    --optim adamw_torch \
    --save_steps 50 \
    --logging_steps 50 \
    --max_steps 500 \
    --use_peft \
    --lora_r 64 \
    --lora_alpha 16 \
    --lora_dropout 0.1 \
    --load_in_4bit \
    --report_to wandb \
    --output_dir GPT2-Gozaru-Instruct

先ほどとエラーが変わったが、データセットはあるんですよね・・・

ライブラリの互換性の可能性があるので、バージョンアップを実施。

!pip install --upgrade pyarrow datasets

そしてgpt-2で再度実行

!python examples/scripts/sft.py \
    --model_name "gpt2" \
    --dataset_name "ag_news" \
    --dataset_text_field text \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --learning_rate 2e-4 \
    --optim adamw_torch \
    --save_steps 50 \
    --logging_steps 50 \
    --max_steps 500 \
    --use_peft \
    --lora_r 64 \
    --lora_alpha 16 \
    --lora_dropout 0.1 \
    --load_in_4bit \
    --report_to wandb \
    --output_dir GPT2-AGNews-Instruct

結果、以下のエラーが発生。

ValueError: Cannot use apply_chat_template() because tokenizer.chat_template is not set and no template argument was passed! For information about writing templates and setting the tokenizer.chat_template attribute, please see the documentation at https://huggingface.co/docs/transformers/main/en/chat_templating

tokenizer.chat_template はデフォルトでは設定されていないのが原因の可能性がある。

%%bash
sed -i 's/return tokenizer.apply_chat_template(messages, tokenize=False)/return tokenizer.apply_chat_template(messages, tokenize=False, chat_template="simple_chat")/' ./examples/scripts/sft.py

# 修正内容を確認するために出力します

再度gpt-2を実施。

色々データセットがないかと思って、ニュースを学ばせてしまいましたが、原因はそこではなかったので、元々学習させたいデータで学習させる

!python examples/scripts/sft.py \
    --model_name "gpt2" \
    --dataset_name "bbz662bbz/databricks-dolly-15k-ja-gozaru" \
    --dataset_text_field text \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --learning_rate 2e-4 \
    --optim adamw_torch \
    --save_steps 50 \
    --logging_steps 50 \
    --max_steps 500 \
    --use_peft \
    --lora_r 64 \
    --lora_alpha 16 \
    --lora_dropout 0.1 \
    --load_in_4bit \
    --report_to wandb \
    --output_dir GPT2-AGNews-Instruct

そしてgpt2で以下を実施。

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# トークナイザーとモデルの準備 (GPT-2ベース)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained(
    "./GPT2-AGNews-Instruct",
    device_map="auto",
    torch_dtype="auto",
)

# プロンプトの準備 (シンプルな形式に変更)
prompt = "あなたは日本語で回答するAIアシスタントです。まどか☆マギカでは誰が一番かわいい?"

推論を実施。

いいません。って何?笑 結構捻くれたLLMさんですね。。。

バグったカオスのものができてしまいました・・・

gpt2のモデルに問題があるのが、自分の学習の仕方に問題があるのか、、、今度Meta社のLlamaを検証してみたいところはありますよね・・・

一応meta社に申請だけはしてみました。

と一旦以上で学習を終えたいと思います。焦らずコツコツ一つずつ進んでいきたいと思います。色々な失敗を通して知ったこともあったので、大きな収穫ではあるかなと思います。ご精読ありがとうございました。

Discussion