🎸

Gemma2-2b: TPUを活用したファインチューニングとKagglehubへのアップロード

2024/08/03に公開

はじめに

こんにちは!今回は、Googleが新しくリリースしたGemma2-2bモデルを使って、TPU(Tensor Processing Unit)を活用したファインチューニングを行い、その結果をKagglehubにアップロードする方法をご紹介します。

このチュートリアルは、大規模言語モデル(LLM)の学習に興味がある初心者の方々を対象としています。ステップバイステップで進めていきますので、ぜひ最後までお付き合いください!

環境設定

まずは、必要なライブラリをインストールし、環境を整えます。

# 必要なライブラリのインストール
!pip install -q -U keras-nlp tensorflow-text
!pip install -q -U tensorflow-cpu

# Keras JAXバックエンドの設定
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

# 必要なモジュールのインポート
import jax
import keras
import keras_nlp

このコードでは以下のことを行っています:

  1. keras-nlptensorflow-textをインストール
  2. tensorflow-cpuをインストール(TPUを使用するため)
  3. KerasのバックエンドをJAXに設定
  4. TPUメモリの使用効率を最大化するための設定
  5. 必要なモジュールをインポート

TPUの確認

Kaggleでは、TPUv3-8デバイスが提供されています。各TPUコアは16GBのメモリを持っています。以下のコードでTPUの状態を確認しましょう。

# 利用可能なJAXデバイス(TPU)の確認
print(jax.devices())

このコードを実行すると、利用可能なTPUデバイスの一覧が表示されます。8つのTPUコアが見えるはずです。

モデルのロードと分散設定

次に、Gemma2-2bモデルをロードし、分散学習の設定を行います。

# デバイスメッシュの作成
device_mesh = keras.distribution.DeviceMesh(
    (1, 4),  # 4つのTPUコアを使用
    ["batch", "model"],
    devices=keras.distribution.list_devices()[:4],  # 最初の4つのデバイスのみを使用
)

# レイアウトマップの設定
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)

# 重みの分散方法を指定
layout_map["token_embedding/embeddings"] = (model_dim, None)
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)

# モデル並列化の設定
model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map,
    batch_dim_name="batch",
)

# 分散設定の適用
keras.distribution.set_distribution(model_parallel)

# Gemma2-2bモデルのロード
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")

このコードでは以下のことを行っています:

  1. TPUコアを使用するためのデバイスメッシュを作成
  2. モデルの重みをTPUコア間で分散させるためのレイアウトマップを設定
  3. モデル並列化の設定を行い、分散学習を有効化
  4. Gemma2-2bモデルを事前学習済みの重みでロード

これにより、大規模なモデルを効率的に扱うことができます。

モデルの構造確認

モデルの構造を可視化するために、Richライブラリを使用します。以下のコードで、モデルの各層の詳細を確認できます。

# Richライブラリのインポート
from rich.console import Console
from rich.tree import Tree
from rich.panel import Panel

console = Console()

# モデルの層構造を表示する関数
def create_layer_tree(layer):
    tree = Tree(f"[bold blue]{layer.name}[/bold blue]")
    
    important_attrs = ['_layers', 'transformer_layers', 'layer_norm', '_token_embedding']
    for attr, value in vars(layer).items():
        if attr in important_attrs:
            if isinstance(value, list):
                subtree = tree.add(f"[yellow]{attr}[/yellow]")
                for item in value:
                    subtree.add(f"[green]{item.name}[/green]: {item.__class__.__name__}")
            else:
                tree.add(f"[yellow]{attr}[/yellow]: [green]{value.name}[/green] ({value.__class__.__name__})")
        elif not attr.startswith('_') and not callable(value):
            tree.add(f"[cyan]{attr}[/cyan]: {value}")
    
    return tree

# GemmaDecoderBlockの内部構造を確認
for layer in gemma_lm.layers:
    console.print(Panel(create_layer_tree(layer), title=f"[bold red]{layer.__class__.__name__}[/bold red]", expand=True))

# 埋め込み層の確認
embedding_layer = gemma_lm.get_layer('token_embedding')
console.print(Panel(
    f"[bold magenta]Embedding Layer[/bold magenta]\n"
    f"Name: [green]{embedding_layer.name}[/green]\n"
    f"Shape: [yellow]{embedding_layer.embeddings.shape}[/yellow]",
    title="Embedding Layer Info",
    border_style="blue"
))

このコードを実行すると、モデルの各層の詳細な構造が美しく可視化されて表示されます。これにより、モデルの内部構造を理解することができます。

ファインチューニング前の推論

ファインチューニングを行う前に、モデルの現在の性能を確認しましょう。以下のコードで、モデルに質問を投げかけることができます。

# テキスト生成と表示のための関数
def generate_and_display(prompt, max_length=512):
    generated_text = gemma_lm.generate(prompt, max_length=max_length)
    
    prompt_panel = Panel(
        Text(prompt, style="bold magenta"),
        title="[blue]Input Prompt[/blue]",
        border_style="blue",
        box=ROUNDED,
    )

    generated_md = Markdown(generated_text)
    
    token_count = len(generated_text.split())
    token_info = Text(f"\n\nGenerated {token_count} tokens.", style="italic cyan")

    output_panel = Panel(
        generated_md,
        title="[green]Generated Response[/green]",
        border_style="green",
        box=ROUNDED,
    )

    console.print(prompt_panel)
    console.print()
    console.print(output_panel)
    console.print(token_info)

# 関数を呼び出して結果を表示
generate_and_display("How can I plan a trip to Europe?", max_length=512)

このコードを実行すると、モデルが生成した回答が美しく整形されて表示されます。これにより、ファインチューニング前のモデルの性能を確認できます。

データの準備

ファインチューニングには、Databricks Dollyデータセットを使用します。このデータセットは、LLMの指示追従能力を向上させるために設計された高品質なプロンプト/レスポンスのペアを含んでいます。

# 必要なライブラリのインポート
import json

# データの読み込みと前処理
data = []
with open('/kaggle/input/databricks-dolly-15k/databricks-dolly-15k.jsonl') as file:
    for line in file:
        features = json.loads(line)
        if features["context"]:
            continue
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# 学習を高速化するためにデータを切り詰める
data = data[:500]

このコードでは以下のことを行っています:

  1. JSONLファイルからデータを読み込む
  2. 必要な形式にデータを整形
  3. 学習の高速化のため、データセットを最初の500例に制限

これにより、ファインチューニングに使用するデータが準備されます。

LoRA(Low-Rank Adaptation)の設定

ファインチューニングには、LoRA(Low-Rank Adaptation)という手法を使用します。LoRAは、モデルの大部分のパラメータを凍結し、少数の新しい学習可能なパラメータを追加することで、効率的なファインチューニングを可能にします。

# LoRAの有効化
gemma_lm.backbone.enable_lora(rank=8)

# モデルのコンパイル
gemma_lm.preprocessor.sequence_length = 512  # 入力シーケンス長を制限
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=5e-5),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# モデルの構造を表示
gemma_lm.summary()

このコードでは以下のことを行っています:

  1. LoRAを有効化(ランク8で設定)
  2. モデルのコンパイル(損失関数、オプティマイザ、メトリクスの設定)
  3. 入力シーケンス長を512トークンに制限(メモリ使用量を抑えるため)
  4. モデルの構造を表示

これにより、効率的なファインチューニングの準備が整います。

ファインチューニングの実行

いよいよファインチューニングを実行します。

# ファインチューニングの実行
history = gemma_lm.fit(data, epochs=1, batch_size=4)

# 学習の履歴を表示
print(history.history)

このコードでは以下のことを行っています:

  1. 準備したデータを使ってモデルをファインチューニング
  2. エポック数は1、バッチサイズは4に設定
  3. 学習の履歴を表示

これにより、モデルが特定のタスクに適応します。

ファインチューニング後の推論

ファインチューニング後のモデルの性能を確認しましょう。

# ファインチューニング後の推論
generate_and_display("Instruction:\nHow can I plan a trip to Europe?\n\nResponse:\n", max_length=512)

このコードを実行すると、ファインチューニング後のモデルが生成した回答が表示されます。ファインチューニング前と比較して、回答の質が向上していることが期待できます。

モデルの保存とKagglehubへのアップロード

# モデルの保存先ディレクトリの設定
FINETUNED_MODEL_DIR = f"./gemma_demo"

# モデル名の設定
MODEL_BASE = "gemma2_2b_demo"
MODEL_NAME = f"{MODEL_BASE}_train_finetuning_h5"
FINETUNED_WEIGHTS_PATH = f"{FINETUNED_MODEL_DIR}/{MODEL_NAME}.weights.h5"
FINETUNED_VOCAB_PATH = f"{FINETUNED_MODEL_DIR}/vocabulary.spm"
FRAMEWORK = "jax"
VER = 9

# ディレクトリの作成とモデルの保存
os.makedirs(FINETUNED_MODEL_DIR, exist_ok=True)
gemma_lm.save_weights(FINETUNED_WEIGHTS_PATH)  # ファインチューニングされた重みを保存
gemma_lm.preprocessor.tokenizer.save_assets(FINETUNED_MODEL_DIR)  # トークナイザーのアセットを保存

このコードでは、以下のことを行っています:

  1. FINETUNED_MODEL_DIR:ファインチューニングしたモデルを保存するディレクトリを指定
  2. MODEL_BASEMODEL_NAME:モデルの基本名と完全な名前を設定
  3. FINETUNED_WEIGHTS_PATH:モデルの重みを保存するファイルパスを指定
  4. FINETUNED_VOCAB_PATH:トークナイザーの語彙ファイルを保存するパスを指定
  5. FRAMEWORK:使用しているフレームワーク(この場合は"jax")を指定
  6. VER:モデルのバージョン番号を設定

次に、os.makedirs()でディレクトリを作成し、gemma_lm.save_weights()でモデルの重みを保存、gemma_lm.preprocessor.tokenizer.save_assets()でトークナイザーのアセットを保存します。

# Kagglehubへのアップロード準備
import kagglehub
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
KAGGLE_USERNAME = user_secrets.get_secret("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = user_secrets.get_secret("KAGGLE_KEY")
os.environ["KAGGLE_USERNAME"] = KAGGLE_USERNAME

# Kagglehubへのアップロード実行
handle = f'{KAGGLE_USERNAME}/{MODEL_BASE}/{FRAMEWORK}/{MODEL_NAME}'
kagglehub.model_upload(handle, FINETUNED_WEIGHTS_PATH, license_name='Apache 2.0', version_notes=f'v{VER}')

このコードでは、以下のことを行っています:

  1. kagglehubUserSecretsClientをインポートして、Kaggleの認証情報を取得
  2. Kaggleのユーザー名と APIキーを環境変数に設定
  3. アップロード用のハンドル(一意の識別子)を作成
  4. kagglehub.model_upload()を使用して、モデルをKagglehubにアップロード
    • handle:モデルの一意の識別子
    • FINETUNED_WEIGHTS_PATH:アップロードする重みファイルのパス
    • license_name:モデルのライセンス(ここではApache 2.0)
    • version_notes:バージョンに関する注記

このプロセスにより、ファインチューニングしたGemma2-2bモデルがKagglehubに公開され、他の人々も使用できるようになります。

まとめ

この記事では、Gemma2-2bモデルを使った分散ファインチューニングの方法と、その結果をKagglehubにアップロードする方法を学びました。以下が主なポイントです:

  1. TPUを活用した効率的な大規模言語モデルのファインチューニング
  2. LoRA(Low-Rank Adaptation)を使用した効率的な学習
  3. Databricks Dollyデータセットを用いた指示追従能力の向上
  4. ファインチューニングしたモデルの保存と共有

この手法を応用することで、さまざまな自然言語処理タスクに対応できるカスタムモデルを作成し、共有することができます。

さらなる改善のためには、以下のポイントを検討してみてください:

  1. より多くのデータでファインチューニングを行う
  2. ハイパーパラメータ(学習率、LoRAのランクなど)の調整
  3. より長いエポック数での学習
  4. タスク特化型のデータセットの使用

Gemma2-2bのような効率的なモデルを使うことで、限られたリソースでも高性能な言語モデルを構築・活用することが可能になります。ぜひ、自分のプロジェクトやアイデアに応用してみてください!

最後に、モデルを公開する際は、適切なライセンスを選択し、ethical AIの原則に従って責任ある使用を心がけましょう。

この記事が、皆さんの機械学習の旅に役立つことを願っています。Happy coding!

📒ノートブック

https://www.kaggle.com/code/makimakiai/jp-gemma2-2b-tpu-fine-tuning-kagglehub-upload

Discussion