Gemma 2 2B 日本語ファインチューニング & TPUv3-8 + Kaggle Hub公開

2024/08/07に公開

このノートブックでは、Googleが新たにリリースした軽量ながらも高性能な言語モデル Gemma 2 2B を、日本語データセット databricks-dolly-15k-ja でファインチューニングする方法を紹介します。さらに、KaggleのTPU v3-8を活用することで、効率的な学習を実現します。ファインチューニング後、モデルをKaggle Hubにアップロードする手順までを解説します。

この記事は、大規模言語モデル(LLM)の学習に興味がある初心者の方々を対象としています。 各ステップで丁寧な解説を加え、コードブロックには詳細なコメントを付与することで、スムーズに理解を進められるように工夫しました。

環境設定

まずは必要なライブラリをインストールし、TPUを使用するための環境設定を行います。

# 必要なライブラリのインストール
!pip install -q -U keras-nlp tensorflow-text
!pip install -q -U tensorflow-cpu
!pip install -q datasets kagglehub kaggle_secrets rich

# Keras JAXバックエンドの設定 (TPUを使用するため)
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"  # TPUメモリの使用効率を最大化

# 必要なモジュールのインポート
import jax
import keras
import keras_nlp
from datasets import load_dataset
import pandas as pd
from rich.console import Console
from rich.tree import Tree
from rich.panel import Panel
from rich.text import Text
from rich.markdown import Markdown
from rich.box import ROUNDED

# Rich Consoleのインスタンスを作成
console = Console()

TPUの確認

Kaggleでは、TPUv3-8デバイスが提供されています。以下のコードで、TPUが正しく認識されているか確認しましょう。

# 利用可能なTPUデバイスの一覧を表示
jax.devices()

出力結果にTPUデバイスが表示されれば、TPUを使用する準備が整っています。

モデルのロードと分散設定

Gemma 2 2Bモデルをロードし、TPUでの分散学習のための設定を行います。

# デバイスメッシュの作成 (4つのTPUコアを使用)
device_mesh = keras.distribution.DeviceMesh(
    (1, 4),  # TPUコアの形状 (行, 列)
    ["batch", "model"],  # 分散する次元
    devices=keras.distribution.list_devices()[:4],  # 最初の4つのデバイスを使用
)

# レイアウトマップの設定 (モデルの重みをTPUコアにどのように配置するかを指定)
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (model_dim, None)
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)

# モデル並列化の設定
model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map,
    batch_dim_name="batch",
)

# 分散設定の適用
keras.distribution.set_distribution(model_parallel)

# Gemma 2 2Bモデルのロード (事前学習済みの重みを使用)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

解説:

  • デバイスメッシュ: TPUコアをどのように配置するかを定義します。ここでは1行4列のメッシュを作成し、4つのTPUコアを使用します。
  • レイアウトマップ: モデルの各層の重みをどのTPUコアに配置するかを指定します。
  • モデル並列化: モデルを複数のTPUコアに分散して学習するための設定を行います。

モデルの構造確認 (create_layer_tree)

モデルの構造を可視化するために、create_layer_tree 関数を使用します。

# モデルの構造をツリー形式で表示する関数
def create_layer_tree(layer):
    tree = Tree(f"[bold blue]{layer.name}[/bold blue]")
    
    important_attrs = ['_layers', 'transformer_layers', 'layer_norm', '_token_embedding']
    for attr, value in vars(layer).items():
        if attr in important_attrs:
            if isinstance(value, list):
                subtree = tree.add(f"[yellow]{attr}[/yellow]")
                for item in value:
                    subtree.add(f"[green]{item.name}[/green]: {item.__class__.__name__}")
            else:
                tree.add(f"[yellow]{attr}[/yellow]: [green]{value.name}[/green] ({value.__class__.__name__})")
        elif not attr.startswith('_') and not callable(value):
            tree.add(f"[cyan]{attr}[/cyan]: {value}")
    
    return tree

# GemmaDecoderBlockの内部構造を確認
for layer in gemma_lm.layers:
    console.print(Panel(create_layer_tree(layer), title=f"[bold red]{layer.__class__.__name__}[/bold red]", expand=True))

# 埋め込み層の確認
embedding_layer = gemma_lm.get_layer('token_embedding')
console.print(Panel(
    f"[bold magenta]Embedding Layer[/bold magenta]\n"
    f"Name: [green]{embedding_layer.name}[/green]\n"
    f"Shape: [yellow]{embedding_layer.embeddings.shape}[/yellow]",
    title="Embedding Layer Info",
    border_style="blue"
))

データの準備

日本語の指示応答データセット databricks-dolly-15k-ja を使用して、Gemma 2 2Bモデルをファインチューニングします。

# データセットの読み込み
dataset = load_dataset('kunishou/databricks-dolly-15k-ja')
df_databricks = pd.DataFrame(dataset['train'])
df_databricks = df_databricks[["instruction", "output"]]

# ファインチューニング用のデータ形式に変換
data = []
for _, row in df_databricks.iterrows():
    template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
    data.append(template.format(instruction=row['instruction'], response=row['output']))

# データ量を制限 (オプション: リソースが少ない場合はデータ量を減らす)
data = data[:2500]

# データの内容を表示 (最初の3件)
for i, item in enumerate(data[:3], 1):
    console.print(Panel(Text(item), title=f"[bold blue]Data Item {i}[/bold blue]", expand=False))

解説:

  • databricks-dolly-15k-ja: 日本語の指示応答データセット。様々なタスクに対応する指示と応答のペアが含まれています。
  • データ形式の変換: Gemma 2 2Bモデルのファインチューニングに適した形式に変換しています。

ファインチューニング前の推論 (generate_and_display)

ファインチューニングを行う前に、generate_and_display 関数を使用して、モデルの現在の性能を確認しましょう。

# テキスト生成と表示を行う関数
def generate_and_display(prompt, max_length=512):
    generated_text = gemma_lm.generate(prompt, max_length=max_length)
    
    prompt_panel = Panel(
        Text(prompt, style="bold magenta"),
        title="[blue]Input Prompt[/blue]",
        border_style="blue",
        box=ROUNDED,
    )

    generated_md = Markdown(generated_text)
    
    token_count = len(generated_text.split())
    token_info = Text(f"\n\nGenerated {token_count} tokens.", style="italic cyan")

    output_panel = Panel(
        generated_md,
        title="[green]Generated Response[/green]",
        border_style="green",
        box=ROUNDED,
    )

    console.print(prompt_panel)
    console.print()
    console.print(output_panel)
    console.print(token_info)

# 関数を呼び出して結果を表示
generate_and_display("ヴァージン・オーストラリア航空はいつから運航を開始したのですか? ", max_length=512)

LoRA(Low-Rank Adaptation)の設定

LoRA (Low-Rank Adaptation) を使用することで、モデルのパラメータの大部分を凍結し、少数の学習可能なパラメータを追加することで、効率的なファインチューニングを実現します。

# LoRAの有効化 (rank=8: 学習可能なパラメータのランク)
gemma_lm.backbone.enable_lora(rank=8)

# モデルのコンパイル
gemma_lm.preprocessor.sequence_length = 512  # 入力シーケンス長を制限
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=5e-5),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# モデルのサマリを表示
gemma_lm.summary()

解説:

  • LoRA: モデルの大部分を凍結し、少数の学習可能なパラメータを追加することで、効率的なファインチューニングを可能にする手法。
  • rank: LoRAで追加する学習可能なパラメータのランク。

ファインチューニングの実行

準備が整ったので、Gemma 2 2Bモデルのファインチューニングを実行します。

# ファインチューニングの実行
gemma_lm.fit(data, epochs=1, batch_size=4)  # epochs: 学習回数, batch_size: バッチサイズ

解説:

  • epochs: 学習を繰り返す回数。
  • batch_size: 1回の学習で使用するデータの数。

ファインチューニング後の推論

ファインチューニング後のモデルを使って、実際にテキストを生成してみましょう。

# ファインチューニング後の推論
generate_and_display("Instruction:\n日本の首都はどこですか?\n\nResponse:\n", max_length=512)

モデルの保存とKaggle Hubへのアップロード

ファインチューニングしたモデルを保存し、Kaggle Hubにアップロードします。

# モデルの保存先ディレクトリ
FINETUNED_MODEL_DIR = f"./gemma2_2_demo"

# モデル情報
MODEL_BASE = "gemma2_2b_demo"
MODEL_NAME = f"{MODEL_BASE}_train_finetuning_h5"
FINETUNED_WEIGHTS_PATH = f"{FINETUNED_MODEL_DIR}/{MODEL_NAME}.weights.h5"
FINETUNED_VOCAB_PATH = f"{FINETUNED_MODEL_DIR}/vocabulary.spm"
FRAMEWORK = "jax"
VER = 1

# ディレクトリ作成
os.makedirs(FINETUNED_MODEL_DIR, exist_ok=True)

# モデルの重みとトークナイザーのアセットを保存
gemma_lm.save_weights(FINETUNED_WEIGHTS_PATH)
gemma_lm.preprocessor.tokenizer.save_assets(FINETUNED_MODEL_DIR)

# Kaggle Secretsから認証情報を取得
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
KAGGLE_USERNAME = user_secrets.get_secret("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = user_secrets.get_secret("KAGGLE_KEY")
os.environ["KAGGLE_USERNAME"] = KAGGLE_USERNAME

# Kaggle Hubへのアップロード
import kagglehub
handle = f'{KAGGLE_USERNAME}/{MODEL_BASE}/{FRAMEWORK}/{MODEL_NAME}'
kagglehub.model_upload(handle, FINETUNED_WEIGHTS_PATH, license_name='Apache 2.0', version_notes=f'v{VER}')

解説:

  • モデルの保存: ファインチューニングしたモデルの重みとトークナイザーのアセットを保存します。
  • Kaggle Hubへのアップロード: kagglehub ライブラリを使用して、保存したモデルをKaggle Hubにアップロードします。

これで、Gemma 2 2Bモデルの日本語データセットでのファインチューニングとKaggle Hubへのアップロードが完了しました。お疲れ様でした!

📒ノートブック

https://www.kaggle.com/code/makimakiai/jp-gemma2-2b-tpu-fine-tuning-dollyja-kagglehub2

Discussion