自作データセットでWhisperをファインチューニングしたら、独自用語だらけのクラロワ実況でも使えるようになった:「ファインチューニング編」
前回のあらずじ
前の記事はこちら
ファインチューニング用の自作データセットをspreadsheetで作れたので、これを使ってファインチューニングしていきたい。
作った独自データセットでファインチューニングする
スプレッドシートで音声データと正解ラベルを作ることができたので、これを使ってファインチューニングします。
最初に使うライブラリをインストール
!pip install datasets>=2.6.1
!pip install git+https://github.com/huggingface/transformers
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio
!pip install git+https://github.com/openai/whisper.git
!pip install -U gspread
!pip install gspread-dataframe
まずはスプレッドシートのデータセットをdataframeにします。
from google.colab import auth
auth.authenticate_user()
import gspread
from google.auth import default
creds, _ = default()
gc = gspread.authorize(creds)
ws = gc.open_by_url("spreadsheetのurl").worksheet("シート1")
records = ws.get_all_records()
df = pd.DataFrame(records)
まずは先ほどデータセット作成にもつかったgspreadを利用します。
spreadsheetのurlをコピペしてきて、そのurlをopen_by_url
の引数に入れて実行するとspreadsheetのオブジェクトが取れます。
とってきたシートオブジェクトのget_all_records
メソッドでシートの中身を辞書でとってくることが可能です。
今回の場合
[
{
"url":"音声データのdriveのurl",
"colab_path": "音声データのcolabのバス(/content/drive...)",
"sampling_rate":44000 # 音声データのサンプリングレート,
"correct":"正解テキスト",
"whisper": "whisperの推論テキスト"
},
...
]
みたいな配列が取れます。この形はそのままデータフレームにできるので、データフレーム化します。
その後トレーニング用のデータセットと検証用のデータセットを分け、colab_pathからHugginFaceのAudioオブジェクトを作ります。
msk = np.random.rand(len(merged_df)) < 0.7
train_dataset = Dataset.from_pandas(df[msk]).cast_column("colab_path", Audio(sampling_rate=16000)).rename_column("colab_path", "audio").remove_columns(["sampling_rate"])
validate_dataset = Dataset.from_pandas(df[~msk]).cast_column("colab_path", Audio(sampling_rate=16000)).rename_column("colab_path", "audio").remove_columns(["sampling_rate"])
まず
msk = np.random.rand(len(merged_df)) < 0.7
train_dataset = Dataset.from_pandas(df[msk])
validate_dataset = Dataset.from_pandas(df[~msk])
の部分で、トレーニング用と検証用を7対3になるように分けます。
さらに、音声データのパスからAudioの配列を作るのはcast_column
メソッドでできるのは前回と同じですが、今回学習用データがsampling_rateが16000じゃないとダメということで、なんとAudio(sampling_date=16000)とやるだけでいい感じにそのデータにしてくれます。
train_dataset = Dataset.from_pandas(df[msk]).cast_column("colab_path", Audio(sampling_rate=16000)).rename_column("colab_path", "audio").remove_columns(["sampling_rate"])
ちなみにこの時にじゃあサンプリングレートのデータいらないじゃんってなってcolumnからremoveしているのと、音声データの列をaudioという列名に変えたかったので.rename_column("colab_path", "audio")
もしています。便利。
これらをまとめてDatasetDict
に格納します。
datasets = DatasetDict({
"train": train_dataset,
"validate": validate_dataset
})
ここからはこの素晴らしいブログの通りにやってくだけです。ほとんど同じなのでブログ見ていただいた方がいいと思いますが、一応書いていきます。書かないのも不親切だと思うので一応・・・
最終的に必要なものは以下のコードです。
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=prepared_datasets["train"],
eval_dataset=prepared_datasets["validate"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
trainer.train()
必要なものとして、
- training_args: 学習のパラメータ
- prepared_dataset : 学習用にラベルづけ等データセット
- data_collator: 音響特徴量とラベルをそれぞれ前処理するためのクラス
- compute_metrics: 評価関数。今回ここは日本語の差を評価する必要があるため分かち書きなど必要。
なのでこれを作っていきます。
prepared_datasetの作り方
まずwhisperのtorknizeや音響特徴量抽出を行うprocessorのHuggingFaceバージョンを定義します。
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained("openai/whisper-large", language="Japanese", task="transcribe")
その後、先ほど作ったDatasetDict
をこのprocessor
で前処理するprepare_dataset
関数を作ります
def prepare_dataset(batch):
audio = batch["audio"]
# 音響特徴量抽出
batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# 正解のテキストをlabel idにエンコード
batch["labels"] = processor.tokenizer(batch["correct"]).input_ids
return batch
DatasetDict
オブジェクトはtrainとvalidateそれぞれに格納されてるデータセットに関数を適用するmap
関数をもつので、それでprepare_dataset関数をdatasetに適用します。
prepared_datasets = datasets.map(prepare_dataset, remove_columns=datasets.column_names["train"], num_proc=1)
data_collatorの作り方
公式がそのまま書いてくれてるのと、解説も詳しくしてくれています。参照
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int]
, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# 音響特徴量側をまとめる処理
# (一応バッチ単位でパディングしているが、すべて30秒分であるはず)
input_features \
= [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# トークン化された系列をバッチ単位でパディング
label_features = [{"input_ids": feature["labels"]} for feature in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# attention_maskが0の部分は、トークンを-100に置き換えてロス計算時に無視させる
# -100を無視するのは、PyTorchの仕様
labels \
= labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# BOSトークンがある場合は削除
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
# 整形したlabelsをバッチにまとめる
batch["labels"] = labels
return batch
作ったクラスのコンストラクタにprocessorを格納します。
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
compute_metricsの作り方
評価関数を作るのですが、これだけHuggingFace公式では日本語の差分評価はできないので自作する必要があります。
ただ先ほどから何回も言っている素晴らしいブログがやり方を載せてくれてます
まず分かち書きのmecabを入れるためにginzaを入れます。
このqiitaで解説してますが、sortedcontainersのバージョンがcolabだとバグるのとpkg_resourcesを再度reloadしないといけないためランタイム再起動ではない方法でリロードさせてるそうです。
!pip install ginza==4.0.5 ja-ginza
!pip install sortedcontainers~=2.1.0
import pkg_resources, imp
imp.reload(pkg_resources)
import evaluate
import spacy
import ginza
metric = evaluate.load("wer")
nlp = spacy.load("ja_ginza")
ginza.set_split_mode(nlp, "C") # CはNEologdの意らしいです
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
# 分かち書きして空白区切りに変換
pred_str = [" ".join([ str(i) for i in nlp(j) ]) for j in pred_str]
label_str = [" ".join([ str(i) for i in nlp(j) ]) for j in label_str]
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
ここで評価関数を作ります。分かち書きしたものに対して評価をするため、
pred_str = [" ".join([ str(i) for i in nlp(j) ]) for j in pred_str]
label_str = [" ".join([ str(i) for i in nlp(j) ]) for j in label_str]
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
ここは公式では書いてない方法になってます。
training_argsの作り方
学習用パラメータを定義します。
ちなみにstep数が公式ブログの100分の1になっているのですが、とりあえずこれでやってみても十分精度よかったので精度悪かったら変える程度でもいいかも?
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./test", # change to a repo name of your choice
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
# warmup_steps=500, # Hugging Faceブログではこちら
warmup_steps=5,
# max_steps=4000, # Hugging Faceブログではこちら
max_steps=40,
gradient_checkpointing=True,
fp16=True,
group_by_length=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
# save_steps=1000, # Hugging Faceブログではこちら
save_steps=10,
# eval_steps=1000, # Hugging Faceブログではこちら
eval_steps=10,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=False,
)
学習実行!
あとは上で作ったもので学習します。
ちなみにこの時colabのプレミアムクラスのGPUじゃないとGPUのメモリ不足で落ちました。
多分1000円課金しないと厳しいです。
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=prepared_datasets["train"],
eval_dataset=prepared_datasets["validate"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
trainer.train()
モデルのsave
Hugging Faceのトレーニング済みモデルはめちゃくちゃ簡単にローカルに保存できます。
save_model
メソッドにローカルのパスを指定するだけです。
trainer.save_model("/content/drive/.../クラロワ/model")
これでプレミアムクラスのGPUを使って毎回学習させなくて済みます。経済的に安心ですね。
学習済みモデルでpredict。
実験をする時なのですが、結局processor等は元のが必要なので、学習時と全く同じことを行います。
prepare_dataset関数とdata_collatorクラスを作ってそれを用意します。
from transformers import WhisperProcessor
from dataclasses import dataclass
from typing import Any, Dict, List, Union
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# encode target text to label ids
batch["labels"] = processor.tokenizer(batch["correct"]).input_ids
return batch
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int]
, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# 音響特徴量側をまとめる処理
# (一応バッチ単位でパディングしているが、すべて30秒分であるはず)
input_features \
= [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# トークン化された系列をバッチ単位でパディング
label_features = [{"input_ids": feature["labels"]} for feature in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# attention_maskが0の部分は、トークンを-100に置き換えてロス計算時に無視させる
# -100を無視するのは、PyTorchの仕様
labels \
= labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# BOSトークンがある場合は削除
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
# 整形したlabelsをバッチにまとめる
batch["labels"] = labels
return batch
processor = WhisperProcessor.from_pretrained("openai/whisper-large", language="Japanese", task="transcribe")
prepared_datasets = test_dataset.map(prepare_dataset, remove_columns=test_dataset.column_names, num_proc=1)
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
!pip install ginza==4.0.5 ja-ginza
!pip install sortedcontainers~=2.1.0
import pkg_resources, imp
imp.reload(pkg_resources)
mport evaluate
import spacy
import ginza
metric = evaluate.load("wer")
nlp = spacy.load("ja_ginza")
ginza.set_split_mode(nlp, "C") # CはNEologdの意らしいです
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
# 分かち書きして空白区切りに変換
pred_str = [" ".join([ str(i) for i in nlp(j) ]) for j in pred_str]
label_str = [" ".join([ str(i) for i in nlp(j) ]) for j in label_str]
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
ここまでは学習時と全く同じコードです。
ここだけ違います。
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("/content/drive/.../クラロワ/model")
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language = "ja", task = "transcribe")
model.config.suppress_tokens = []
modelの変数に入れるpretrainedのモデルは先ほど学習して保存したドライブのパスを入力します。
それを使ってあとはtrainerを作り、predictするだけです。
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=prepared_datasets,
eval_dataset=prepared_datasets,
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
# predict!!
prediction_output = trainer.predict(prepared_datasets)
pred_ids = prediction_output.predictions
processor.tokenizer.decode(pred_ids[0], skip_special_tokens=True)
結果がすごい
別動画の予測をしてみました。
正解
めっちゃしやすくてで迫撃にもアチャクイを当てられるでしょ だもうマジで環境でゴレとかにもまあポイズンウッドだから普通に強くてエリポンも別にディガーで潰せると三銃士が来ても勝てるロイホグ系もねゴーストアチャクイゴブリンウッドだからめっちゃ強いんですよ
元のWhisper
めっちゃしやすくてで迫撃にもあ着いを当てられるでしょ だもうマジで環境で5例とかにもはポイズングッドだから普通に強くてエリポンも別にリガーで潰せると30人が来ても勝てるロイホグ系もねゴーストアチャクイゴブリングッドだからめっちゃ強いんですよ
流石にゴレが5例になってたりディガーがリガーになってたり三銃士が30人になってたりします。
ファインチューニング後Whisper
めっちゃでしやすくてで迫撃にもアチャクイを当てられるでしょだからもうマジで環境でゴレとかにもポイズンウッドだから普通に強くてでエリポンも別にディガーで潰せると三銃士とか来ても勝てるロイホグ系もねゴーストアチャクイゴブリングッドだからめっちゃ強いんですよ
最後のゴーストアチャクイゴブリングッドのグッドはウッドが正解なのですが、それ以外全部完璧に文字起こしできている!!えぐい。
データセットの中にゴレとか1回程度しか出てなかったけど綺麗に読めてやがる。
まあ過学習とかありえるのかもだから詳しく調査していきたいけどとりあえず40分程度のデータセットでここまでなるのは普通にすごいんじゃないか。
結論
1時間いないの自作文字起こしデータセットで専門用語を理解するwhisperは割と作れる説があって、めちゃくちゃ未来が見える。
マジですごい。
Discussion