【AI_6日目】Google Colaboratoryでファインチューニングをやってみる
こんにちは投資ロウトです。
背景
ファインチューニングができるようにしたい背景があります。
※先輩にこれやったら?というハンズオン的なものを教えて頂いたので、それに倣ってやってみるところから始めていきます。(他のサイトを実施するだけ)
Llama 3
上記のリンクはLlama 3を実際に使ってみたいと思います。
# パッケージのインストール
!pip install -U transformers accelerate bitsandbytes
!pip install trl peft wandb
!git clone https://github.com/huggingface/trl
%cd trl
・bitsandbytes・・・LLM 向け 8bit 量子化ライブラリ
実行するとエラーが出てしまった。
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.
ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 17.0.0 which is incompatible.
依存関係が問題のようなので、pyarrowのダウングレードをしてみる。
pip install pyarrow==14.0.2
またエラーが発生した。
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 2.21.0 requires pyarrow>=15.0.0, but you have pyarrow 14.0.2 which is incompatible.
さらに調整
!pip install datasets==2.10.0
成功したので、下記をさらに実施。
!pip uninstall pyarrow datasets
!pip install pyarrow==14.0.2 datasets==2.10.0
・以下でアカウント開設をしていく。
・次にメールアドレスの認証を行う
・次にトークンを発行してみる
・開設リンクにはwrite権限が必要とあったので、write権限で実施アカウント作成の実施。
・次に以下のコマンドでHugging Faceにアクセスできるとのことで、やってみる。
!huggingface-cli login
こちらに先ほど発行したトークンを入れる
実施するとログインに成功したように見えます。
To login, huggingface_hub requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible):
Add token as git credential? (Y/n) Y
Token is valid (permission: write).
Cannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub.
Run the following command in your terminal in case you want to set the 'store' credential helper as default.
git config --global credential.helper store
Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.
Token has not been saved to git credential helper.
Your token has been saved to /root/.cache/huggingface/token
Login successful
「trl/examples/sft.py」の編集。と言う記載がありますが、どこにそのファイルがあるかわかりませんでした。とりあえず下記のコマンドを実施してみる。
!ls -R
大量にファイルはあるんだよな・・・
ググってみると、下記のような編集方法を提示してくれておりました。
まずはGdriveをマウントする
from google.colab import drive
drive.mount('/content/drive')
そもそもGPUを選択し忘れていたので、一旦やり直す。元のリンクにあるように、「「GPU」の「A100」を選択。」で実施してみる。
※前回の記事で実施。
こちらのツールを入れてみる。
うまくいかないので、bashで突破する
!cat ./examples/scripts/sft.py
下記ファイルが表示される
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# regular:
python examples/scripts/sft.py \
--model_name_or_path="facebook/opt-350m" \
--dataset_text_field="text" \
--report_to="wandb" \
--learning_rate=1.41e-5 \
--per_device_train_batch_size=64 \
--gradient_accumulation_steps=16 \
--output_dir="sft_openassistant-guanaco" \
--logging_steps=1 \
--num_train_epochs=3 \
--max_steps=-1 \
--push_to_hub \
--gradient_checkpointing
# peft:
python examples/scripts/sft.py \
--model_name_or_path="facebook/opt-350m" \
--dataset_text_field="text" \
--report_to="wandb" \
--learning_rate=1.41e-5 \
--per_device_train_batch_size=64 \
--gradient_accumulation_steps=16 \
--output_dir="sft_openassistant-guanaco" \
--logging_steps=1 \
--num_train_epochs=3 \
--max_steps=-1 \
--push_to_hub \
--gradient_checkpointing \
--use_peft \
--lora_r=64 \
--lora_alpha=16
"""
import logging
import os
from contextlib import nullcontext
from trl.commands.cli_utils import init_zero_verbose, SFTScriptArguments, TrlParser
from trl.env_utils import strtobool
TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0"))
if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"
from rich.console import Console
from rich.logging import RichHandler
import torch
from datasets import load_dataset
from tqdm.rich import tqdm
from transformers import AutoTokenizer
from trl import (
ModelConfig,
RichProgressCallback,
SFTConfig,
SFTTrainer,
get_peft_config,
get_quantization_config,
get_kbit_device_map,
)
tqdm.pandas()
if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
if __name__ == "__main__":
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()
################
# Model init kwargs & Tokenizer
################
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=model_config.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
raw_datasets = load_dataset(args.dataset_name)
train_dataset = raw_datasets[args.dataset_train_split]
eval_dataset = raw_datasets[args.dataset_test_split]
################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)
################
# Training
################
with init_context:
trainer = SFTTrainer(
model=model_config.model_name_or_path,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)
trainer.train()
with save_context:
trainer.save_model(training_args.output_dir)
下記のコマンドを実施して上書きする
%%bash
sed -i 's|raw_datasets = load_dataset(args.dataset_name)|# データセットの読み込み\n dataset = load_dataset("bbz662bbz/databricks-dolly-15k-ja-gozarinnemon", split="train")\n dataset = dataset.filter(lambda example: example["category"] == "open_qa")|' ./examples/scripts/sft.py
sed -i 's|train_dataset = raw_datasets\[args.dataset_train_split\]|# プロンプトの生成\n def generate_prompt(example):\n messages = [\n {\n "role": "system",\n "content": "あなたは日本語で回答するAIアシスタントです。"\n },\n {\n "role": "user",\n "content": example["instruction"]\n },\n {\n "role": "assistant",\n "content": example["output"]\n }\n ]\n return tokenizer.apply_chat_template(messages, tokenize=False)\n\n # textカラムの追加\n def add_text(example):\n example["text"] = generate_prompt(example)\n return example\n\n dataset = dataset.map(add_text)\n dataset = dataset.remove_columns(["input", "category", "output", "index", "instruction"])\n\n # データセットの分割\n train_test_split = dataset.train_test_split(test_size=0.1)\n train_dataset = train_test_split["train"]|' ./examples/scripts/sft.py
sed -i 's|eval_dataset = raw_datasets\[args.dataset_test_split\]|eval_dataset = train_test_split["test"]|' ./examples/scripts/sft.py
再度ファイルを表示してみる
!cat ./examples/scripts/sft.py
上手く書き換わっているように見えますね。
元々のリンクにあるように、学習させてみる。
# 学習
!python examples/scripts/sft.py \
--model_name meta-llama/Meta-Llama-3-8B-Instruct \
--dataset_name bbz662bbz/databricks-dolly-15k-ja-gozaru \
--dataset_text_field text \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 1 \
--learning_rate 2e-4 \
--optim adamw_torch \
--save_steps 50 \
--logging_steps 50 \
--max_steps 500 \
--use_peft \
--lora_r 64 \
--lora_alpha 16 \
--lora_dropout 0.1 \
--load_in_4bit \
--report_to wandb \
--output_dir Llama-3-Gozaru-8B-Instruct
上記を実施すると、
OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct.
403 Client Error. (Request ID: Root=1-66dc05a0-48b0e51a4798a0eb2ed676bd;190463cf-1f7b-49c6-ba4f-ae5ccb358864)
Cannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/resolve/main/config.json.
Access to model meta-llama/Meta-Llama-3-8B-Instruct is restricted and you are not in the authorized list. Visit https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct to ask for access.
ゲート付きのリポジトリで、アクセス権の申請が必要とのことみたいでした。今回はファインチューニングをすることが目的なので、他の公開リポジトリで試してみようと思います。
Llama 2は公開リポジトリなのか念の為検証
!python examples/scripts/sft.py \
--model_name meta-llama/Llama-2-7b-hf \
--dataset_name bbz662bbz/databricks-dolly-15k-ja-gozaru \
--dataset_text_field text \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 1 \
--learning_rate 2e-4 \
--optim adamw_torch \
--save_steps 50 \
--logging_steps 50 \
--max_steps 500 \
--use_peft \
--lora_r 64 \
--lora_alpha 16 \
--lora_dropout 0.1 \
--load_in_4bit \
--report_to wandb \
--output_dir Llama-2-Gozaru-7B-Instruct
これもゲート付きだった。gpt2はどうだろう。。。
!python examples/scripts/sft.py \
--model_name "gpt2" \
--dataset_name bbz662bbz/databricks-dolly-15k-ja-gozaru \
--dataset_text_field text \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 1 \
--learning_rate 2e-4 \
--optim adamw_torch \
--save_steps 50 \
--logging_steps 50 \
--max_steps 500 \
--use_peft \
--lora_r 64 \
--lora_alpha 16 \
--lora_dropout 0.1 \
--load_in_4bit \
--report_to wandb \
--output_dir GPT2-Gozaru-Instruct
先ほどとエラーが変わったが、データセットはあるんですよね・・・
ライブラリの互換性の可能性があるので、バージョンアップを実施。
!pip install --upgrade pyarrow datasets
そしてgpt-2で再度実行
!python examples/scripts/sft.py \
--model_name "gpt2" \
--dataset_name "ag_news" \
--dataset_text_field text \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 1 \
--learning_rate 2e-4 \
--optim adamw_torch \
--save_steps 50 \
--logging_steps 50 \
--max_steps 500 \
--use_peft \
--lora_r 64 \
--lora_alpha 16 \
--lora_dropout 0.1 \
--load_in_4bit \
--report_to wandb \
--output_dir GPT2-AGNews-Instruct
結果、以下のエラーが発生。
ValueError: Cannot use apply_chat_template() because tokenizer.chat_template is not set and no template argument was passed! For information about writing templates and setting the tokenizer.chat_template attribute, please see the documentation at https://huggingface.co/docs/transformers/main/en/chat_templating
tokenizer.chat_template はデフォルトでは設定されていないのが原因の可能性がある。
%%bash
sed -i 's/return tokenizer.apply_chat_template(messages, tokenize=False)/return tokenizer.apply_chat_template(messages, tokenize=False, chat_template="simple_chat")/' ./examples/scripts/sft.py
# 修正内容を確認するために出力します
再度gpt-2を実施。
色々データセットがないかと思って、ニュースを学ばせてしまいましたが、原因はそこではなかったので、元々学習させたいデータで学習させる
!python examples/scripts/sft.py \
--model_name "gpt2" \
--dataset_name "bbz662bbz/databricks-dolly-15k-ja-gozaru" \
--dataset_text_field text \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 1 \
--learning_rate 2e-4 \
--optim adamw_torch \
--save_steps 50 \
--logging_steps 50 \
--max_steps 500 \
--use_peft \
--lora_r 64 \
--lora_alpha 16 \
--lora_dropout 0.1 \
--load_in_4bit \
--report_to wandb \
--output_dir GPT2-AGNews-Instruct
そしてgpt2で以下を実施。
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# トークナイザーとモデルの準備 (GPT-2ベース)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained(
"./GPT2-AGNews-Instruct",
device_map="auto",
torch_dtype="auto",
)
# プロンプトの準備 (シンプルな形式に変更)
prompt = "あなたは日本語で回答するAIアシスタントです。まどか☆マギカでは誰が一番かわいい?"
推論を実施。
いいません。って何?笑 結構捻くれたLLMさんですね。。。
バグったカオスのものができてしまいました・・・
gpt2のモデルに問題があるのが、自分の学習の仕方に問題があるのか、、、今度Meta社のLlamaを検証してみたいところはありますよね・・・
一応meta社に申請だけはしてみました。
と一旦以上で学習を終えたいと思います。焦らずコツコツ一つずつ進んでいきたいと思います。色々な失敗を通して知ったこともあったので、大きな収穫ではあるかなと思います。ご精読ありがとうございました。
Discussion