📘

自宅のRTX3060で小さなLLMを自作してみた

に公開

こちらのイベントに参加するためのネタとして自宅のPC(RTX3060)で青空文庫のデータセットを使ってトークナイザーを自作しGPT-2アーキテクチャの42.1Mのモデルでの事前学習をやってみました。
https://aimeetup.connpass.com/event/367666/

(イベントがなかったら事前学習をやってみようとは思いもしなかったはずでこのような機会を設けてくださったぬこぬこさんには大変感謝しています!)

データセット

こちらの前処理済みの青空文庫のデータセットを使いました。
https://huggingface.co/datasets/globis-university/aozorabunko-clean

さらに今回は以下のように文字遣いを新字新仮名に絞って学習を行います

raw_dataset = load_dataset("globis-university/aozorabunko-clean", split="train")
filtered_dataset = raw_dataset.filter(lambda row: row["meta"]["文字遣い種別"] == "新字新仮名")

print(filtered_dataset)
print(f"Filtered dataset size: {filtered_dataset.num_rows:,} entries")

およそ1万件のデータになります。

Dataset({
    features: ['text', 'footnote', 'meta'],
    num_rows: 10246
})
Filtered dataset size: 10,246 entries

トークナイザー

青空文庫のデータのみで学習を行うということもあり既存のトークナイザーを使うより専用のものを使用した方が良いだろうと考え自作することにしました。
SentencePiece を直接呼び出し unigram モデルを学習しています。
一般的なCPUですが10分ほどで作成が完了しました。

spm.SentencePieceTrainer.train(
    sentence_iterator=corpus_iterator(filtered_dataset),

    model_prefix=str(OUTPUT_DIR / "aozora_spm_gpt2"),
    model_type="unigram",
    vocab_size=32000,
    character_coverage=0.9995,
    normalization_rule_name="nfkc",
    byte_fallback=True,

    unk_id=0, bos_id=-1, eos_id=-1, pad_id=-1,
    user_defined_symbols=["<|endoftext|>"],

    input_sentence_size=2_000_000,
    shuffle_input_sentence=True,
    train_extremely_large_corpus=True,
    num_threads=8,

    add_dummy_prefix=False,
    remove_extra_whitespaces=False,
    hard_vocab_limit=False,
)

パラメーターについて

  • model_type: unigram 互換のトークン列を維持しつつ日本語特有の語尾を細かく刻みたかったので unigram を選択。
  • vocab_size: 32,000 なら GPT-2 Small と同じスケールの埋め込み行列です。
  • character_coverage: 0.9995 にすることで歴史的仮名遣いなどの出現頻度が極端に低い文字を落としつつ、byte_fallback でどうしても必要な文字はバイト列として表現します。
  • input_sentence_size=2_000_000 に制限することで SentencePiece のサンプリングが安定し、毎回同じような語彙が得られます。
  • shuffle_input_sentence=True を立てて青空文庫特有の長大な作品でも偏りなくサブワードが抽出されるようにしました。
  • <|endoftext|> は GPT-2 の EOS と揃える目的で user_defined_symbols に追加し、学習データの文末にも常に付与しておくことで後段の Trainer が PAD と EOS を兼用できます。
  • normalization: normalization_rule_name="nfkc"byte_fallback=True の組み合わせで旧字体を正規化しつつ未知文字を落とさないようにしています。

結果

トークナイザーとして動かしてみた結果です。当然「ウェストミンスター」のような固有名詞には弱いですが。「となって」、「集まっている」などが一つの語彙として扱われているのはよさそうです。

--- Sample 1 ---
Original: 深いおどろきにうたれて、
Token IDs: [874, 11406, 264, 366, 15219, 258]
Pieces: ['深い', 'おどろき', 'に', 'う', 'たれて', '、']
Decoded: 深いおどろきにうたれて、
Unknown tokens: 0
--- Sample 2 ---
Original: 名高いウェストミンスターに
Token IDs: [15899, 14107, 2047, 10438, 7545, 264]
Pieces: ['名高い', 'ウェ', 'スト', 'ミン', 'スター', 'に']
Decoded: 名高いウェストミンスターに
Unknown tokens: 0
--- Sample 3 ---
Original: 真鍮や石の記念碑となって
Token IDs: [21674, 278, 3600, 5815, 9470, 1070]
Pieces: ['真鍮', 'や', '石の', '記念', '碑', 'となって']
Decoded: 真鍮や石の記念碑となって
Unknown tokens: 0
--- Sample 4 ---
Original: すべての王侯貴族が集まっているのをみれば、
Token IDs: [2200, 965, 6015, 5559, 262, 26889, 260, 266, 5117, 258]
Pieces: ['すべての', '王', '侯', '貴族', 'が', '集まっている', 'の', 'を', 'みれば', '、']
Decoded: すべての王侯貴族が集まっているのをみれば、
Unknown tokens: 0
--- Sample 5 ---
...
Token IDs: [1291, 12277, 3254, 268, 258, 18540, 268, 258, 340, 2529, 510, 259]
Pieces: ['今は', 'さげ', 'すみ', 'も', '、', 'ほこり', 'も', '、', '見', '栄', 'もない', '。']
Decoded: 今はさげすみも、ほこりも、見栄もない。
Unknown tokens: 0

モデル

RTX 3060 でも常用できる 8 層 512 次元の GPT-2 を構築しました。この構成では42.1M のパラメータ数になります。

パラメーターについて

  • n_embd=512 は RTX3060 の 12GB に収まるギリギリの線で、n_layer を 8 にすることで学習時間とのバランスをとりました。
  • n_ctx=256 は青空文庫の段落単位の文脈をカバーしつつ、バッチを 32 サンプル積める程度に抑えています。
  • BOS/EOS/PAD を <|endoftext|> に統一し、欠損を気にせずに Trainerdefault_data_collator を使えるようにしています。
  • use_cache=False:トレーニング時の再計算を許容してでも VRAM 消費を抑え、勾配爆発を起こしにくくしています。
config = GPT2Config(
    vocab_size=vocab_size,
    n_positions=block_size,
    n_ctx=block_size,
    n_embd=512,
    n_layer=8,
    n_head=8,
    bos_token_id=endoftext_id,
    eos_token_id=endoftext_id,
    pad_token_id=endoftext_id,
)

model = GPT2LMHeadModel(config)
model.config.use_cache = False

事前学習

パラメーターについて

  • epoch=1:試しに1epochの学習を実施しました。
  • block_size=256:長編小説でも段落単位でまとまりが出る長さ。VRAMとの兼ね合いで妥協したところです。
  • per_device_train_batch_size=2 × gradient_accumulation_steps=16:デバイスあたり 2 サンプルしか積めない設定でも 32 サンプル相当で更新でき、バッチサイズ依存の学習安定性を確保できます。
  • learning_rate=5e-4 + warmup_steps=500:事前学習を 1Epoch で終えるため高めの学習率を採用しつつ、最初の 500step で少しずつ上げて暴走を防ぎます。
  • weight_decay=0.1:語彙を新規に学習しているので正則化を強めに入れ、まれなトークンでの overfitting を抑えています。
  • eval/save/logging_steps=500:コーパス全体で 3000step 前後なので 500step ごとにイベントを置くと、計 5 回程度のスナップショットで推移を追える計算です。
  • fp16=True:RTX3060 でも学習を高速化でき、VRAM のピークも 1〜1.2GB ほど下げられます。
tokenized = dataset_dict.map(
    encode_with_sentencepiece,
    batched=True,
    remove_columns=original_columns,
    desc="Tokenizing with SentencePiece",
)

lm_datasets = tokenized.map(
    group_texts,
    batched=True,
    desc="Grouping texts into fixed-size blocks",
)

training_args = TrainingArguments(
    output_dir=str(output_dir),
    num_train_epochs=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=16,
    learning_rate=5e-4,
    warmup_steps=500,
    weight_decay=0.1,
    logging_steps=100,
    eval_steps=500,
    save_steps=500,
    evaluation_strategy="steps",
    save_strategy="steps",
    fp16=torch.cuda.is_available(),
    report_to=[],
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["test"],
    data_collator=default_data_collator,
)

学習経過

1時間45分ほどで学習は完了しました。
Train Lossの推移としては以下のようになりました。

結果

青空文庫で学習させた甲斐があってか特に抽象的な問いを投げればそれらしいことを言うようになりました。

人生とは、全く同個の一段階の人間学の「神」とその統一の上に存在するものとして見られている。しかし私は、この思想的実在に於て、初めて、世界的なる存在史という形式を自覚し来った。而してそこに哲学が展開されたのである。

愛とは、その感情の直接関係にそなうものでなくてはならないのであって、したがってそれは単に単なる性質上の相違によることを意味するにすぎないと考えるのである。もしこの矛盾が原因するとすれば、自然はこの行為によって変化しないという結果になるというのである。

神とは、全く別の世界ではないのであります。この二つが何であるかというと、一つは一つに分れてその同一領域へ通ずること、また別のものとして同一範囲に存するのでありましょう。これに対する需要価値もそのいずれの時代に於て異るものであるかと存じます。

事前学習会参加者の中では断トツでパラメタ数が小さなモデルだったと思いますが出力に関しては結構面白がってもらえたのでうれしかったです。

試してみたい方向け

一応トークナイザーとモデルに関してはHugging Faceにアップロードしています。
https://huggingface.co/auhulu/aozora-gpt2

推論を実行できるColabも用意したので興味ある方はぜひ触ってみてください。
https://colab.research.google.com/drive/1MRJCf_ztdxJ2PHPblz23Bshgl2i0YvVW

Discussion