Gemma 2 2B 日本語ファインチューニング & TPUv3-8 + Kaggle Hub公開
このノートブックでは、Googleが新たにリリースした軽量ながらも高性能な言語モデル Gemma 2 2B を、日本語データセット databricks-dolly-15k-ja でファインチューニングする方法を紹介します。さらに、KaggleのTPU v3-8を活用することで、効率的な学習を実現します。ファインチューニング後、モデルをKaggle Hubにアップロードする手順までを解説します。
この記事は、大規模言語モデル(LLM)の学習に興味がある初心者の方々を対象としています。 各ステップで丁寧な解説を加え、コードブロックには詳細なコメントを付与することで、スムーズに理解を進められるように工夫しました。
環境設定
まずは必要なライブラリをインストールし、TPUを使用するための環境設定を行います。
# 必要なライブラリのインストール
!pip install -q -U keras-nlp tensorflow-text
!pip install -q -U tensorflow-cpu
!pip install -q datasets kagglehub kaggle_secrets rich
# Keras JAXバックエンドの設定 (TPUを使用するため)
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0" # TPUメモリの使用効率を最大化
# 必要なモジュールのインポート
import jax
import keras
import keras_nlp
from datasets import load_dataset
import pandas as pd
from rich.console import Console
from rich.tree import Tree
from rich.panel import Panel
from rich.text import Text
from rich.markdown import Markdown
from rich.box import ROUNDED
# Rich Consoleのインスタンスを作成
console = Console()
TPUの確認
Kaggleでは、TPUv3-8デバイスが提供されています。以下のコードで、TPUが正しく認識されているか確認しましょう。
# 利用可能なTPUデバイスの一覧を表示
jax.devices()
出力結果にTPUデバイスが表示されれば、TPUを使用する準備が整っています。
モデルのロードと分散設定
Gemma 2 2Bモデルをロードし、TPUでの分散学習のための設定を行います。
# デバイスメッシュの作成 (4つのTPUコアを使用)
device_mesh = keras.distribution.DeviceMesh(
(1, 4), # TPUコアの形状 (行, 列)
["batch", "model"], # 分散する次元
devices=keras.distribution.list_devices()[:4], # 最初の4つのデバイスを使用
)
# レイアウトマップの設定 (モデルの重みをTPUコアにどのように配置するかを指定)
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (model_dim, None)
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)
# モデル並列化の設定
model_parallel = keras.distribution.ModelParallel(
layout_map=layout_map,
batch_dim_name="batch",
)
# 分散設定の適用
keras.distribution.set_distribution(model_parallel)
# Gemma 2 2Bモデルのロード (事前学習済みの重みを使用)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
解説:
- デバイスメッシュ: TPUコアをどのように配置するかを定義します。ここでは1行4列のメッシュを作成し、4つのTPUコアを使用します。
- レイアウトマップ: モデルの各層の重みをどのTPUコアに配置するかを指定します。
- モデル並列化: モデルを複数のTPUコアに分散して学習するための設定を行います。
モデルの構造確認 (create_layer_tree)
モデルの構造を可視化するために、create_layer_tree
関数を使用します。
# モデルの構造をツリー形式で表示する関数
def create_layer_tree(layer):
tree = Tree(f"[bold blue]{layer.name}[/bold blue]")
important_attrs = ['_layers', 'transformer_layers', 'layer_norm', '_token_embedding']
for attr, value in vars(layer).items():
if attr in important_attrs:
if isinstance(value, list):
subtree = tree.add(f"[yellow]{attr}[/yellow]")
for item in value:
subtree.add(f"[green]{item.name}[/green]: {item.__class__.__name__}")
else:
tree.add(f"[yellow]{attr}[/yellow]: [green]{value.name}[/green] ({value.__class__.__name__})")
elif not attr.startswith('_') and not callable(value):
tree.add(f"[cyan]{attr}[/cyan]: {value}")
return tree
# GemmaDecoderBlockの内部構造を確認
for layer in gemma_lm.layers:
console.print(Panel(create_layer_tree(layer), title=f"[bold red]{layer.__class__.__name__}[/bold red]", expand=True))
# 埋め込み層の確認
embedding_layer = gemma_lm.get_layer('token_embedding')
console.print(Panel(
f"[bold magenta]Embedding Layer[/bold magenta]\n"
f"Name: [green]{embedding_layer.name}[/green]\n"
f"Shape: [yellow]{embedding_layer.embeddings.shape}[/yellow]",
title="Embedding Layer Info",
border_style="blue"
))
データの準備
日本語の指示応答データセット databricks-dolly-15k-ja を使用して、Gemma 2 2Bモデルをファインチューニングします。
# データセットの読み込み
dataset = load_dataset('kunishou/databricks-dolly-15k-ja')
df_databricks = pd.DataFrame(dataset['train'])
df_databricks = df_databricks[["instruction", "output"]]
# ファインチューニング用のデータ形式に変換
data = []
for _, row in df_databricks.iterrows():
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
data.append(template.format(instruction=row['instruction'], response=row['output']))
# データ量を制限 (オプション: リソースが少ない場合はデータ量を減らす)
data = data[:2500]
# データの内容を表示 (最初の3件)
for i, item in enumerate(data[:3], 1):
console.print(Panel(Text(item), title=f"[bold blue]Data Item {i}[/bold blue]", expand=False))
解説:
- databricks-dolly-15k-ja: 日本語の指示応答データセット。様々なタスクに対応する指示と応答のペアが含まれています。
- データ形式の変換: Gemma 2 2Bモデルのファインチューニングに適した形式に変換しています。
ファインチューニング前の推論 (generate_and_display)
ファインチューニングを行う前に、generate_and_display
関数を使用して、モデルの現在の性能を確認しましょう。
# テキスト生成と表示を行う関数
def generate_and_display(prompt, max_length=512):
generated_text = gemma_lm.generate(prompt, max_length=max_length)
prompt_panel = Panel(
Text(prompt, style="bold magenta"),
title="[blue]Input Prompt[/blue]",
border_style="blue",
box=ROUNDED,
)
generated_md = Markdown(generated_text)
token_count = len(generated_text.split())
token_info = Text(f"\n\nGenerated {token_count} tokens.", style="italic cyan")
output_panel = Panel(
generated_md,
title="[green]Generated Response[/green]",
border_style="green",
box=ROUNDED,
)
console.print(prompt_panel)
console.print()
console.print(output_panel)
console.print(token_info)
# 関数を呼び出して結果を表示
generate_and_display("ヴァージン・オーストラリア航空はいつから運航を開始したのですか? ", max_length=512)
LoRA(Low-Rank Adaptation)の設定
LoRA (Low-Rank Adaptation) を使用することで、モデルのパラメータの大部分を凍結し、少数の学習可能なパラメータを追加することで、効率的なファインチューニングを実現します。
# LoRAの有効化 (rank=8: 学習可能なパラメータのランク)
gemma_lm.backbone.enable_lora(rank=8)
# モデルのコンパイル
gemma_lm.preprocessor.sequence_length = 512 # 入力シーケンス長を制限
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(learning_rate=5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# モデルのサマリを表示
gemma_lm.summary()
解説:
- LoRA: モデルの大部分を凍結し、少数の学習可能なパラメータを追加することで、効率的なファインチューニングを可能にする手法。
- rank: LoRAで追加する学習可能なパラメータのランク。
ファインチューニングの実行
準備が整ったので、Gemma 2 2Bモデルのファインチューニングを実行します。
# ファインチューニングの実行
gemma_lm.fit(data, epochs=1, batch_size=4) # epochs: 学習回数, batch_size: バッチサイズ
解説:
- epochs: 学習を繰り返す回数。
- batch_size: 1回の学習で使用するデータの数。
ファインチューニング後の推論
ファインチューニング後のモデルを使って、実際にテキストを生成してみましょう。
# ファインチューニング後の推論
generate_and_display("Instruction:\n日本の首都はどこですか?\n\nResponse:\n", max_length=512)
モデルの保存とKaggle Hubへのアップロード
ファインチューニングしたモデルを保存し、Kaggle Hubにアップロードします。
# モデルの保存先ディレクトリ
FINETUNED_MODEL_DIR = f"./gemma2_2_demo"
# モデル情報
MODEL_BASE = "gemma2_2b_demo"
MODEL_NAME = f"{MODEL_BASE}_train_finetuning_h5"
FINETUNED_WEIGHTS_PATH = f"{FINETUNED_MODEL_DIR}/{MODEL_NAME}.weights.h5"
FINETUNED_VOCAB_PATH = f"{FINETUNED_MODEL_DIR}/vocabulary.spm"
FRAMEWORK = "jax"
VER = 1
# ディレクトリ作成
os.makedirs(FINETUNED_MODEL_DIR, exist_ok=True)
# モデルの重みとトークナイザーのアセットを保存
gemma_lm.save_weights(FINETUNED_WEIGHTS_PATH)
gemma_lm.preprocessor.tokenizer.save_assets(FINETUNED_MODEL_DIR)
# Kaggle Secretsから認証情報を取得
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
KAGGLE_USERNAME = user_secrets.get_secret("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = user_secrets.get_secret("KAGGLE_KEY")
os.environ["KAGGLE_USERNAME"] = KAGGLE_USERNAME
# Kaggle Hubへのアップロード
import kagglehub
handle = f'{KAGGLE_USERNAME}/{MODEL_BASE}/{FRAMEWORK}/{MODEL_NAME}'
kagglehub.model_upload(handle, FINETUNED_WEIGHTS_PATH, license_name='Apache 2.0', version_notes=f'v{VER}')
解説:
- モデルの保存: ファインチューニングしたモデルの重みとトークナイザーのアセットを保存します。
-
Kaggle Hubへのアップロード:
kagglehub
ライブラリを使用して、保存したモデルをKaggle Hubにアップロードします。
これで、Gemma 2 2Bモデルの日本語データセットでのファインチューニングとKaggle Hubへのアップロードが完了しました。お疲れ様でした!
📒ノートブック
Discussion