🦑

スプラのブキ紹介文を自動生成してみた(GPT)

2022/10/10に公開

皆さんはスプラトゥーンやっていますでしょうか。

スプラ3では50種類以上のブキがありますが、スプラ2をやってた自分としては亜種がまったくないのでちょっと物足りないです。

そこで考えました。 ブキチに新ブキを作らせよう、と。

新ブキの作り方

スプラトゥーンでは、ブキを購入する前に必ずブキチのうんちく 紹介文を聞く必要があります。

今回はこのブキ紹介文をrinna社が無料で提供している日本語特化GPT2モデル rinna/japanese-gpt2-medium に食わせてファインチューニングしてみます。

結論から言うと、こんな感じで架空のブキ紹介文が無限に生成できます。

この武器はクアッドホッパーブラックシャンクでし!
素早く走ることで相手を追いこむことができるでし!
さらに、ジェットパックで遠くの相手もねらいうち!時にはダイタンな一撃をお見舞いするでし!
バランスの取れた構成なので、自分のスタイルに合った構成のナイフを見つけて欲しいでし!

ジェッパのクアッド、欲しくないですか???(違

ブキチの紹介文の取得

こちらのサイト様から取得させていただきました。

こんな感じ。基本的に語尾が「でし」です。

$ head train.txt
メインウェポンのボールドマーカーは射出口のパイを大きくすることで飛距離をぎせいにし、近接戦に特化したものでし
その高い攻撃力をイカして、相手を確実に倒しつつ前進!カーリングボムでさらに相手の領地に入りこめるでし!
そして、相手のふところにもぐりこんだところにスーパーチャクチでドカン!く~っ!これぞまさにベタ足インファイト!
近距離戦が大好きで仲間の進路を切り開くのが得意な使い手にかわいがって欲しいでし
メインウェポンのボールドマーカーネオはボールドマーカーに純正エンブレムがほどこされたモデルでし!
ぶっちゃけ中身はいっしょなので、どんどん相手の領地に入りこみ、サブウェポンのジャンプビーコンで仲間を呼びこむでし!
さらに、マルチミサイルで相手をあぶり出して仲間といっしょに一気に勝負を決めるでし!
接近戦ばかりではなく、仲間とも協力する戦い方が好きな使い手にかわいがって欲しいでし
メインウェポンのボールドマーカー7はボールドマーカーと同じ性能を確保しつつ幅広い戦い方を実現させたモデルでし!
サブのスプラッシュボムで中距離をカバーしウルトラハンコを投げれば遠距離もカバー出来るブキに仕上がったでし!

環境構築

まずは必要なライブラリをインストール

pip install git+https://github.com/huggingface/transformers@v4.22.2
pip install sklearn

ファインチューニングするために本家のソースコードをダウンロードします。pipでインストールしたバージョンと合わせること。

git clone https://github.com/huggingface/transformers -b v4.22.2

rinna社の学習済みモデル場合、TokenizerとしてAutoTokenizerではなく、T5Tokenizerを明示的に指定する必要があるので置き換えます。

sed -i 's/AutoTokenizer/T5Tokenizer/' ./transformers/examples/pytorch/language-modeling/run_clm.py

run_clm.pyを実行するのに必要なライブラリをまとめてインストールします。

pip install -r ./transformers/examples/pytorch/language-modeling/requirements.txt

学習

いざ学習

python ./transformers/examples/pytorch/language-modeling/run_clm.py \
    --model_name_or_path=rinna/japanese-gpt2-medium \
    --train_file=train.txt \
    --validation_file=train.txt \
    --do_train \
    --do_eval \
    --num_train_epochs=30 \
    --save_steps=10000 \
    --save_total_limit=3 \
    --per_device_train_batch_size=1 \
    --per_device_eval_batch_size=1 \
    --output_dir=output/ \
    --overwrite_output_dir=true \
    --use_fast_tokenizer=False

しばらくするとこんな感じで学習が完了します。

[INFO|trainer.py:2920] 2022-10-09 01:45:35,492 >> ***** Running Evaluation *****
[INFO|trainer.py:2922] 2022-10-09 01:45:35,492 >>   Num examples = 29
[INFO|trainer.py:2925] 2022-10-09 01:45:35,493 >>   Batch size = 1
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:04<00:00,  5.82it/s]
***** eval metrics *****
  epoch                   =       30.0
  eval_accuracy           =     0.9988
  eval_loss               =     0.0157
  eval_runtime            = 0:00:05.17
  eval_samples            =         29
  eval_samples_per_second =      5.605
  eval_steps_per_second   =      5.605
  perplexity              =     1.0158
[INFO|modelcard.py:443] 2022-10-09 01:45:40,817 >> Dropping the following result as it does not have all the necessary fields:
{'task': {'name': 'Causal Language Modeling', 'type': 'text-generation'}, 'metrics': [{'name': 'Accuracy', 'type': 'accuracy', 'value': 0.9987865304884215}]}

自動生成

GPTの場合、仕様上文の冒頭は与えなければなりません。多様に出力してほしかったので「この武器は」からブキ紹介文を自動生成してもらいます。

model.generate()のパラメータで生成される文章がかなり変化します。

import torch
from transformers import T5Tokenizer, AutoModelForCausalLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
tokenizer.do_lower_case = True
model = AutoModelForCausalLM.from_pretrained("output/")
model.to(device)

input = tokenizer.encode("この武器は", return_tensors="pt", add_special_tokens=False).to(device)

with torch.no_grad():
  output_ids = model.generate(input,
    max_length=100, # 文章の最大帳
    min_length=60, # 文章の最小帳
    do_sample=True, # サンプリングするか Falseの方が自然な文章だが学習文章量が少ないので出力できなくなるケースも
    top_k=200, # 上位{top_k}個の文章を保持 多いほどブッ飛んだ文章ができがち
    top_p=0.95,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
  )

output = tokenizer.decode(output_ids.tolist()[0], skip_special_tokens=True)
print(output)

色々生成してみました!

例1

この武器は自動連射で相手にプレッシャーをかけながら前進し、
周囲を一気に塗り広げることも可能になるでし!
スーパーチャクチも組み合わせれば、
相手をまとめて仕留めることも可能になっているでし!
相手の領地に入り込んだり、仲間と協力して戦うのが好きな使い手にかわいがって欲しいでし!

スーパーチャクチが一網打尽できるスペシャルかは議論の余地がありますが、かなり自然な文章 ブキチの口調も反映されてます。

例2

この武器はクアッドホッパーブラックシャンクでし!
素早く走ることで相手を追いこむことができるでし!
さらに、ジェットパックで遠くの相手もねらいうち!時にはダイタンな一撃をお見舞いするでし!
バランスの取れた構成なので、自分のスタイルに合った構成のナイフを見つけて欲しいでし!

ジェットパック搭載クアッドホッパーブラックシャンク(?) 強そう

例3

この武器は、相手をけん制しながら自らもけん制しつつ進むことができるローリングボムと、
移動しながら広範囲を塗るマルチミサイルのセットでし!
スーパーショットも使えるので相手をまとめて仕留めるのが得意な使い手にかわいがって欲しいでし!
※こちらの商品は、ご注文後のメーカー取寄せとなりますので、納期が1~2ヶ月ほどかかります。
あらかじめご了承ください。

めっちゃ笑った。こういう学習外の文章が出るのも味がありますね。(?

例4

自分はスプラシューターが好きなんですが、スペシャルがジェットパックで来てほしいです。そこで 「スプラシューターは最強の武器でし!スペシャルがジェットパックでサブが」 まで書いて続きを自動生成してもらいました。

スプラシューターは最強の武器でし!スペシャルがジェットパックでサブがキューバンボムなのも注目でし!
スーパーセンサーで相手の位置をとらえたら、持ち前の射程とイカスフィアで一気に相手をセンメツするでし!
今にも世界を救えそうな気分になるでし~ぜひぜひ手にとってかわいがって欲しいでし!

ジェッパ+スーパーセンサー+イカスフィアの最強チーター武器が誕生してしまった…これは世界救えるわ

まとめ

  • huggingface/transformersrinna/japanese-gpt2-mediumを組み合わせるとほぼノーコードで学習出来る
  • 少ない文章でもブキチの口調が再現されるほどにはファインチューニング可能
  • スペシャル、サブが連続で出てしまっているケースがあったので、各々をラベル付けすればより自然なブキチが作れるかも?

宣伝

環境

  • Python 3.10
  • transformers v.4.22
  • pytorch 1.12.1+cu11
  • GCP NVIDIA Tesla P100x1 / n1-highmem-4

参考リンク

Discussion