OpenCALM(CyberAgentのLLMモデル)をファインチューニングして、イナババ怪文書自動生成AIを改良した
速報的に取り組んだテーマなので、かなり雑。
近日中にしっかり検討します。
5/20 : 追記(cyberagent/open-calm-largeでの推論)
動機
Rinna社の方からrinna/japanese-gpt-neox-3.6bという日本語特化型LLMが発表された。また対話型のモデルのrinna/japanese-gpt-neox-3.6b-instruction-sftという対話型のモデルも同時発表され、話題になっている。
この発表のインパクトが強くて少し影が薄くなっているが、実はほぼ同時期にOpenCALMからも日本語特化LLMモデルがリリースされている。こちらは種類が豊富であり、パラメータ数が160Mの小規模モデルから7bの比較的大きなモデルまで様々なモデルがあり、こちらも良い選択肢だと感じる。
そしてさらにうれしいことに、こちらのモデルは中規模モデル(cyberagent/open-calm-medium)の時点で絵文字に対応している。例えば、絵文字を含んだ文をファインチューニングしていないOpenCLAMに導入すると、以下のような出力が得られる。
入力 : 楽しみ!😁
出力 : 楽しみ!😁\n\n今日、明日と海の日🎶\n\n海の日も、今日はいいお天気☀︎\n\n今日も1日、良い日になりますように😊\n\n#海の日 #海の日海の家\n#海の日海の家\n#海の日海の家\n#海の日海の日\n#海'
GPT特有の文の繰り返しが発生してしまっているが、これはmodel.generateの引数をいじっていないのでご愛嬌である。
前々回の記事で作成した怪文書イナババ怪文書生成AIはrinna/japanese-gpt2-mediumをベースにしており、元の怪文書に含まれている絵文字が<unk>に置き換えられてしまうという欠点があった。しかし、このモデルをベースに使えば、絵文字を含んだ真のイナババ怪文書を生成できる可能性がある。
そこで、今回は、パラメータ数400Mのcyberagent/open-calm-mediumをファインチューニングし、イナババ怪文書を自動生成するAIを作成した。
ファインチューニング
hugging faceのいいところは、使用するモデルを変更した場合の手軽さである。
以下にファインチューニングに用いたコードを示す。
前々回の記事で作成したコードのAutoModelForCausalLM.from_pretrained()の中に指定するモデル名を変更する以外はほとんど何も書き換えていない。
ファインチューニング用コード
import csv, json, mojimoji, re, os, sys, emoji, pickle
import numpy as np
from tqdm import tqdm
from pykakasi import kakasi
#torchの読み込み前に環境変数を固定
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch
import evaluate
from transformers import AutoModelForCausalLM, AutoTokenizer, T5Tokenizer
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, AutoModelWithLMHead
class Zupposhi_maker:
#GPT2のモデル名
gpt_model_name = "cyberagent/open-calm-medium"
#データセットのcsvファイル
csv_path = "path/to/csv_file"
csv_enc = "utf-8-sig"
#教師データをGPT用のファイルとして出力する際のパス
gpt2train_path = "path/to/train_data_file"
#文章の最大長
min_len = 32
max_len = 100
#学習回数,バッチサイズ
Nepo = 100
bsize = 8
#途中経過を表示する学習回数の間隔
logging_steps = 200
#モデルを保存する間隔
save_freq = 100000
#結果の出力先
odir = "path/to/ouput_dir/"
#予測時のパラメータ
top_k = 40 #top-k検索の閾値
top_p = 1 #top-pの閾値
num_text = 1 #出力する文の数
temp = 1.0
repeat_ngram_size = 1
#推論にCPUを使用するか
use_cpu = True
def __init__(self, ft_path = None, isTrain = True):
"""コンストラクタ
コンストラクタ。モデルをファイルから読み込む場合と,
新規作成する場合で動作を分ける.
Args:
ft_path : ファインチューニングされたモデルのパス.
Noneを指定すると
train : 学習を行うか
Returns:
なし
"""
print("GPU is available : {}".format(torch.cuda.is_available()))
#モデルの設定
self.__SetModel(ft_path)
#教師データの読み込み
if isTrain:
self.__LoadDataSet()
def __SetModel(self, ft_path = None):
"""GPT2の設定
GPT2のTokenizerおよびモデルを設定する.
ユーザー定義後と顔文字も語彙として認識されるように設定する.
Args:
ft_path : ファインチューニング済みのモデルを読み込む
何も指定しないとself.gpt_model_nameの事前学習モデルを
ネットからダウンロードする.
Returns:
なし
"""
#GPT2のTokenizerのインスタンスを生成
self.tokenizer = AutoTokenizer.from_pretrained(
self.gpt_model_name
)
#self.tokenizer.do_lower_case = True # 今回はrinna社のモデルではないので必要なし。
#モデルの読み込み
if ft_path is not None:
self.model = AutoModelForCausalLM.from_pretrained(ft_path, device_map="auto")
else:
self.model = AutoModelForCausalLM.from_pretrained(self.gpt_model_name, device_map="auto")
def __LoadDataSet(self):
"""データセットのロード
怪文書データセットの読み込み
Args:
csv_name (string) : csvファイルのファイル名
Rtest (float) : テスト用のデータの割合
Returns:
なし
"""
#csvファイルを読み込む
data = []
with open(self.csv_path, "r", encoding = self.csv_enc) as f:
reader = csv.reader(f, delimiter = ",")
for row in reader:
#空行またはコメント行なら読み飛ばす
if(row[0] == "\n"):
continue
if(row[0][0] == "#"):
continue
#怪文書なら,読み取り結果をリストに保存
if(int(row[0]) == 1):
data.append([row[1], row[2], row[3]])
#教師データの成形と,絵文字の抽出
with open(self.gpt2train_path, "w", encoding = "utf-8-sig") as f:
for row in tqdm(data):
ret = self.__TextCleaning(row)
To, Body, From = ret[0], ret[1], ret[2]
#手紙の宛名+送り主から本文を予測するタスクを行う.
#もしも送り主が空欄でなくて、かつ末尾に句読点や感嘆符が付いていないなら"。"をつける
if (From != "") & (not (From.endswith( (".", ".", "。", "!", "!", "?", "?") ))):
From = From + "。" #末尾に"。"や"!", "?"が付いていないなら区切る。
if (Body != "") & (not (Body.endswith( (".", ".", "。", "!", "!", "?", "?") ))):
Body = Body + "。" #末尾に"。"や"!", "?"が付いていないなら区切る。
if (To != "") & (not (To.endswith( (".", ".", "。", "!", "!", "?", "?") ))):
To = To + "。" #末尾に"。"や"!", "?"が付いていないなら区切る。
#テキストを学習用の形式に編集
text = To + Body + From
#text = "".join(tokens).replace('▁', '')
print(text)
f.write(text + "\n")
def __TextCleaning(self, texts):
"""テキストの前処理をする
テキストの前処理を行う.具体的に行うこととしては...
・全角/半角スペースの除去
・半角数字/アルファベットの全角化
"""
#半角スペース,タブ,改行改ページを削除
texts = [re.sub("[\u3000 \t \s \n]", "", t) for t in texts]
#半角/全角を変換
texts = [mojimoji.zen_to_han(t, kana=False) for t in texts]
return texts
def TrainGPT2(self):
"""GPT2のファインチューニング
GPT2の設定とファインチューニングをする
"""
#データセットの設定
train_dataset = TextDataset(
tokenizer = self.tokenizer,
file_path = self.gpt2train_path,
block_size = self.max_len #文章の長さを揃える必要がある
)
#データ入力についての設定
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm= False
)
#学習についての設定
os.makedirs(self.odir + "gpt2-ft", exist_ok=True) #結果出力先のディレクトリがあれば作成
training_args = TrainingArguments(
output_dir=self.odir + "gpt2-ft",
overwrite_output_dir=True,
num_train_epochs=self.Nepo,
per_device_train_batch_size=self.bsize,
logging_steps=self.logging_steps,
save_steps=self.save_freq
)
#上記の設定をtransformerのTrainerクラスに適用
trainer = Trainer(
model =self.model,
args=training_args,
data_collator = data_collator,
train_dataset = train_dataset
)
#学習開始
print("start ... ")
trainer.train()
print("finish!")
print("saving...")
#モデルをCPU/GPUのどちらかに移す
if self.use_cpu: #推論時にCPUの利用を強制する場合の処理
device = torch.device('cpu')
else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model.to(device)
#モデルを保存する
trainer.save_model()
print("finish!")
def GenLetter(self, prompt):
"""怪文書の生成
GPT2で怪文書を生成する.
promptに続く文章を生成して出力する
Args:
prompt : 文章の先頭
Retunrs:
生成された文章のリスト
"""
#文章をtokenizerでエンコード
x = self.tokenizer.encode(prompt, return_tensors="pt")
if self.use_cpu: #CPUの利用を強制する場合の処理
device = torch.device('cpu')
else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
x = x.to(device)
#gptによる推論
with torch.no_grad():
y = self.model.generate(
x,
min_length=self.min_len, # 文章の最小長
max_length=self.max_len, # 文章の最大長
do_sample=True, # 次の単語を確率で選ぶ
top_k=self.top_k, # Top-Kサンプリング
top_p=self.top_p, # Top-pサンプリング
temperature=self.temp, # 確率分布の調整
no_repeat_ngram_size = self.repeat_ngram_size, #同じ単語を何回繰り返していいか
num_return_sequences=self.num_text, # 生成する文章の数
pad_token_id=self.tokenizer.pad_token_id, # パディングのトークンID
bos_token_id=self.tokenizer.bos_token_id, # テキスト先頭のトークンID
eos_token_id=self.tokenizer.eos_token_id, # テキスト終端のトークンID
early_stopping=True
)
# 特殊トークンをスキップして推論結果を文章にデコード
res = self.tokenizer.batch_decode(y, skip_special_tokens=True)
return res
具体的な変更点は...
- 事前学習モデルの読み込み部分をAutoModelForCausalLM.from_pretrained(cyberagent/open-calm-medium)に変更する。(当たり前)
- AutoModelForCausalLM.from_pretrained()の引数の一つでrinna社のモデルだとuse_fast = Falseを指定しろとの記述が公式からあったが、こちらのモデルでは必要ない。
結果
以下に生成した例を示す。generateのパラメータは前回と同じ。
見たところファインチューニングはできているが、改善の余地ありである。
生成例1
入力 : 🐱ニコバンバン🐱かわいいのら😁。ドュッピ!イナババだっピ!
出力 : 🐱ニコバンバン🐱かわいいのら😁。ドュッピ!イナババだっピ!許さぬぞ...!ぺしみちのタンザックに載せられてしまったアザラシトモ。-追記-ズッポシかるた<ほ>弾むような笑顔が素敵な村の代表です モモチへ★タムリさん便り、ごきげん
生成例2
入力 : おいでませピータン!村CEOのイナババです。
出力 : おいでませピータン!村CEOのイナババです。ようこそ、お越し下さいました!!当館はガラクタの山なのでござる・・・。(続く)。こんにちは😁イナバンバビタワゾー🐱プンプイ💀セール中だね🎵見に来てくれてありがとう😊実は、、、この服3980ベルで😂通販してるのんで新品未開封なの
文章が少々めちゃくちゃだが、絵文字は大量に出力されたことと、ズッポシ村やどうぶつの森に特有の名詞(「ぺしみち」、「モモチ」、「ズッポシかるた」、「ベル」など)は入っていることからより大きなモデルを長時間ファインチューニングすれば光明はありそう。
largeモデルならGoogle Colabの方でA100やV100を使えば動かせると思うので、そちらにもチャレンジしてみたい。
5/20追記 : cyberagent/open-calm-largeをファインチューニング
パラメータ数830Mのlargeモデルも試してみた。学習回数100およびバッチサイズ4で回したところ、Google ColabのV100で20分ほどかかった。
出力結果を以下に示す。
生成例1
入力 : 🐱ニコバンバン🐱かわいいのら😁。ドュッピ!イナババだっピ!
出力 : 🐱ニコバンバン🐱かわいいのら😁。ドュッピ!イナババだっピ!お手紙、送ったよ🕺ぜひ着てみてね。...あっ、「ツバクロ」って口癖も似合いそうだから覚えておいてほしいな...💀ハニワを送ります🎵ダサ男改造委員たぬきち
サラちゃんへ。★
入力 : おいでませピータン!村CEOのイナババです。
出力 : おいでませピータン!村CEOのイナババです。イナババンにはいつもお世話になっております😁この服、着てくれると嬉しいです。...あっ、(・Д・)ノ?この流れからしてヘンな予感がするんだけど・・・大丈夫かねぇ・・・?とりあえず、「ダサ男改造委員」に任命します!!
🐱ニコバーかわいい博物館へ🕺。「モ
ところどころ変な部分や文脈が通っていない部分もあるがちゃんとイナババ怪文書になっている。
ただ、動かしてみた印象として、CPUオンリーの場合のrinna/japanese-gpt2-mediumに比べるとかなりモデルの読み込み・推論が遅い(特にモデルの読み込みがストレス)。「ファインチューニングしたモデルをWebで公開する」みたいな用途だと、GPUが使える環境を用意しないと厳しいかも。
参考文献
npaka氏「Google Colab で OpenCALM-7B を試す」
Discussion