🤗

大規模言語モデルを自作しよう!(Transformers+flash_attn2)

に公開
3

本記事は、LLM Advent Calendar 2023 13日目の記事です。

https://qiita.com/advent-calendar/2023/llm

はじめに

🤗 Transformersは、自然言語処理、マルチモーダル、音声処理、コンピュータビジョン分野の事前学習済モデルを簡単にダウンロードしトレーニングすることが可能なpythonライブラリです。このライブラリを使用し、大規模言語モデル(LLM)の事前学習済モデルをローカルPC上にダウンロードし、それを使用した言語生成や、要約・翻訳・質問応答などの個別のタスクへのファインチューニング、チャットAIへの組み込みなどが盛んに行われています。

LLMの事前学習方法に関する情報としては、GPT-NeoXMegatron-LMTinyLlamalit-llamaなど、他のpythonライブラリを使用したものが増えてきています。一方で、Transformersライブラリを使用したLLMの事前学習に関する情報は未だ少ない現状にあります。

そこで本記事では、汎用PCでも学習がしやすい300MサイズのMistralモデルを題材とし、Transformersを使用してLLMの事前学習・指示チューニングを実施する方法を紹介します。本記事で作成できるbaseモデルは「mistral-300m-base」、instructionモデルは「mistral-300m-sft」として公開しています。

実装のためのソースコードは、japanese-mistral-300m-recipev1.0.0であり、本記事はその解説です。次のコマンドを実行することで、環境構築、事前学習、指示チューニング、推論のすべてが実施可能です。

git clone japanese-mistral-300m-recipe
cd japanese-mistral-300m-recipe
git checkout v1.0.0
docker compose build
docker compose run mistral300m
uv sync --extra build 
uv sync --extra build --extra compile
bash run_all.sh

この記事の特徴は、以下の通りです。

  • SentencePieceトークナイザーでのbyte fallback使用によるunknown_token生成抑制と、huggingface Tokenizers形式への変換
  • flash attention2を使用した学習高速化
  • Mistral 300Mの事前学習
  • Mistral 300Mの指示チューニング

検証環境

項目 バージョン 備考
OS Ubuntu 22.04.3 LTS
CPU AMD® Ryzen 5 3600x 6-core processorx12
RAM DDR4 80GB
GPU RTX4090 VRAM24GB
python 3.11 .python-version参照
CUDA toolkit 12.6 Dockerfile参照
cudnn 9 Dockerfile参照
NVIDIA Driver 550.107.02
pythonライブラリ transformers==4.55.2
torch==2.8.0
pyproject.toml参照
SSD 4TB
その他ハードディスク HDD 12TB
SSD 4TB

ワークフロー

本記事の流れと所要時間は、下表の通りです。LLMは、基本的にこの流れで作成されます。
所要時間は、ハードウェア性能に起因して前後します。

No. ステップ 所要時間
1 python仮想環境構築 5min
2 データセット構築 1h
3 トークナイザー学習 1h
4 事前学習 92h
5 推論 1min
6 指示チューニング 1h
7 推論2 1min

python仮想環境構築

Dockerを使用した環境構築を推奨します。
以下のコマンドを実行することで、CUDA toolkitとcudnn、その他のツールをインストールしたコンテナが作成されます。

docker compose build

次に、以下のコマンドでコンテナを起動し、コンテナ上での開発に移ります。

docker compose run mistral300m

最後に、以下のコマンドでpython仮想環境を構築します。

uv sync --extra build
uv sync --extra build --extra compile

データセット構築

事前学習用のデータセットには、以下を使用します。
このデータセットのサイズは、14 Billion(140億)トークンです。

https://huggingface.co/datasets/ce-lery/corpus-ja-11b

このデータセットの内訳は、次の通りです。

データセット名 作り方 トークン数
Wikipedia dumps.wikimedia.orgよりダウンロード。
bert-japaneseのスクリプトを用いてデータ整形。
corpus-cleanerでクリーニング。
0.94B
Wikibooks dumps.wikimedia.orgよりダウンロード。
bert-japaneseのスクリプトを用いてデータ整形。
corpus-cleanerでクリーニング。
0.01B
Wikiversity dumps.wikimedia.orgよりダウンロード。
bert-japaneseのスクリプトを用いてデータ整形。
corpus-cleanerでクリーニング。
0.00B
CC-100 data.statmt.orgよりダウンロード。
bert-japaneseのスクリプトを用いてデータ整形。
corpus-cleanerでクリーニング。
13.1B

データセットのダウンロード

次のコマンドを実行することで、前述のデータセットをダウンロードします。

cd examples/pretrain
bash dataset/dataset.sh

dataset.shの処理は、次の通りです。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/examples/pretrain/dataset.sh

クリーニング

通常、データセットは前処理が必要です。
前処理の種類としては、こちらのようなものがあります。

今回学習に使用するデータセットはクリーニング済みのものであるため、この処理は省略します。

データセットのtrain-validation分割

未学習データに対する推論性能を検証するため、データセットの一部を検証(validation)データとして使用します。

本記事では、データセットをtrain:validation=99:1の割合で分割します。trainデータは学習に使用し、validationデータは学習に使用しません。学習中に随時validationデータを使用して推論精度を計測し、未学習データに対する性能を確認します。

データセットの分割は、事前学習にて使用するrun_clm.py内で実施されます。run_clm.pyに渡すvalidation_split_percentageを1にすることで、datasetがtrain:validation=(100-1):1の割合で分割され、それぞれ使用されます。

トークナイザー学習

Transformerモデルは、入力として文字列を受けとることができません。代わりに、文字列をトークンという単位に分割し、トークンごとに割り当てられた数値(ID)を入力値として使用します。トークナイザーは、文字列をトークンに分割し、トークンごとにIDへ変換するものです。

トークナイザーのトークン分割単位を決定づけるアルゴリズムには、バイトペアエンコーディング(BPE)、WordPiece、Unigramどがあります。各アルゴリズムの詳細はこちらに譲るとして、今回はSentensePieceライブラリにてUnigramアルゴリズムを使用して、トークン分割単位の決定とトークンごとのID割り当てを行います。本記事では、この処理を「トークナイザーの学習」と呼称します。

トークナイザーの学習は、次のコマンドで実行可能です。

# cd examples/pretrain
bash tokenizer.sh

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/examples/pretrain/tokenizer.sh

SentencePiece学習

SentensePieceライブラリを使用し、トークナイザーの学習を実施します。トークナイザーは、前述のデータセットのうちwikipediaデータのみを用いて、Unigramアルゴリズムで学習します。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/tokenizer/unigram.py

SentencePieceトークナイザーは、17行目の「vocab_size=32000」で、トークナイザーのリストに登録する語彙数を設定しています。逆に、この数以上の語彙は登録されません。トークナイザーは、リストに登録されていない文字列は処理することができず、未知語(unknown_token)として処理します。すなわち、リストに登録されていない文字列は、学習時・推論時に[unk]となってしまいます。

これを防ぐための機能がbyte-fallbackです。16行目のようにこれを有効にすることで、SentencePieceトークナイザーは渡されたリスト未登録文字をバイト単位でIDにエンコードすると共に、バイト単位のIDをUTF-8形式でデコードすることが可能になります。つまり、トークナイザーのリストに登録されていない文字列も処理することができます。

Tokenizers LlamaTokenizerFast形式への変換

Transformersライブラリは、SentencePieceライブラリで生成されたトークナイザーをそのまま使用することができません。Transformersライブラリでも使用できるように、huggingface TokenizersライブラリのTokenizersクラス形式に変換します。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/tokenizer/unigram.py#L41-L51

データセットのトークン化処理

先ほど作成したトークナイザーを使用して、データセット全体をトークン化します。
この処理は、後述のrun_clm.pyに実装されており、run_clm.py実行時に実施されます。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/train/run_clm.py#L562-L605

事前学習

構築したデータセットを用いて、LLMをゼロから学習します。
事前学習は、次のコマンドで実行可能です。

# cd examples/pretrain
bash train.sh

事前学習には、Transformersのexampleスクリプトの1つであるrun_clm.pyを使用します。run_clm.pyは、一部を以下のように修正しています。

run_clm.pyの変更差分
@@ -67,9 +67,8 @@ from transformers.trainer_utils import get_last_checkpoint
 from transformers.utils import check_min_version, send_example_telemetry
 from transformers.utils.versions import require_version
 
-
 # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.57.0.dev0")
+check_min_version("4.55.2")
 
 require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 
@@ -144,7 +143,7 @@ class ModelArguments:
             )
         },
     )
-    dtype: Optional[str] = field(
+    torch_dtype: Optional[str] = field(
         default=None,
         metadata={
             "help": (
@@ -154,6 +153,33 @@ class ModelArguments:
             "choices": ["auto", "bfloat16", "float16", "float32"],
         },
     )
+    low_cpu_mem_usage: bool = field(
+        default=False,
+        metadata={
+            "help": (
+                "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
+                "set True will benefit LLM loading time and RAM consumption."
+            )
+        },
+    )
+    min_lr: Optional[float]  = field(
+        default=None,
+        metadata={
+            "help": (
+                "For example, when using cosine_with_min_lr as sucheduler, "
+                "this is an option to set the minimum learning rate."
+            )
+        },
+    )
+    min_lr_rate: Optional[float]  = field(
+        default=None,
+        metadata={
+            "help": (
+                "For example, when using cosine_with_min_lr as sucheduler, "
+                "this is an option to set the minimum learning rate's ratio."
+            )
+        },
+    )
 
     def __post_init__(self):
         if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
@@ -234,10 +260,10 @@ class DataTrainingArguments:
         else:
             if self.train_file is not None:
                 extension = self.train_file.split(".")[-1]
-                assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
+                # assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
             if self.validation_file is not None:
                 extension = self.validation_file.split(".")[-1]
-                assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
+                # assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
 
 
 def split_streaming_dataset(
@@ -404,6 +430,12 @@ def main():
         if extension == "txt":
             extension = "text"
             dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
+        if extension == "jsonl":
+            extension = "json"
+        print("extention:",extension)
+        print("data_files[train]:", data_files["train"])
+        # print("data_files[validation]:", data_files["validation"])
+
         raw_datasets = load_dataset(
             extension,
             data_files=data_files,
@@ -487,7 +519,11 @@ def main():
         )
 
     if model_args.model_name_or_path:
-        dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
+        torch_dtype = (
+            model_args.torch_dtype
+            if model_args.torch_dtype in ["auto", None]
+            else getattr(torch, model_args.torch_dtype)
+        )
         model = AutoModelForCausalLM.from_pretrained(
             model_args.model_name_or_path,
             from_tf=bool(".ckpt" in model_args.model_name_or_path),
@@ -496,13 +532,16 @@ def main():
             revision=model_args.model_revision,
             token=model_args.token,
             trust_remote_code=model_args.trust_remote_code,
-            dtype=dtype,
+            torch_dtype=torch_dtype,
+            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
+            attn_implementation="flash_attention_2",
         )
     else:
-        model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
+        # n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
+        model = AutoModelForCausalLM.from_config(config,
+                                                attn_implementation="flash_attention_2")
         n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
-        logger.info(f"Training new model from scratch - Total size={n_params / 2**20:.2f}M params")
-
+        logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
     # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
     # on a small vocab and want a smaller embedding size, remove this test.
     embedding_size = model.get_input_embeddings().weight.shape[0]
@@ -522,7 +561,14 @@ def main():
 
     def tokenize_function(examples):
         with CaptureLogger(tok_logger) as cl:
-            output = tokenizer(examples[text_column_name])
+            #output = tokenizer(examples[text_column_name])
+            # add BOS and EOS
+            processed_texts = [text + tokenizer.eos_token for text in examples[text_column_name]]
+            output = tokenizer(processed_texts, add_special_tokens=False)
+            # # If there are other columns in examples, we need to preserve them
+            # for key in examples.keys():
+            #     if key != text_column_name:
+            #         output[key] = examples[key]
         # clm input could be much much longer than block_size
         if "Token indices sequence length is longer than the" in cl.out:
             tok_logger.warning(
@@ -530,6 +576,16 @@ def main():
                 " before being passed to the model."
             )
         return output
+    
+    print(">>> tokenizer test")
+    text = "こんにちは。私は日本人です。"
+    text_encoded = tokenizer(text)
+    print(text_encoded)
+    print(tokenizer.decode(text_encoded["input_ids"]))
+    text = text + tokenizer.eos_token
+    text_encoded = tokenizer(text)
+    print(text_encoded)
+    print(tokenizer.decode(text_encoded["input_ids"]))
 
     with training_args.main_process_first(desc="dataset map tokenization"):
         if not data_args.streaming:
@@ -575,7 +631,7 @@ def main():
     # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
     def group_texts(examples):
         # Concatenate all texts.
-        concatenated_examples = {k: list(chain(*examples[k])) for k in examples}
+        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
         total_length = len(concatenated_examples[list(examples.keys())[0]])
         # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
         # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
@@ -615,22 +671,16 @@ def main():
             raise ValueError("--do_train requires a train dataset")
         train_dataset = lm_datasets["train"]
         if data_args.max_train_samples is not None:
-            if data_args.streaming:
-                train_dataset = train_dataset.take(data_args.max_train_samples)
-            else:
-                max_train_samples = min(len(train_dataset), data_args.max_train_samples)
-                train_dataset = train_dataset.select(range(max_train_samples))
+            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
+            train_dataset = train_dataset.select(range(max_train_samples))
 
     if training_args.do_eval:
         if "validation" not in tokenized_datasets:
             raise ValueError("--do_eval requires a validation dataset")
         eval_dataset = lm_datasets["validation"]
         if data_args.max_eval_samples is not None:
-            if data_args.streaming:
-                eval_dataset = eval_dataset.take(data_args.max_eval_samples)
-            else:
-                max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
-                eval_dataset = eval_dataset.select(range(max_eval_samples))
+            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
+            eval_dataset = eval_dataset.select(range(max_eval_samples))
 
         def preprocess_logits_for_metrics(logits, labels):
             if isinstance(logits, tuple):
@@ -649,6 +699,15 @@ def main():
             preds = preds[:, :-1].reshape(-1)
             return metric.compute(predictions=preds, references=labels)
 
+    training_args.lr_scheduler_kwargs = {}
+    if model_args.min_lr is not None:
+        training_args.lr_scheduler_kwargs["min_lr"] = model_args.min_lr
+    elif model_args.min_lr_rate is not None:
+        training_args.lr_scheduler_kwargs["min_lr_rate"] = model_args.min_lr_rate 
+
+    # todo fix here
+    # training_args.torch_empty_cache_steps = 4
+
     # Initialize our Trainer
     trainer = Trainer(
         model=model,
@@ -656,6 +715,7 @@ def main():
         train_dataset=train_dataset if training_args.do_train else None,
         eval_dataset=eval_dataset if training_args.do_eval else None,
         processing_class=tokenizer,
+        # like past_key_values, but logits always come first
         # Data collator will default to DataCollatorWithPadding, so we change it.
         data_collator=default_data_collator,
         compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
@@ -679,10 +739,10 @@ def main():
         max_train_samples = (
             data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
         )
-        if data_args.streaming:
-            metrics["train_samples"] = max_train_samples
-        else:
-            metrics["train_samples"] = min(max_train_samples, len(train_dataset))
+    # trainer = Trainer(
+    #     model=model,
+    #     args=training_args,
+    #     train_dataset=train_dataset if training_args.do_train else None,
 
         trainer.log_metrics("train", metrics)
         trainer.save_metrics("train", metrics)
@@ -695,11 +755,11 @@ def main():
         metrics = trainer.evaluate()
 
         max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
-        if data_args.streaming:
-            metrics["eval_samples"] = max_eval_samples
-        else:
-            metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
-
+        checkpoint = None
+        if training_args.resume_from_checkpoint is not None:
+            checkpoint = training_args.resume_from_checkpoint
+        elif last_checkpoint is not None:
+            checkpoint = last_checkpoint
         try:
             perplexity = math.exp(metrics["eval_loss"])
         except OverflowError:

このrun_clm.pyに対し、train.shの各種パラメータを渡すことで、学習に使用するパラメータを設定しています。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/examples/pretrain/train.sh

Mistral 300Mモデルの設定

mistralai/Mistral-7B-v0.1のconfig.jsonを改良し、以下のようにモデルサイズを318Mになるように設定し、使用します。 モデルの各パラメータは、japanese-gpt2-mediumのconfig.jsonGPT2ConfigMistralConfigを参照し、設定しています。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/config/config_mistral_300m.json

run_clm.py内では、まず、次の箇所でconfig_mistral_300m.jsonを読み込み、configパラメータを作成します。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/train/run_clm.py#L494-L495

次に、configパラメータをベースに、初期モデルを定義します。モデルの重みは、このタイミングで初期化されます。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/train/run_clm.py#L541-L542

そして、次の箇所で、学習用パラメータを設定します。model引数に先程定義した初期モデルを渡すことで、このモデルが学習対象になります。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/train/run_clm.py#L712-L725

最後に以下の箇所が実行されることで、上記の学習用パラメータで学習が開始されます。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/train/run_clm.py#L734

学習速度高速化とメモリ消費量削減

学習時に課題となる要素には、「学習速度」と「GPUメモリ消費量」が挙げられます。例えば学習時間が3600時間(150日)である場合、実質学習できないことと同義です。また、学習時間を短縮することで、GPUの電気料金やGoogle Colabo等の課金料金を削減することができるため、可能な限り学習速度を高速化することが望まれます。

学習速度高速化とメモリ消費量削減のための手段は、下表のとおりです。 
学習高速化効果の列と、メモリ削減効果の列は、◯、ー、×で表記します。列名に対して、◯は「効果がある」、ーは「効果なし」、×は「逆効果である」ことを表します。

設定項目 内容 学習高速化効果 メモリ削減効果 備考
bf16の混合精度トレーニング モデルパラメータの一部をbf16精度で表現し、残りはfp32精度とすることで、計算を高速化する。 こちらの通り、モデルが16ビットと32ビットの両方の精度 (GPU 上の元のモデルの1.5倍)でGPU上に存在するため、多くのGPUメモリが使用される可能性がある。
torch.compile PyTorch コードを最適化されたカーネルにJITコンパイルすることで、Pytorch処理の高速化を実現する。 × こちらも参照。
flash attention2 attention機構の計算を、並列化とWork Partitioningにより効率的に実施するflash attentionに置き換える。 現状は、Ampere、Ada、またはHopper GPU (A100、RTX3090、RTX4090、H100 など)のみ対応。T4では使用できない。
軽量Optimizerの選択 Adamw Optimizerのうち、state情報を8bit(adamw_bnb_8bit)や4bit(adamw_torch_4bit)に量子化して保持するものを使用する。 Transformersで使用できるoptimizerの一覧は、training_args.pyに記載されている。
OptimizerのLayerwise化 各重みにおいて、勾配が計算された時点でOptimizerでの重み更新を行い、その重みの勾配値を削除することで、勾配の保存をなくしメモリ消費量を削減する方法。勾配が保存されないため、Gradient Accumulationとの併用はできない。 Transformersの中でも、galore_adamw_layerwiseなど、一部layerwise化に対応したものがある。layerwiseの設定はtrainer.pyのsetup_low_rank_optimizer()で行われており、この部分を流用することで、adamw_bnb_8bitなどもlaywerwise化が可能(実施済み)。
Gradient Checkpointing 順伝播中の各層での計算結果(中間activation)は逆伝播で使用される場合があるため、通常はすべて保存されている。これに対し、中間activationの一部のみを保存し、残りは逆伝播中に都度再計算するようにすることで、メモリ使用量を削減する方法。 × 書籍「深層ニューラルネットワークの高速化」の9.1.2にわかりやすい説明が記載されている。
Liger Kernel モデルの裏で実行されるCUDAカーネルを、計算効率の高いLiger Kernelに置き換える。
CUDAのMemoryAllocation時のフラグメンテーション防止 CUDAのGPUメモリ割り当て時に発生するメモリ領域の隙間をなくす設定。 こちらでわかりやすく説明されている。

上記のうち、「bf16の混合精度トレーニング」、「flash attention2の使用」、「軽量Optimizerの選択」を実施します。
これらの有効化方法のイメージは、次の通りです。

model = AutoModelForCausalLM.from_config(config,
                                        attn_implementation="flash_attention_2" # flash_attention2の有効化
                                        )

# bf16混合精度使用
training_args = TrainingArguments(..., 
                                  bf16=True,
                                  optim="adamw_bnb_8bit")
trainer = Trainer(...)
trainer.train()

学習時の進捗状況確認

Transformersでは、以下の手順を踏むことで、学習時のtrain_loss、eval_lossなどをダッシュボード表示する機能が存在します。

  1. 以下のコマンドで、python仮想環境にtensorboardライブラリをインストールする
    uv add tensorboard  
    # uvを使用しない場合は`pip install tensorboard`
    
  2. TrianingArgumentsのlogging_dirにログ出力ディレクトリ(例:/results/train/logs/)を設定し、学習を開始する
  3. 別のterminalを開き、python仮想環境をactivateした状態で以下のコマンドを実行する
    uv run tensorboard --logdir ./
    # uvを使用しない場合は`tensorboard --logdir ./`
    
  4. ブラウザで「http://localhost:6006/ 」を開く

以下は、学習完了時の分析結果の様子です。この結果は、こちらからも確認できます。train_loss、eval_lossともに順調に下がっている様子が確認できます。また、learning_rateのWarmupとmin_lr_rateが効いていることが確認できます。

fig1

事前学習結果

run_clm.pyは以下のように、学習終了後に自動でperplexityを算出します。
validationデータセットに対するperplexityは23.8887でした。

***** eval metrics *****
  epoch                   =        1.0
  eval_loss               =     3.1734
  eval_runtime            = 0:59:27.70
  eval_samples_per_second =     128.27
  eval_steps_per_second   =     10.689
  perplexity              =    23.8887

ここまでの手順で、以下の「mistral-300m-base」が完成します。

https://huggingface.co/ce-lery/mistral-300m-base

推論

事前学習済みモデルを使用した推論は、次のコマンドで実行可能です。

# cd examples/pretrain
bash inference.sh

inference.shは、次の推論用ソースコードを実行しています。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/inference/inference.py

作成した事前学習済みモデルの入出力例は、下表の通りです。
出力内容は誤りが含まれていますが、かなり流暢な日本語で出力できています。

プロンプト 出力
<s>香川県の名物は、 『うどん』です。香川県には、うどんの名店がいくつもあります。今回は、香川県にある『讃岐うどん店』をご紹介します。</s>
<s>香川県の県庁所在地は、 香川県高松市です。高松市の人口は1,334,546人(平成30年3月1日現在)で、県の人口の約4割を占めています。また、高松市内には高松城跡や高松塚古墳などの史跡が点在しており、歴史好きにはたまらない地域となっています。</s>
<s>兵庫県の県庁所在地は、 兵庫県姫路市です。姫路城や姫路港など、観光地が点在しています。また、姫路駅はJR西日本の駅として有名で、新幹線の停車駅でもあります。</s>
<s>栃木県の県庁所在地は、 栃木県宇都宮市です。宇都宮市の人口は1,339,559人(平成29年3月31日現在)で、人口密度は61.1人/km2(全国平均)となっています。県の面積の約3分の1は山林で占められています。また、宇都宮市では宇都宮城址公園が整備されており、市民の憩いの場として親しまれています。</s>
<s>日本の首都は、 東京です。東京は東京湾に浮かぶ島々からなり、その面積は日本の総面積の約3分の1を占めます。人口は1億2千万人で、日本の人口の4分の3を占めています。また、東京都心部には、首都機能が集中しており、都心部は交通の要衝として発展してきました。</s>
<s>日本で一番高い山は、 標高1,335mの「富士山」です。富士山の頂上には、富士山本宮浅間大社(静岡県富士宮市)が鎮座しています。また、日本三名瀑の1つに数えられる「白糸の滝(しらぎのたたき)」があります。この滝は「日本百名滝百選」にも選ばれています。</s>
<s>日本で二番目に高い山は、 山王山(標高1,320m)で、標高2,840mの山です。</s>
<s>日本で一番大きな湖は、 海水浴客でにぎわう、千葉県の房総半島の南端にある「房州市」です。房県市は千葉県北西部に位置し、人口は1万5千人ほど。千葉県内では最も人口の多い市です。「海と山に囲まれ、豊かな自然に恵まれた地域」と、地元住民からは親しまれています。海に面していないため、夏は涼しく、冬は暖かく、年間を通して温暖な気候。年間平均気温は15℃前後と比較的暖かい気候で、冬でも雪が降ることがありません。また、春は桜、秋は紅葉が美しいことでも有名。四季折々の自然を楽しめるのも魅力のひとつです。</s>
<s>世界で一番高い山は、 山王山(標高1,335m)です。
<s>世界で一番大きな湖は、 何だと思いますか?(複数回答可)(n=257)</s>
<s>赤信号 停車中、運転士が発車ベルを鳴らすのを目撃した乗客が、発車のベルが鳴ったことに気づき、車内にいた乗客に通報した。この事故で、乗客の1人が死亡し、2人が重軽傷を負った。</s>
<s>ジョークとは、 言葉遊びの1つで、相手を不快にさせる言葉のことです。例えば、ある人が「私は○○です。」と言ったとします。その人は、「○○です」という言い方をしますが、この言い方は、相手に対して不快感を与えます。つまり、相手の言葉の印象を悪くするのです。また、同じ言葉でも、相手に与える印象によって、印象が変わることもあります。たとえば、「私は、○○と申します」と、相手が言った場合、相手は、その「○○」という言葉に対して、不快感を抱かせます。そして、その言葉が、自分の印象に強く影響することもあるのです。このように、「言葉」には、様々な意味が込められています。しかし、それは、人によって、捉え方が異なるため、一概には言えません。そこで、ここでは、ビジネスシーンでの「言葉の使い方」について、考えてみたいと思います。まず、「ビジネス用語」の定義についてですが、これは、主に、企業や組織で使われる言葉を指します。ビジネスの現場では、「商品」や「サービス」といった言葉は使われません。そのため、この言葉をビジネスで用いる場合は、注意が必要です。なぜなら、商品やサービスは、商品やサービスを販売するための「手段」であり、「目的」ではないからです。商品・サービス・サービスを販売するためには、販売者(企業・組織)が「目的(目的)」を定める必要があります。この目的とは、「顧客(顧客)の満足(満足)を得ること」です。「顧客満足」とは「顧客が満足すること」を意味し、顧客の満足度が高ければ高いほど、高い価値が提供できることになります。したがって、目的は「商品(サービス)やサービスの価値(価値)を高めること」となります。一方、「サービス(商品)の価値を高める」ということは、「価値(付加価値)を上げること」になります。すなわち、サービスや商品の価値を高めることで、「付加価値」が高まり、結果として、価値の高い商品やサービスを提供できるということです。では、どのような言葉を使えばいいのでしょうか。それは、「表現」にします。具体的には、以下のような言葉を使います。このような表現を「表現力」といいます。表現力が高まれば、より、良い商品やサービスを提供することができるため、「良い商品」「良いサービス」「悪いサービス」、「良い価格」など、さまざまな表現ができるようになります。さらに、表現力をアップさせるには、「伝える力」「伝える技術」が必要です。ここでは、「伝わる技術」「伝わる言葉」「伝わり方」について解説します。「伝える」というと、難しいイメージがありますが、実は、簡単にできます。「伝わる」ためには、「伝え方」「伝え方の工夫」が大切です。今回は、伝えるための工夫について紹介しました。</s>
<s>自然言語処理とは、 言語の構造を解析し、その構造から意味を抽出する技術である。</s>
<s>自動車を運転する際、青信号は進む、 信号が青に変わったら停止する、など、交通ルールを守らなければなりません。また、道路交通法では、歩行者や自転車の通行の妨げとなる行為は禁止されています。違反すると、5年以下の懲役または100万円以下の罰金に処せられます。</s>
<s>人工知能 (AI)の進化が加速する中、AIの活用がますます重要になってきています。本セミナーではAI(人工知能)を活用した業務改善の事例をご紹介します。</s>

指示チューニング

事前学習済みLLMは、与えられた文章に対して、その続きを補完する形で出力するものでした。このLLMに対して、人間の指示に従って応答するような振る舞いをさせるために行うファインチューニングを、指示チューニングといいます。本章では、事前学習済みLLMに対して、LoRAを適用して指示チューニングを行います。

次のコマンドを実行することで、指示チューニングが実行可能です。

cd ../../examples/inst-tuning
bash train.sh

train.sh内では、次の学習用スクリプトを実行しています。
書籍「大規模言語モデル入門Ⅱ」11章のipynbスクリプトをベースに作成しています。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/train/sft_train.py

このスクリプトの一部を説明します。
sft_train.pyでは、まず、次の箇所でデータセットを読み込みます。  
使用するデータセットはllm-book/oasst1-21k-jaです。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/train/sft_train.py#L29

事前学習済みモデルと併せて保存されているtokenizerを読み出し、chat_templateを新たに追加します。chat_templateについては、「Transforemrs documentation Chat Templates」も併せてご参照ください。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/train/sft_train.py#L35-L48

作成したchat_templateをデータセットに適用し、データセットを数値IDにエンコードします。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/train/sft_train.py#L66-L69

データセットの形状は、次の通りです。

{"conversation":[{"content":"こんにちは!", "role":"user"},
                 {"content":"こんにちは! ご質問やお困りのことがありましたらご相談ください。", "role":"assistant"}]}  

これに前述のchat_templateを適用すると、次のような文章に変換後、数値IDにエンコードされます。

<s>User: こんにちは!</s><s>Assistant: こんにちは! ご質問やお困りのことがありましたらご相談ください。</s>

つまり、apply_chat_template()の処理では、conversation列の辞書型データにおいて、"role"キーの値が"user"である要素の"content"キーの値に"User: "という接頭語を付与しています。また、conversation列の辞書型データにおいて、"role"キーの値が"assistant"である要素の"content"キーの値に"Assistant: "という接頭語を付与し、これらを連結しています。

また、return_assistant_tokens_mask=Trueを設定することで、assistant_masks配列が作成されます。これは、どのトークンがAssistantの出力に該当するかを表すもので、値が1の場合はAssistant出力、値が0の場合はAssistant以外の出力を表します。Assistant出力はLLMが推論時に生成すべき箇所であり、それ以外の箇所はユーザ等から与えられます。loss計算時は、Assistant出力に該当する箇所のみを計算に含めるようにします。
なお、assistant_masksで1が設定されるのは、chat_template内で{% generation %}{% endgeneration %}に囲まれている記述に該当するトークンです。

次に、以下の処理で、データセットへの正解ラベル付与とミニバッチ作成を行うDataCollatorを定義します。
complementation_only_loss=Trueを設定することで、正解ラベル生成の際に、各データセットのassistant_masksの要素が0のトークンの正解ラベルに-100が当てられます。-100が割り当てられたトークンは、loss計算から除外されます。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/train/sft_train.py#L86-L89

事前学習済みモデルを読み込みます。学習時のメモリ使用量の削減のため、重みを量子化して読み込みます。
load_in_4bit=Truebnb_4bit_quant_type="nf4"を設定することでNF4型量子化でモデルを読み込み、bnb_4bit_compute_dtype=torch.bfloat16の設定により学習時はbf16型に逆量子化して高速計算を行います。
各種設定方法は、こちらにて詳細に説明されています。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/train/sft_train.py#L123-L137

以下の処理で、LoRAの設定を行い、モデルに適用します。LoRAでは、学習中には事前学習済みの重みWの値は変更せず、モデル内の各層の出力を調整するための小さな重み\Delta Wを追加し、\Delta Wを更新します。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/train/sft_train.py#L139-L157

以降の流れはrun_clm.pyと同様です。training_args=TrainingArguments(...)で学習用パラメータを設定し、trainer = Trainer(...)で学習対象のモデルを設定し、trainer.train()で学習を開始します。

ここまでの手順で、以下のモデルが完成します。

https://huggingface.co/ce-lery/mistral-300m-sft

推論2

指示チューニング済みモデルを使用した推論は、次のコマンドで実行可能です。

# cd examples/inst-tuning
bash inference.sh

inference.shは、次の推論用ソースコードを実行しています。

https://github.com/ce-lery/japanese-mistral-300m-recipe/blob/v1.0.0/src/inference/sft_inference.py

推論結果は、下表の通りです。

プロンプト 出力
<s>User:日本で一番高い山は何?</s> <s>Assistant:日本で最も高い山は富士山である。</s>
<s>User:日本で一番大きな湖の名前は?</s> <s>Assistant:日本で最も大きい湖は、山梨県の富士河口湖町です。</s>
<s>User:世界で一番高い山は何?</s> <s>Assistant:世界最高の山は、エベレストです。</s>
<s>User:世界で一番大きな湖の名前は?</s> <s>世界で最も大きい湖は、アメリカ合衆国のフロリダ州オーランドにあるオハイオ州オクラホマ州のオホーツク海に面した湖である。この湖の面積は約1,700km2で、面積の約3分の1が水で占められている。</s>
<s>User:冗談を言ってください。</s> <s>Assistant:もちろん、冗談です!</s>
<s>User:香川県の名物は何?</s> <s>Assistant:うどん、そば、うどんそばの3つです。
うどんは香川県を代表する郷土料理で、讃岐うどんの原型と言われています。うどんには、小麦粉、水、塩、こしょう、コシヒカリなどの植物油を混ぜたものが混ざり合い、独特の風味と食感がある。また、香川県はうどん県としても知られており、県庁所在地の高松市にはうどん店が軒を連ねている。</s>
<s>User:日本の首都はどこ?</s> <s>Assistant:東京、大阪、名古屋、福岡です。</s>
<s>User:こんにちは!</s> <s>Assistant:おはようございます。私はオープンアシスタントです。あなたの質問に答えたり、あなたの質問にお答えします。何かお手伝いできることがあれば、遠慮なく聞いてください。</s>
<s>User:兵庫県の県庁所在地の名前は?</s> <s>Assistant:神戸市、姫路市、明石市などです。</s>

おわりに

本記事では、japanese-mistral-300m-recipeを使用し、LLMの事前学習と指示チューニングの方法について説明しました。

本記事に関するご意見・改善点等がありましたら、是非コメント欄へ記載をお願いいたします。特に、学習時間の高速化についての情報を必要としております。japanese-mistral-300m-recipeへのissue、pull requestも歓迎します。
また、よろしければ本記事へのいいね、Githubリポジトリのスターをお願いいたします。著者の励みになります。

最後に、本記事内でリンクしている情報をご提供いただきました皆様に、心より感謝申し上げます。

Discussion

fizunimofizunimo

素晴らしい記事をありがとうございます。
LLMのpre-trainingについて、実装方法が分からず悩んでいたのでとても参考になりました。
一つ質問なのですが、こちらはhttps://huggingface.co/rinna/japanese-gpt2-mediumの継続事前学習ではなく、configを使用して新しいものを1から作成しているという認識であっているでしょうか。

celerycelery

コメント頂きありがとうございます!

configを使用して新しいものを1から作成しているという認識であっているでしょうか

はい、その通りです。継続事前学習ではなく、事前学習です。新しいモデルを1から学習させています。

fizunimofizunimo

返信が遅くなり申し訳ございません。
ありがとうございます。こちらの記事のおかげで理解が深まりました!