😺

Intel MacでGemmaのファインチューニングをする

に公開

Intel MacでGemmaのファインチューニングする作業をしている方がいなかったので、テストでかいてみました。

max_seq_length=1024をコメント化していますが、動作する環境であればコメントを外してください。

pip install torch transformers datasets peft trl accelerate sentencepiece gradio bitsandbytes
finetuning
from pathlib import Path
import json
import traceback
import gc
import shutil

import gradio as gr
from datasets import Dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from trl import SFTTrainer
from transformers.trainer_utils import get_last_checkpoint
import psutil

MODEL_NAME = "google/gemma-2b-it"
OUTPUT_DIR = Path("output")
FINAL_CHECKPOINT_DIR = OUTPUT_DIR / "final_checkpoint"

try:
    if torch.backends.mps.is_available(): DEVICE = torch.device("mps"); print("✅ MPS Detected.")
    elif torch.cuda.is_available(): DEVICE = torch.device("cuda"); print("✅ CUDA Detected.")
    else: DEVICE = torch.device("cpu"); print("⚠️ MPS/CUDA not available. Using CPU.")
except Exception as e:
    print(f"Device detection error: {e}"); DEVICE = torch.device("cpu")
    print("⚠️ Using CPU due to detection error.")

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1
)

def chat_format(example: dict) -> str:
    if 'instruction' not in example or 'response' not in example:
        raise ValueError("Data format error: 'instruction' and 'response' keys required.")
    return f"<bos><start_of_turn>user\n{example['instruction']}<end_of_turn>\n<start_of_turn>model\n{example['response']}<end_of_turn>"

print(f"Loading Tokenizer: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True, padding_side="right")
if tokenizer.pad_token is None:
    print("Setting pad_token = eos_token"); tokenizer.pad_token = tokenizer.eos_token
else: print(f"Tokenizer pad_token: {tokenizer.pad_token}")

print(f"Loading Model: {MODEL_NAME}")
model = None
try:
    import accelerate
    print("Accelerate installed. Using device_map='auto'.")
    mps_memory_limit_gb = 6.0
    total_ram_gb = psutil.virtual_memory().total / (1024**3)
    cpu_buffer_gb = 4.0
    cpu_memory_gb = max(0, total_ram_gb - mps_memory_limit_gb - cpu_buffer_gb)
    max_memory = None; model_dtype = torch.float32
    if DEVICE.type == 'mps':
        max_memory = {'mps': f'{mps_memory_limit_gb:.1f}GiB', 'cpu': f'{cpu_memory_gb:.1f}GiB'}
        model_dtype = torch.float16
        print(f"✅ MPS detected. Using dtype={model_dtype}. Max memory: {max_memory}")
    elif DEVICE.type == 'cuda':
         model_dtype = torch.float16
         print(f"✅ CUDA detected. Using dtype={model_dtype}.")
    else: print(f"✅ CPU detected. Using dtype={model_dtype}.")

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, trust_remote_code=True, device_map="auto",
        torch_dtype=model_dtype, max_memory=max_memory
    )
    print("Model loaded with device_map='auto'.")
    if hasattr(model, 'hf_device_map'): print(f"   Device map: {model.hf_device_map}")
    else: print(f"   Model loaded on single device: {model.device}")

except ImportError:
    print("⚠️ Accelerate not installed. Loading model manually.")
    try:
        model_dtype = torch.float16 if DEVICE.type == 'mps' else torch.float32
        print(f"   Attempting load with dtype={model_dtype} on {DEVICE}.")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME, trust_remote_code=True, torch_dtype=model_dtype,
        ).to(DEVICE)
        print(f"Model loaded manually on {DEVICE}.")
    except RuntimeError as e_manual_load:
        if "out of memory" in str(e_manual_load).lower():
            print(f"❌ OOM Error during manual load ({DEVICE}, {model_dtype}).")
            print("   Consider using a smaller model or CPU. Check PYTORCH_MPS_HIGH_WATERMARK_RATIO (use with caution).")
        raise e_manual_load
    except Exception as e: print(f"Unexpected error during manual load: {e}"); raise e
except Exception as e: print(f"Model loading failed: {e}"); raise e

try:
    model = get_peft_model(model, peft_config)
    print("PEFT model prepared.")
    model.print_trainable_parameters()
except Exception as e: print(f"PEFT model preparation failed: {e}"); raise e

def finetune_from_json(json_file_obj, resume_training, progress=gr.Progress(track_tqdm=True)):
    training_started = False
    try:
        training_args = TrainingArguments(
            output_dir=str(OUTPUT_DIR),
            per_device_train_batch_size=1,
            gradient_accumulation_steps=4,
            num_train_epochs=1,
            learning_rate=2e-4,
            logging_strategy="steps",
            logging_steps=10,
            save_strategy="steps",
            save_steps=50,
            save_total_limit=2,
            logging_dir=str(OUTPUT_DIR / "logs"),
            report_to="none",
            no_cuda=True, fp16=False, bf16=False,
            ddp_find_unused_parameters=False,
        )
        print("TrainingArguments configured:")
        print(f"Output directory (checkpoints & final model): {training_args.output_dir}")
        print(f"Checkpoint save interval: {training_args.save_steps} steps")
        print(f"Log interval: {training_args.logging_steps} steps (Check console)")

        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        print(f"Searching for checkpoints in {training_args.output_dir}...")
        print(f"Latest checkpoint found: {last_checkpoint}")

        resume_path = None
        if resume_training:
            if last_checkpoint:
                print(f"✅ Resuming training from checkpoint: {last_checkpoint}")
                resume_path = last_checkpoint
                if json_file_obj:
                    gr.Warning("学習再開オプションが選択されたため、アップロードされたJSONファイルは無視されます。")
                    json_file_obj = None
            else:
                return f"❌ 再開オプションが選択されましたが、{training_args.output_dir} に有効なチェックポイントが見つかりません。"
        elif last_checkpoint:
             gr.Warning(f"⚠️ 注意: {training_args.output_dir} に既存のチェックポイント ({last_checkpoint}) があります。"
                       "今回は最初から学習を開始するため、これらのチェックポイントは学習中に上書きされる可能性があります。")

        dataset = None
        if not resume_path:
            if json_file_obj is None:
                return "❌ 最初から学習を開始するにはJSONファイルをアップロードしてください。"
            json_path = Path(json_file_obj.name)
            progress(0, desc="JSONファイルを読み込み中...")
            print(f"Loading JSON from: {json_path}")
            with json_path.open("r", encoding="utf-8") as f: data = json.load(f)

            if not isinstance(data, list): return f"❌ JSON must be a list, got {type(data)}."
            if not data: return "❌ JSON file is empty."
            if not all(isinstance(item, dict) and 'instruction' in item and 'response' in item for item in data):
                 return "❌ Invalid JSON data format. Each item must be a dict with 'instruction' and 'response'."

            print(f"Loaded {len(data)} records from JSON.")
            dataset = Dataset.from_list(data)
            print("Dataset created successfully.")
        else:
            print("Resuming training. Creating dummy dataset for Trainer initialization.")
            dummy_data = [{"instruction": "dummy", "response": "dummy"}]
            dataset = Dataset.from_list(dummy_data)

        trainer = SFTTrainer(
            model=model,
            train_dataset=dataset,
            args=training_args,
            formatting_func=chat_format,
        )
        print("SFTTrainer initialized.")

        progress(0.1, desc="ファインチューニング実行中... (詳細はコンソールを確認)")
        print("\n" + "="*20 + " Training Start " + "="*20)
        print("📊 進捗とLossはコンソールに出力されます。")
        print("⏸️ 中断したい場合は Ctrl+C で停止してください。次回再開できます。")
        training_started = True
        train_result = trainer.train(resume_from_checkpoint=resume_path)
        print("="*20 + " Training Finished " + "="*20)

        progress(0.9, desc="最終モデルを保存中...")
        FINAL_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
        print(f"Saving final LoRA adapter and tokenizer to {FINAL_CHECKPOINT_DIR}")
        model.save_pretrained(str(FINAL_CHECKPOINT_DIR))
        tokenizer.save_pretrained(str(FINAL_CHECKPOINT_DIR))

        try:
            del trainer
            gc.collect()
            if DEVICE.type == 'mps': torch.mps.empty_cache()
            elif DEVICE.type == 'cuda': torch.cuda.empty_cache()
            print("Memory released.")
        except Exception as e_mem: print(f"Warning during memory release: {e_mem}")

        progress(1.0, desc="完了")
        return (f"✔︎ ファインチューニングが完了しました。\n"
                f"   最終モデル(LoRAアダプタ)は '{FINAL_CHECKPOINT_DIR}' に保存されました。\n"
                f"   トレーニング統計: {train_result.metrics}")

    except Exception as e:
        print(f"❌ Error during fine-tuning: {e}")
        print(traceback.format_exc())
        error_prefix = "❌ エラーが発生しました:" if not training_started else "❌ 学習中にエラーが発生しました:"
        resume_msg = ""
        if training_started and get_last_checkpoint(str(OUTPUT_DIR)):
             resume_msg = f"\n\nℹ️ 最後に保存されたチェックポイント ({get_last_checkpoint(str(OUTPUT_DIR))}) から再開を試みることができます。"

        return f"{error_prefix}\n{e}\n\nTraceback:\n{traceback.format_exc()}{resume_msg}"

def launch_ui():
    with gr.Blocks(title="Gemmaファインチューニング (Intel Mac)", theme=gr.themes.Soft()) as demo:
        gr.Markdown(f"""
        # ✅ Gemma ファインチューニング GUI (Intel Mac版)

        学習用のJSONファイルをアップロードするか、前回の続きから再開してください。
        *   **JSON形式:** `[ {{"instruction": "指示", "response": "応答"}} ]` のリスト。
        *   **モデル:** Gemma ({MODEL_NAME}) LoRA + SFTTrainer
        *   **進捗確認:** 学習のステップ、Lossなどの**詳細な進捗は、このアプリを実行しているターミナル(コンソール)で確認**してください。
        *   **一時停止/再開:** 学習を中断したい場合は、コンソールで `Ctrl+C` を押してスクリプトを停止します。`save_steps` ごとにチェックポイントが `{OUTPUT_DIR}` に自動保存されるため、次回起動時に下の「途中から再開する」にチェックを入れて開始すれば、続きから再開できます。
        *   **注意 (速度):** Intel Mac ({DEVICE.type.upper()}) での学習は時間がかかることがあります。
        *   **注意 (max_seq_length):** 現在、ライブラリ互換性のため `max_seq_length` が設定されていません。メモリ不足のリスクや効率低下を防ぐため、可能であればターミナルで `pip install --upgrade trl` を実行して `trl` を更新し、コード内の `max_seq_length=1024,` のコメントを解除してください。
        *   **最終モデル:** 完了後、LoRAアダプタは `{FINAL_CHECKPOINT_DIR}` に保存されます。
        """)

        with gr.Row():
            with gr.Column(scale=1):
                file_input = gr.File(
                    label="学習用 JSON ファイル (新規開始時のみ)",
                    file_types=[".json"],
                    type="filepath"
                )
                resume_checkbox = gr.Checkbox(
                    label=f"利用可能なチェックポイント ({OUTPUT_DIR}) から学習を再開する",
                    value=False
                )

            with gr.Column(scale=2):
                 output_text = gr.Textbox(
                     label="結果 / ログ概要",
                     lines=15,
                     interactive=False,
                     placeholder="ここに結果メッセージが表示されます。詳細なログはコンソールを確認してください。"
                 )

        run_button = gr.Button("▶ ファインチューニング開始 / 再開", variant="primary")

        run_button.click(
            fn=finetune_from_json,
            inputs=[file_input, resume_checkbox],
            outputs=[output_text]
        )

    print("Gradio UI launching...")
    demo.queue().launch()

if __name__ == "__main__":
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    launch_ui()
train
[
  {
    "instruction": "日本の首都は?",
    "response": "日本の首都は東京です。"
  },
  {
    "instruction": "富士山の標高を教えてください。",
    "response": "富士山の標高は3,776メートルです。"
  }
]

Discussion