Gemma 2 2Bモデルの分散ファインチューニング: TPUを活用した効率的な学習方法(kaggleノート付)
はじめに
こんにちは!今回は、Googleが新しくリリースしたGemma 2 2Bモデルを使って、TPU(Tensor Processing Unit)を活用した分散ファインチューニングを行う方法をご紹介します。この記事は、大規模言語モデル(LLM)の学習に興味がある初心者の方々を対象としています。
Gemma 2 2Bは、わずか20億のパラメータでありながら、驚くほど高性能な言語モデルです。今回は、このモデルをさらに特定のタスクに適応させるため、ファインチューニングを行います。
環境設定
まず、必要なライブラリをインストールし、環境を整えましょう。
# 必要なライブラリのインストール
!pip install -q -U keras-nlp tensorflow-text
!pip install -q -U tensorflow-cpu
# Keras JAXバックエンドの設定
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
# 必要なモジュールのインポート
import jax
import keras
import keras_nlp
このコードでは、Keras、KerasNLP、TensorFlow関連のライブラリをインストールし、KerasのバックエンドをJAXに設定しています。また、TPUメモリの使用効率を最大化するための設定も行っています。
TPUの確認
Kaggleでは、TPUv3-8デバイスが提供されており、各TPUコアは16GBのメモリを持っています。以下のコードでTPUの状態を確認しましょう。
import jax
jax.devices()
このコードを実行すると、利用可能なTPUデバイスの一覧が表示されます。
モデルのロードと分散設定
次に、Gemma 2 2Bモデルをロードし、分散学習の設定を行います。
# デバイスメッシュの作成
device_mesh = keras.distribution.DeviceMesh(
(1, 4), # 4つのTPUコアを使用
["batch", "model"],
devices=keras.distribution.list_devices()[:4], # 最初の4つのデバイスのみを使用
)
# レイアウトマップの設定
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)
# 重みの分散方法を指定
layout_map["token_embedding/embeddings"] = (model_dim, None)
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*attention_output/kernel"] = (model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear/kernel"] = (model_dim, None)
# モデル並列化の設定
model_parallel = keras.distribution.ModelParallel(
layout_map=layout_map,
batch_dim_name="batch",
)
# 分散設定の適用
keras.distribution.set_distribution(model_parallel)
# Gemma 2 2Bモデルのロード
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
このコードでは、以下の手順を実行しています:
- TPUコアを使用するためのデバイスメッシュを作成
- モデルの重みをTPUコア間で分散させるためのレイアウトマップを設定
- モデル並列化の設定を行い、分散学習を有効化
- Gemma 2 2Bモデルを事前学習済みの重みでロード
モデルの構造確認
モデルの構造を可視化するために、Richライブラリを使用します。以下のコードで、モデルの各層の詳細を確認できます。
from rich.console import Console
from rich.tree import Tree
from rich.panel import Panel
console = Console()
def create_layer_tree(layer):
tree = Tree(f"[bold blue]{layer.name}[/bold blue]")
important_attrs = ['_layers', 'transformer_layers', 'layer_norm', '_token_embedding']
for attr, value in vars(layer).items():
if attr in important_attrs:
if isinstance(value, list):
subtree = tree.add(f"[yellow]{attr}[/yellow]")
for item in value:
subtree.add(f"[green]{item.name}[/green]: {item.__class__.__name__}")
else:
tree.add(f"[yellow]{attr}[/yellow]: [green]{value.name}[/green] ({value.__class__.__name__})")
elif not attr.startswith('_') and not callable(value):
tree.add(f"[cyan]{attr}[/cyan]: {value}")
return tree
# GemmaDecoderBlockの内部構造を確認
for layer in gemma_lm.layers:
console.print(Panel(create_layer_tree(layer), title=f"[bold red]{layer.__class__.__name__}[/bold red]", expand=True))
# 埋め込み層の確認
embedding_layer = gemma_lm.get_layer('token_embedding')
console.print(Panel(
f"[bold magenta]Embedding Layer[/bold magenta]\n"
f"Name: [green]{embedding_layer.name}[/green]\n"
f"Shape: [yellow]{embedding_layer.embeddings.shape}[/yellow]",
title="Embedding Layer Info",
border_style="blue"
))
このコードを実行すると、モデルの各層の詳細な構造が美しく可視化されて表示されます。
ファインチューニング前の推論
ファインチューニングを行う前に、モデルの現在の性能を確認しましょう。以下のコードで、モデルに質問を投げかけることができます。
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from rich.markdown import Markdown
from rich.box import ROUNDED
console = Console()
def generate_and_display(prompt, max_length=512):
generated_text = gemma_lm.generate(prompt, max_length=max_length)
prompt_panel = Panel(
Text(prompt, style="bold magenta"),
title="[blue]Input Prompt[/blue]",
border_style="blue",
box=ROUNDED,
)
generated_md = Markdown(generated_text)
token_count = len(generated_text.split())
token_info = Text(f"\n\nGenerated {token_count} tokens.", style="italic cyan")
output_panel = Panel(
generated_md,
title="[green]Generated Response[/green]",
border_style="green",
box=ROUNDED,
)
console.print(prompt_panel)
console.print()
console.print(output_panel)
console.print(token_info)
# 関数を呼び出して結果を表示
generate_and_display("How can I plan a trip to Europe?", max_length=512)
このコードを実行すると、モデルが生成した回答が美しく整形されて表示されます。
データの準備
ファインチューニングには、Databricks Dollyデータセットを使用します。このデータセットは、LLMの指示追従能力を向上させるために設計された高品質なプロンプト/レスポンスのペアを含んでいます。
import json
data = []
with open('/kaggle/input/databricks-dolly-15k/databricks-dolly-15k.jsonl') as file:
for line in file:
features = json.loads(line)
if features["context"]:
continue
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
data.append(template.format(**features))
# 学習を高速化するためにデータを切り詰める
data = data[:2500]
このコードでは、JSONLファイルからデータを読み込み、必要な形式に整形しています。学習の高速化のため、データセットを最初の2500例に制限しています。
LoRA(Low-Rank Adaptation)の設定
ファインチューニングには、LoRA(Low-Rank Adaptation)という手法を使用します。LoRAは、モデルの大部分のパラメータを凍結し、少数の新しい学習可能なパラメータを追加することで、効率的なファインチューニングを可能にします。
# LoRAの有効化
gemma_lm.backbone.enable_lora(rank=8)
# モデルのコンパイル
gemma_lm.preprocessor.sequence_length = 512 # 入力シーケンス長を制限
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(learning_rate=5e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
このコードでは、LoRAを有効化し、モデルのコンパイルを行っています。入力シーケンス長を512トークンに制限することで、メモリ使用量を抑えています。
ファインチューニングの実行
いよいよファインチューニングを実行します。
gemma_lm.fit(data, epochs=1, batch_size=4)
このコードでは、準備したデータを使ってモデルをファインチューニングしています。エポック数は1、バッチサイズは4に設定しています。
ファインチューニング後の推論
ファインチューニング後のモデルの性能を確認しましょう。
generate_and_display("Instruction:\nHow can I plan a trip to Europe?\n\nResponse:\n", max_length=512)
このコードを実行すると、ファインチューニング後のモデルが生成した回答が表示されます。ファインチューニング前と比較して、回答の質が向上していることが期待できます。
まとめ
この記事では、Gemma 2 2Bモデルを使った分散ファインチューニングの方法を学びました。TPUを活用することで、効率的に大規模言語モデルを特定のタスクに適応させることができます。
ここで学んだ技術を応用することで、さまざまな自然言語処理タスクに対応できるモデルを作成することができます。さらなる改善のためには、以下のポイントを検討してみてください:
- より多くのデータでファインチューニングを行う
- ハイパーパラメータ(学習率、LoRAのランクなど)の調整
- より長いエポック数での学習
Gemma 2 2Bのような効率的なモデルを使うことで、限られたリソースでも高性能な言語モデルを構築・活用することが可能になります。ぜひ、自分のプロジェクトやアイデアに応用してみてください!
📒ノートブック
日本語版
英語版
参考サイト
<script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>
Discussion