🐡

axolotlを使ったLLMのファインチューニング

2024/12/11に公開

はじめに

この記事では、axolotlというツールを使ってLLMをファインチューニングする方法について解説します。

公式のリポジトリはこちらです。

https://github.com/axolotl-ai-cloud/axolotl

axolotlとは

axolotlは、LLMのファインチューニングを素早く簡単に行うためのツールです。特に海外の個人LLM開発者によく利用されています。

このツール自体が特殊な学習方法等を提供しているというわけではなく、内部的にはHugging Faceのtrlやaccelerateを利用しています。あくまでファインチューニングのためのデータ処理や学習の設定・実行・管理などを一元的・簡便に行うためのツールです。

このツールは以下のような特徴を持っています。

  • LlamaやGemmaなど主要なモデルのファインチューニングに一通り対応
  • 学習のパラメータやデータセット等の設定を全て単一のyamlファイルにまとめることで、簡単にカスタマイズ可能・設定の配布も可能に
  • フルファインチューニング、PEFT双方に対応
  • SFTだけでなく、DPOやORPO、SimPOなど様々なPreference Optimization手法に対応
  • Liger KernelUnslothなどの軽量化手法に対応
  • DeepspeedやFSDPを通じたマルチGPU/マルチノードでの学習に対応
  • 様々な形式のデータセットを簡単に対話形式に変換して学習することが可能

以下、このaxolotlを使ったLLMのファインチューニングの方法を簡単に解説します。

環境準備

実際の使い方の解説に入る前に、axolotlを動かすための環境を準備します。
複数の方法がありますが、ここではpipを使った方法とDockerを使った方法の2つを解説します。

*Windows上での動作には対応していないので、Windows PCで動かしたい方はWSLまたはDockerを利用してください。

1. pipを使う方法

CUDAやPyTorch等の環境設定が通常通り済んでいる環境であれば、以下のようにリポジトリをcloneしpipからインストールするだけで利用可能になります。

git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl

pip3 install packaging ninja
pip3 install -e '.[flash-attn,deepspeed]'

2. Dockerを使う方法

axolotlai/axolotlというDockerイメージが公式から提供されているので、これを利用します。以下は公式のREADMEにあるサンプルの引用です。

docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest

また、一部のGPUレンタルサービスではDockerイメージを通じて簡単にaxolotlを利用することが出来ます。
例えばrunpodでは以下のようにtemplateから選択することで簡単に環境構築が可能です。


runpod上にあるaxolotlと名前の付いたDockerイメージ一覧。公式のものであるaxolotlai/axolotl-cloud:main-latestの利用を推奨します。

学習の流れ

axolotlを利用したLLMのファインチューニングは基本的に以下の3ステップで行われます。

  1. 学習の設定を記述するyamlファイルの準備
  2. 学習データの加工などの前処理
  3. 学習の実行

今回は、Googleのgemma-2-2bを元に、SFTと*SimPOを適用してInstruction Tuningを行う方法をデモとして解説します。

*公式のドキュメントには記載はないですが、CPOやSimPOにも対応しています。ドキュメンテーションも兼ねてSimPOを例にして解説します。

SFTにはRunpodのRTX 4090x2を、SimPOにはRTX 4090x4を使い、axolotlai/axolotl-cloud:main-latestのイメージ上で学習を行います。

コンテナ起動後、実際の学習に進む前に、以下のようにHugging Faceとwandbの認証を済ませておいてください。

# write権限のあるtokenを利用してHFにログイン(学習後のモデルアップロードに必要)
huggingface-cli login
# wandbにログイン(wandbに学習ログを残したい場合)
wandb login

axolotlを利用したLLMのSFT

まず第一段階としてSFTを行う流れを解説します。

1. 学習の設定を記述するyamlファイルの準備

学習手法やハイパーパラメータ、利用するデータセットやその前処理など学習に関する各種設定をまとめて記述するyaml形式のconfigファイルを準備します。

configファイルの例は公式のexamplesに様々なものがあるので、これを参考に設定してください。また、下記リンク先に各種オプションの詳細説明があります。(なお、ドキュメンテーションが追いついておらず記載されていないオプションもあります)

https://axolotl-ai-cloud.github.io/axolotl/docs/config.html

今回のSFTでは以下のようなyamlファイルをsft-gemma-2-2b.ymlという名前で作成し利用します。

学習用のconfigファイル
sft-gemma-2-2b.yml
# 学習のベースモデルに関する設定
base_model: google/gemma-2-2b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

# 学習後のモデルのHFへのアップロードに関する設定
hub_model_id: Aratako/gemma-2-2b-axolotl-sft-v1.0
hub_strategy: "end"
push_dataset_to_hub:
hf_use_auth_token: true

# Liger Kernelの設定(学習の軽量・高速化)
plugins:
  - axolotl.integrations.liger.LigerPlugin
liger_cross_entropy: false
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true

# 量子化に関する設定
load_in_8bit: false
load_in_4bit: true

# SFTに利用するchat templateの設定
chat_template: gemma

# 学習データセットの前処理に関する設定
datasets:
  - path: kanhatakeyama/ramdom-to-fixed-multiturn-Calm3
    split: 20240806filtered[0:10000]
    type: chat_template
    field_messages: messages
    message_field_role: role
    message_field_content: content
  - path: llm-jp/magpie-sft-v1.0
    split: train[0:10000]
    type: chat_template
    field_messages: conversations
    message_field_role: role
    message_field_content: content
  - path: Aratako/magpie-qwen2.5-32b-reasoning-100k-formatted
    split: train[0:10000]
    type: chat_template
    field_messages: conversations
    message_field_role: role
    message_field_content: content

# データセット、モデルの出力先に関する設定
shuffle_merged_datasets: true
dataset_prepared_path: /workspace/data/sft-data
output_dir: /workspace/data/models/gemma-2-2b-axolotl-sft-v1.0

# valid datasetのサイズ
val_set_size: 0.05

# LoRAに関する設定(フルファインチューニングしたい場合は全て空欄にする)
adapter: qlora
lora_model_dir:
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

# wandbに関する設定
wandb_project: axolotl
wandb_entity: aratako-lm
wandb_watch:
wandb_name: sft-lora-1
wandb_log_model:

# 学習に関する様々な設定
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true

gradient_accumulation_steps: 16
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
cosine_min_lr_ratio: 0.1
learning_rate: 3e-4

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: false
early_stopping_patience:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

save_strategy: steps
save_steps: 50
save_total_limit: 2

warmup_steps: 10
eval_steps: 50
eval_batch_size: 1
eval_table_size:
eval_max_new_tokens:
debug:
deepspeed: /workspace/axolotl/deepspeed_configs/zero3_bf16.json
weight_decay: 0.01
fsdp:
fsdp_config:
special_tokens:
  pad_token: <pad>

特にデータセットの前処理については形式ごとに様々なオプションがあり複雑なため、下記リンクから公式のドキュメントをよく読むことを推奨します。

https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/

2. 前処理の実行

データセットのchat templateを使った加工や、事前のモデルダウンロード等の学習に必要となる前処理を実行します。
以下のようなコマンドで実行可能です。

python -m axolotl.cli.preprocess sft-gemma-2-2b.yml --debug

この処理を事前に行うことで、次のステップで学習をスムーズに始めることができます。また、--debugオプションをつけると、以下のようにchat templateで加工したデータセットに正常にloss maskが適用されているかの確認もできるので便利です。


debugオプションを付けた結果の出力。赤字部分がinstruction部分でloss maskがかかっており、completion onlyな学習になっていることが一目でわかる

3. 学習の実行

前処理した結果をもとに、実際の学習を行います。
以下のようなコマンドで実行可能です。マルチGPU環境の場合、DeepspeedやFSDPを使うことで分散学習が可能です。

accelerate launch -m axolotl.cli.train sft-gemma-2-2b.yml --deepspeed deepspeed_configs/zero3_bf16.json

実際の学習の進行はそのままコンソールに表示されます。また、設定したwandb上でも確認可能です。


train lossの推移

学習が完了すると、HF上にモデルがアップロードされます。上記の例ではこちらのリポジトリにアップロードされました。

https://huggingface.co/Aratako/gemma-2-2b-axolotl-sft-v1.0

*なお、私が試したタイミングではライブラリのバージョン不整合からか学習開始時にAttributeError: 'AdamW' object has no attribute 'optim_bits'というエラーが発生しました。これについてはaccelerateのバージョンを1.2.0から1.1.0にダウングレードすることで対応しました。

*余談ですが、このステップはデータセット周りの設定を変えなければ前処理をやり直さずとも何回もやり直せます。学習を試してCUDA OOM等が発生した場合、micro_batch_size等の変更であれば前処理の再実行は不要です。

4. 推論のテスト

学習が完了したモデルはHugging Faceにアップロードされるのでそこからダウンロードして推論のテストをしても良いですが、axolotlではそのままローカルで簡単に推論のテストが出来ます。

以下のようなコマンドで推論テスト用のプログラムを起動します。

accelerate launch -m axolotl.cli.inference sft-gemma-2-2b.yml --lora-model-dir="/workspace/data/models/gemma-2-2b-axolotl-sft-v1.0"

コンソールが入力待ちの状態になるので、テスト用の入力を行います。

入力サンプル

こんにちは!

出力サンプル(出力が止まらないと思いますが、実際の利用時は<end_of_turn>をeos_tokenとして出力をストップするので問題はありません)

こんにちは。何かお手伝いが必要ですか?ご質問や情報を提供していただければ、喜んでお手伝いします。どんなことでもお答えできるかもしれませんので、どうぞお気軽にお知らせください。<end_of_turn>

5. LoRA adapterのマージ

LoRAを使って学習した場合にはHFにLoRA adapterのみがアップロードされます。これは別の場所でマージしても良いのですが、axolotl側でもマージが可能です。ここでは、adapterをマージしてHF上にモデルを再度アップロードしてみます。

adapterのマージは以下のようなコマンドで実行可能です。

python -m axolotl.cli.merge_lora sft-gemma-2-2b.yml --lora-model-dir="/workspace/data/models/gemma-2-2b-axolotl-sft-v1.0"

マージが完了すると、saving merged model to: /workspace/data/models/gemma-2-2b-axolotl-sft-v1.0/mergedというような出力がされます。ここにマージ後の重みが保存されます。

READMEが保存されていないのでコピーしておき、huggingface-cli upload-large-folderを使って丸ごとHF上にアップロードしておきます。

cp /workspace/data/models/gemma-2-2b-axolotl-sft-v1.0/README.md /workspace/data/models/gemma-2-2b-axolotl-sft-v1.0/merged

huggingface-cli upload-large-folder Aratako/gemma-2-2b-axolotl-sft-v1.0-merged --repo-type=model /workspace/data/models/gemma-2-2b-axolotl-sft-v1.0/merged

完成したモデルはこちらにアップロードされました。

https://huggingface.co/Aratako/gemma-2-2b-axolotl-sft-v1.0-merged

axolotlを利用したLLMのPreference Optimization

次に、第二段階としてSimPOを使ったPreference Optimizationを行う流れを解説します。

1. 学習の設定を記述するyamlファイルの準備

SFTの場合と同様に、学習手法やハイパーパラメータ、利用するデータセットやその前処理など学習に関する各種設定をまとめて記述するyaml形式のconfigファイルを準備します。

ほとんどの項目はSFT時の設定と同様ですが、一部違う部分もあります。また、一部設定項目はPreference Optimizationの手法(DPO、ORPO、SimPOなど)によってそれぞれ異なります。一部手法の設定項目については下記リンクに記載があります。

https://axolotl-ai-cloud.github.io/axolotl/docs/rlhf.html

また、Preference Optimizationの際にはデータセットの加工に関する設定もSFT時とは異なります。
データセットの加工について、既存実装を使ったものではgemma2の形式に上手く対応できなかったので独自で簡易的なものを実装します。

まず、以下のような内容のpythonファイルをgemma.pyという名前で保存します。

gemma.py
def custom(cfg, **kwargs):  # pylint: disable=possibly-unused-variable,unused-argument
    """
    For My Customized Data
    """

    def transform_fn(sample):
        sample[
            "prompt"
        ] = f"<bos><start_of_turn>user\n{sample['prompt']}<end_of_turn>\n<start_of_turn>model\n"
        sample["chosen"] = f"{sample['chosen']}<end_of_turn>"
        sample["rejected"] = f"{sample['rejected']}<end_of_turn>"
        return sample

    return transform_fn

これをsrc/axolotl/prompt_strategies/dpo/gemma.pyに配置することで利用可能になります。

今回のSimPOでは以下のようなyamlファイルをsimpo-gemma-2-2b.ymlという名前で作成し利用します。

学習用のconfigファイル
simpo-gemma-2-2b.yml
# 学習のベースモデルに関する設定
# ベースモデルには先ほどSFTしたモデルを指定
base_model: Aratako/gemma-2-2b-axolotl-sft-v1.0-merged
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

# 学習後のモデルのHFへのアップロードに関する設定
hub_model_id: Aratako/gemma-2-2b-axolotl-simpo-v1.0
hub_strategy: "end"
push_dataset_to_hub:
hf_use_auth_token: true

# Liger Kernelの設定(学習の軽量・高速化)
plugins:
  - axolotl.integrations.liger.LigerPlugin
liger_cross_entropy: false
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true

# 量子化に関する設定
load_in_8bit: false
load_in_4bit: true

# chat templateはSFTの時から変更しない
chat_template: tokenizer_default

# SimPO関連の設定
rl: simpo
rl_beta: 2.0
simpo_gamma: 1.0

# cpo_alphaを設定するとCPO_SimPOになる
cpo_alpha: 0.05

# 学習データセットの前処理に関する設定
# type: gemma.customが先ほど追加したgemma.pyの実装を指す
datasets:
  - path: Aratako/aya-ja-evol-instruct-calm3-dpo-masked-formatted
    type: gemma.custom
    split: "train[0:10000]"

# データセット、モデルの出力先に関する設定
shuffle_merged_datasets: true
dataset_prepared_path: /workspace/data/simpo-data
output_dir: /workspace/data/models/gemma-2-2b-axolotl-simpo-v1.0

# LoRAに関する設定(フルファインチューニングしたい場合は全て空欄にする)
adapter: qlora
lora_model_dir:
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

# wandbに関する設定
wandb_project: axolotl
wandb_entity: aratako-lm
wandb_watch:
wandb_name: simpo-lora-1
wandb_log_model:

# 学習に関する様々な設定
sequence_len: 1024
# sample packingはPOに非対応
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true

gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 5e-6

train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: false

gradient_checkpointing: false
early_stopping_patience:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

save_strategy: steps
save_steps: 50
save_total_limit: 2

warmup_steps: 10
eval_steps: 50
eval_batch_size: 1
eval_table_size:
eval_max_new_tokens:
debug:
deepspeed: /workspace/axolotl/deepspeed_configs/zero3_bf16.json
weight_decay: 0.01
fsdp:
fsdp_config:
special_tokens:
  pad_token: <pad>

2. 前処理の実行

SFTの場合と同様に、以下のようなコマンドで前処理を実行します。

python -m axolotl.cli.preprocess simpo-gemma-2-2b.yml --debug

この処理により、以下のようにPrompt/Chosen/Rejectedの部分がそれぞれどのように加工されたかが分かりやすく表示されます。


debugオプションを付けた結果の出力。Prompt/Chosen/Rejectedそれぞれが異なる色で分かりやすく表示されている

3. 学習の実行

前処理した結果をもとに、実際の学習を行います。
SFTの時と同様に以下のようなコマンドで実行可能です。

accelerate launch -m axolotl.cli.train simpo-gemma-2-2b.yml --deepspeed deepspeed_configs/zero3_bf16.json

実際の学習の進行はそのままコンソールに表示されます。また、設定したwandb上でも確認可能です。


train lossの推移

学習が完了すると、HF上にモデルがアップロードされます。上記の例ではこちらのリポジトリにアップロードされました。

https://huggingface.co/Aratako/gemma-2-2b-axolotl-simpo-v1.0

4. 推論のテスト

SFT後と同じくローカルでaxolotlを使って推論のテストをします。

ここで、学習のconfig中のchat_templatetokenizer_defaultだと上手く起動できなかったので、一時的にgemmaに変えておきます。
その後、以下のようなコマンドで推論テスト用のプログラムを起動します。

accelerate launch -m axolotl.cli.inference simpo-gemma-2-2b.yml --lora-model-dir="/workspace/data/models/gemma-2-2b-axolotl-simpo-v1.0"

コンソールが入力待ちの状態になるので、テスト用の入力を行います。

入力サンプル

こんにちは!

出力サンプル(出力が止まらないと思いますが、実際の利用時は<end_of_turn>をeos_tokenとして出力をストップするので問題はありません)

こんにちは、何か特に特定についてお尋ねや情報を探していることがあれば、どんなことでもお話しして差し支えありませんか?どんな分野やテーマに関しても助けてお手伝いできるよう頑張ります。また日常の話から専門知識まで幅広く対応できそうですので、どうぞお気軽にどうぞ。まずはどんなトピックが気になるのか教えてください。<end_of_turn>

出力の精度はあまりよくなっていませんが、テスト推論は出来ました。

5. LoRA adapterのマージ

SFTの時と同様にadapterをマージしてHF上にモデルを再度アップロードしてみます。

adapterのマージは以下のようなコマンドで実行可能です。

python -m axolotl.cli.merge_lora simpo-gemma-2-2b.yml --lora-model-dir="/workspace/data/models/gemma-2-2b-axolotl-simpo-v1.0"

マージが完了すると今回の例では/workspace/data/models/gemma-2-2b-axolotl-simpo-v1.0/mergedにマージ後の重みが保存されます。

READMEが保存されていないのでコピーしておき、huggingface-cli upload-large-folderを使って丸ごとHF上にアップロードしておきます。

cp /workspace/data/models/gemma-2-2b-axolotl-simpo-v1.0/README.md /workspace/data/models/gemma-2-2b-axolotl-simpo-v1.0/merged

huggingface-cli upload-large-folder Aratako/gemma-2-2b-axolotl-simpo-v1.0-merged --repo-type=model /workspace/data/models/gemma-2-2b-axolotl-simpo-v1.0/merged

完成したモデルはこちらにアップロードされました。

https://huggingface.co/Aratako/gemma-2-2b-axolotl-simpo-v1.0-merged

まとめ

この記事では、axolotlというツールを使ってLLMをファインチューニングする方法について解説しました。

実際にある程度利用してみるとバグやおかしな挙動も見つかりますが、configファイルを用意するだけで準備が完了するので使ってみるとかなり扱いやすく感じると思います。また、configファイルが配布されていることも多いので、先駆者様の設定を真似しやすいのも魅力的に感じます。

もしこの記事を読んで興味が湧いたという方がいればぜひ手元で試してみてください。

Discussion