国交省の建設工事事故データベースでSwallow-13bをファインチューニングして事故対策提案AIを作る
要約
- 国土交通省の建設工事事故データベースを使ってSwallow-13bをLoRAでファインチューニングすることで、事故の対策提案タスクに特化したLLMを作成しました
- 出力結果には(素人目線で)そこまで違和感がなく、事故対策検討の補助に活用できそう
背景
先日、国土交通省が建設工事事故データベースのデータを公開しました。
建設現場での事故は重大災害に繋がりやすいこともあり、万が一がないように十分な対策を行っていく必要があります。今回国交省が公開したデータは各社の事故防止に役立つようにという思いを込めて公開されたもので、過去4年間に発生した約1600件の事故に対して、事故の経緯と状況、要因、その後の対策がまとめられています。
これをLLMに学習させることで、漏れのない対策検討を実現できるんじゃないかと思い、今回の実験を行うことにしました。
実験概要
データセットには上述の通り、国交省の建設工事事故データベースを、モデルには13Bモデルの中でも性能が良さそうな東工大のSwallow-13b-instruct-hfを使用します。
実験の手順は下記の通りです。
- データセットから事故の経緯・状況、要因、対策を抽出
- 事故の経緯・状況+要因から事故後の対策を生成するようにLoRAでSFT
- テストデータに対して事故後の対策を生成して検証
データセットについては、経緯・状況や要因、対策などが空のレコードが含まれていましたので、それらを取り除いた1150件(内、学習データ1035件)を使用しました。
学習時の設定
学習時の諸々の設定は下記の通りです。
プロンプトテンプレート
# プロンプトテンプレートの準備
def generate_prompt(data_point):
result = f"""以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。リクエストを適切に完了するための回答を記述してください。
### 指示:
下記の建設現場における事故に至る経緯、事故の状況、事故の要因から、事故発生後の対策を記述してください。
### 入力:
事故に至る経緯と事故の状況: {data_point["事故に至る経緯と事故の状況"]}
事故の要因(背景も含む): {data_point["事故の要因(背景も含む)"]}
### 応答:
{data_point["事故発生後の対策"]}"""
return result
モデルのロード
model_name = "tokyotech-llm/Swallow-13b-instruct-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_8bit=True,
device_map="auto",
)
LoRAのパラメータ
lora_config = LoraConfig(
r= 8,
lora_alpha=16,
target_modules=['v_proj', 'up_proj', 'down_proj', 'o_proj', 'q_proj', 'k_proj', 'gate_proj'],
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM
)
# モデルの前処理
model = prepare_model_for_int8_training(model)
# LoRAモデルの準備
model = get_peft_model(model, lora_config)
# 学習可能パラメータの確認
model.print_trainable_parameters()
学習
# トレーナーの準備
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
args=transformers.TrainingArguments(
num_train_epochs=3,
learning_rate=3e-4,
logging_steps=logging_steps,
save_strategy="epoch",
output_dir=output_dir,
report_to="none",
save_total_limit=3,
push_to_hub=False,
auto_find_batch_size=True
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
# 学習の実行
model.config.use_cache = False
trainer.train()
model.config.use_cache = True
結果
3列目がデータベースに書かれていた実際の対策、4列目が今回のモデルの出力になります。
考察
主観的な評価にはなりますが、出力結果にそこまで違和感はありませんでした。
建設業はインターネットに掲載されている情報が少ない印象があり、Webデータを中心に学習しているLLMは性能を十分に発揮できないんじゃないかと思っていたため、それらしい出力がされたのは驚きでした。
もちろんLLMの出力だけで事故対策を行うには不十分だと思いますが、対策の抜け漏れをなくすための補助としては使えるんじゃないかと素人目には感じました。
今回のモデルの課題としては、
- 出力の後半が一般的な対策に寄ってしまい、事故特有の対策を十分に提案できていない
- 工種などの情報を与えていないため、事故の要因や状況からしか対策の提案を行えない
などが挙げられるかなと思います。
これらは学習方法やデータセットを工夫することで改善が見込めるのではないかと考えています。
個人レベルではデータがこれ以上ないため、今回以上のことができませんが、ゼネコンなど自社で事故のデータを保有している企業がこういった取組を行うことで、事故による被害が少しでも減っていくことを願っています。
Discussion