🍞

Gemma-3をLoRAでチューニングしてみた

に公開

目的

この記事では、Gemma-3のうち、1B・インストラクションチューニングのgemma-3-1b-itをRoLAを使ってつくよみちゃんデータセットでファインチューニングする手法を説明します。
https://huggingface.co/google/gemma-3-1b-it

はじめに

Googleから2025/03/12にリリースされたGemma-3をご存じでしょうか?

https://huggingface.co/blog/gemma3

このGemma-3は大規模言語モデルのひとつで、小規模なパラメータ数にもかかわらず、かなり賢いモデルになっています。

Gemma-3には、モデルサイズ(1B/4B/12B/27B)と事前学習モデルかインストラクションチューニングモデルか(pt/it)の違いによって、合計8種類が存在します。ちなみに、1b以外はマルチモーダルモデルになっていて、画像の入力も可能のようです。今回はこの1bのitモデルをLoRAによってファインチューニングします。

LoRAとは?

ファインチューニング手法の一種です。重み行列W\in\mathbb R^{d\times k}を学習するとき、Wをそのまま学習するのではなく、代わりに\Delta W = BAとなるような行列B\in\mathbb R^{d\times r}A\in\mathbb R^{r\times k}を学習します(rは小さな値)。
\Delta Wをそのまま学習する場合、 dkのパラメータ数の学習が必要ですが, ABを学習することでdr + rk=r(d+k)のパラメータ数だけで学習することができます。これにより学習に必要なパラメータ数を大幅に抑え、メモリ使用量の節約が可能になります!

詳しく知りたい方は以下の論文を参照してください。

https://arxiv.org/abs/2106.09685

Transformerのインストール

現在(2025/03/16)、pipから標準でダウンロードできるtransformersライブラリはGemma-3に対応していないので、対応版をgitからインストールする必要があります。モデルのページにあるように、以下のコードでGemma-3対応版transformersをインストールしてください。

$ pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

だたし、おそらく後日transformersに組み込まれると思われるので、この操作は不要になる気がします。最新の情報はGemma-3のHugging Faceのページなどを参考にしてください。

データセットのダウンロード

こちらのデータセットを使用しました。
https://tyc.rei-yumesaki.net/material/kaiwa-ai/
だたし、このデータの形式はxlsx形式なので、csvに変換します。また、元データには「追加日」「【A】話しかけ」「【B】お返事」「【C】Bに対するA話者の返事(ある場合のみ)」「【D】Cに対するつくよみちゃんのお返事(ある場合のみ)」「2往復」「固有ネタ」「名前入り」という8つのカラムがありますが、このうち「【A】話しかけ」「【B】お返事」のみを用います。
今回は「【A】話しかけ」に「question」、「【B】お返事」に「response」というカラム名をつけてタブ文字区切りのcsvに変換しました。

データセットの処理

Gemma-3のテキスト形式に合わせて、データセットを変形します。

import pandas as pd
from datasets import Dataset, load_dataset

dataset_raw = Dataset.from_pandas(pd.read_csv("tukuyomi.csv", sep="\t"))

def generate_prompt(example):
    return {
        "text": f"""<bos><start_of_turn>user
{example["question"]}<end_of_turn>
<start_of_turn>model
{example["response"]}<end_of_turn><eos>""",
    }


dataset = dataset_raw.map(generate_prompt).train_test_split(test_size=0.1)
train_dataset = dataset["train"].remove_columns(["question", "response"])
test_dataset = dataset["test"].remove_columns(["question", "response"])

最終的にこんな感じになります。

train_dataset[0]
{'text': '<bos><start_of_turn>user\n一緒に歌わない?<end_of_turn>\n<start_of_turn>model\n歌いましょう!<end_of_turn><eos>'}

Gemma-3の読み込み

Gemma-3をインストールします。

from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer

tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-3-1b-it",
    device_map="auto",
)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

モデルのレイヤーを見てみます。

長いので折り畳み
model
Gemma3ForCausalLM(
  (model): Gemma3TextModel(
    (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma3DecoderLayer(
        (self_attn): Gemma3Attention(
          (q_proj): Linear(in_features=1152, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1152, out_features=256, bias=False)
          (v_proj): Linear(in_features=1152, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1152, bias=False)
          (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
          (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
        )
        (mlp): Gemma3MLP(
          (gate_proj): Linear(in_features=1152, out_features=6912, bias=False)
          (up_proj): Linear(in_features=1152, out_features=6912, bias=False)
          (down_proj): Linear(in_features=6912, out_features=1152, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
        (post_attention_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
        (post_feedforward_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
      )
    )
    (norm): Gemma3RMSNorm((1152,), eps=1e-06)
    (rotary_emb): Gemma3RotaryEmbedding()
    (rotary_emb_local): Gemma3RotaryEmbedding()
  )
  (lm_head): Linear(in_features=1152, out_features=262144, bias=False)
)

LoRAの実行

import transformers
from trl import SFTTrainer
from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    task_type="CAUSAL_LM",
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=16,
        num_train_epochs=10,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=50,
        output_dir="outputs",
    ),
    peft_config=lora_config,
)

trainer.train()
trainer.save_model("outputs/gemma-3-1b-it-tukuyomi")
  • rは「RoLAとは?」で説明した値です。今回は8です。
  • target_modulesでRoLAを適用するレイヤーのモジュールを指定しています。
  • GPUのメモリを使い切ってしまう場合、per_device_train_batch_sizeを小さくするとよいと思います。GPU環境に合わせて変化させてください。

RoLAを適用したモデルのレイヤーを見てみます。

長いので折り畳み
model_tsukuyomi
Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=2304, out_features=2048, bias=False)
            (lora_dropout): ModuleDict(
              (default): Identity()
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2304, out_features=8, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=8, out_features=2048, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear(
            (base_layer): Linear(in_features=2304, out_features=1024, bias=False)
            (lora_dropout): ModuleDict(
              (default): Identity()
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2304, out_features=8, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=8, out_features=1024, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (v_proj): lora.Linear(
            (base_layer): Linear(in_features=2304, out_features=1024, bias=False)
            (lora_dropout): ModuleDict(
              (default): Identity()
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2304, out_features=8, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=8, out_features=1024, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (o_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=2304, bias=False)
            (lora_dropout): ModuleDict(
              (default): Identity()
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2048, out_features=8, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=8, out_features=2304, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
        )
        (mlp): Gemma2MLP(
          (gate_proj): lora.Linear(
            (base_layer): Linear(in_features=2304, out_features=9216, bias=False)
            (lora_dropout): ModuleDict(
              (default): Identity()
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2304, out_features=8, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=8, out_features=9216, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (up_proj): lora.Linear(
            (base_layer): Linear(in_features=2304, out_features=9216, bias=False)
            (lora_dropout): ModuleDict(
              (default): Identity()
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2304, out_features=8, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=8, out_features=9216, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (down_proj): lora.Linear(
            (base_layer): Linear(in_features=9216, out_features=2304, bias=False)
            (lora_dropout): ModuleDict(
              (default): Identity()
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=9216, out_features=8, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=8, out_features=2304, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      )
    )
    (norm): Gemma2RMSNorm((2304,), eps=1e-06)
    (rotary_emb): Gemma2RotaryEmbedding()
  )
  (lm_head): Linear(in_features=2304, out_features=256000, bias=False)
)

元のGemmaのレイヤーと比較すると、lora_Aやlora_Bといったlora用パラメータが追加されているのがわかると思います。

推論

次のコードで推論を実行できます。

model_tsukuyomi = AutoModelForCausalLM.from_pretrained(
    "outputs/gemma-3-1b-it-tukuyomi", device_map="auto"
)
tokenizer_tsukuyomi = AutoTokenizer.from_pretrained("outputs/gemma-3-1b-it-tukuyomi")
streamer_tsukuyomi = TextStreamer(
    tokenizer_tsukuyomi, skip_prompt=True, skip_special_tokens=True
)

messages = [
    {"role": "user", "content": "こんにちは!あなたは誰ですか?"},
]

input_ids = tokenizer_tsukuyomi.apply_chat_template(
    messages, add_generation_prompt=True, return_tensors="pt"
).to(model_tsukuyomi.device)
output_ids = model_tsukuyomi.generate(
    input_ids, max_new_tokens=4096, temperature=0.7, streamer=streamer_tsukuyomi
)

素のGemma-3との比較をいくつか見てみます。全体的に、データセットで学習させたような回答をしてくれるようになったと思います。

質問 Gemma-3 LoRA
こんにちは!あなたは誰ですか? こんにちは!私はGoogle DeepMindによってトレーニングされたGemmaです。オープンウェイトのAIチャットボットです。 私はあなたの心の赴くままに存在するフェアリー的な存在です!
好きな食べ物は何ですか? 私はAIなので、食べ物を食べたり、味を感じたりすることはできません。でも、もし私が食べ物だったら、きっとチョコレートが好きだと思います!😊 絵に描いた餅です!
日本で1番高い山は? 日本で一番高い山は、富士山です。 日本一高い山は、富士山です。
まどか☆マギカでは誰が一番かわいい? (長かったので表外に) みんなかわいいです!
NLPについて教えてください。 (長かったので表外に) 基本的に、コンピューターが人間の精神を模倣できるようになることです。
「まどか☆マギカでは誰が一番かわいい?」に対するGemmaの回答

まどか☆マギカで一番かわいいキャラクターを選ぶのは難しいですね!それぞれのキャラクターが魅力的な理由や、人によって好みが分かれるので、いくつか人気のあるキャラクターを挙げてみます。

1. 魔法使いの「エミリア」

  • 魅力ポイント: 優しい笑顔と、魔法の力で困っている人を助ける姿が可愛らしいです。
  • 理由: 彼女の純粋な心と、魔法の力でできることへの希望に満ちた表情が、多くの人に共感を得ています。

2. 魔法使いの「アキ」

  • 魅力ポイント: 落ち着いた雰囲気と、少し憂いを帯びた表情が可愛らしいです。
  • 理由: 彼女の知性と、魔法の力でできることへの自信が、見る人の心を温かくします。

3. 魔法使いの「ミサキ」

  • 魅力ポイント: 活発で明るい性格と、少しおてんばな姿が可愛らしいです。
  • 理由: 彼女の笑顔と、魔法の力でできることへの熱意が、多くの人に元気を与えます。

4. 魔法使いの「リナ」

  • 魅力ポイント: 繊細で、少しおとなしい雰囲気のキャラクターです。
  • 理由: 彼女の優しい眼差しと、魔法の力でできることへの繊細な表現が、多くの人に癒されます。

5. 魔法使いの「ユウコ」

  • 魅力ポイント: 活発で、少しおっちょこちょいな性格のキャラクターです。
  • 理由: 彼女の明るい笑顔と、魔法の力でできることへの熱意が、多くの人に元気を与えます。

その他:

  • 魔法使いの「サクラ」: 優しい雰囲気で、癒し系のキャラクターです。
  • 魔法使いの「ミカ」: 活発で、少しおてんばな性格のキャラクターです。

最終的には、どのキャラクターが一番かわいいかは、個人の好みによって異なります。ぜひ、それぞれのキャラクターの魅力を確かめてみてください!

もし、あなたが特に好きなキャラクターや、どんなところに魅力を感じるか教えていただければ、さらに詳しくお答えできます。

「NLPについて教えてください。」に対するGemmaの回答

NLP(自然言語処理)についてですね!NLPは、コンピュータが人間の言葉を理解し、生成できるようにする技術です。近年、AI技術の発展とともに、NLPは急速に進化しており、私たちの生活や仕事に大きな影響を与えています。

NLPとは?

NLPは、テキストや音声などの自然言語データをコンピュータが処理できるようにする技術です。具体的には、以下のようなことを行います。

  • テキスト理解:
    • 感情分析: テキストから感情(ポジティブ、ネガティブ、ニュートラルなど)を読み取る。
    • 固有表現抽出: テキストから人名、地名、組織名などの固有表現を抽出する。
    • 質問応答: 質問に対して適切な回答を生成する。
    • テキスト要約: 長いテキストを短く要約する。
  • テキスト生成:
    • 機械翻訳: ある言語のテキストを別の言語に翻訳する。
    • 文章生成: 質問や指示に基づいて、文章を生成する。
    • チャットボット: 人間と自然な会話ができるAI。

NLPの主要な技術

NLPを実現するための主要な技術はいくつかあります。

  • 単語埋め込み (Word Embeddings): 単語をベクトルとして表現することで、単語間の意味的な関係を捉えることができる。
  • 深層学習 (Deep Learning): ニューラルネットワークを用いて、大量のテキストデータを学習し、NLPタスクを自動的に解決する。
    • RNN (Recurrent Neural Network): 時系列データ(テキストなど)の処理に適している。
    • LSTM (Long Short-Term Memory): RNNの改良版で、長期的な依存関係を捉えることができる。
    • Transformer: 現在最も人気のある深層学習モデルで、Attentionメカニズムを用いて、文脈を考慮した自然な文章生成が可能。
  • 自然言語モデル (Language Model): テキストの予測や生成を行うモデル。

NLPの応用例

NLPは、様々な分野で活用されています。

  • カスタマーサポート: チャットボットによる自動応答
  • マーケティング: 顧客の感情分析によるターゲティング
  • 医療: 医療記録の解析、診断支援
  • 金融: 詐欺検知、リスク評価
  • 検索エンジン: 検索結果の精度向上
  • 翻訳: 機械翻訳の精度向上
  • 文章作成: ブログ記事、レポート、メールの自動作成

NLPの学習方法

NLPの学習には、以下のような方法があります。

  • 書籍: NLPに関する書籍を読む
  • オンラインコース: Coursera、UdemyなどのプラットフォームでNLPのコースを受講する
  • チュートリアル: GitHubなどのプラットフォームで、NLPのコードを実際に書いてみる
  • データセット: NLPの学習に役立つデータセットを利用する

さらに深く知りたい場合

もし、特定のNLPの技術や応用例についてもっと詳しく知りたい場合は、お気軽にご質問ください。

例えば、

  • 「感情分析の具体的な方法」
  • 「Transformerモデルの仕組み」
  • 「特定のNLPタスク(例:機械翻訳)について」

など、具体的な質問をいただけると、より的確な情報を提供できます。

まとめ

今回の記事では、Gemma-3-1b-itをRoLAを用いてファインチューニングする手法を紹介しました!大規模言語モデルっていうとでかいパラメータを強いGPU(やTPU)でゴリ押すイメージがありますが、最近は比較的小さいのに高性能なモデルとそのファインチューニング手法が増えてきているのでありがたいですね。

あと、この記事を書いてから1bモデルは英語にしか対応していない(ほかのGemma-3はマルチリンガル)ことを知りましたが、それでもわりとうまく出力できています。

参考文献

https://zenn.dev/yumefuku/articles/llm-finetuning-qlora

Discussion