松尾研LLM講座レポート
今回の最終課題では、最終的に以下の成績を取り、優秀賞をいただくことが出来ました。
- 予選最終スコア 3.40点 29位
- 優秀賞U-29部門 3位
- 優秀賞U-18部門 1位
- 優秀賞コントリビューション部門 4位
ここではモデルの開発についてまとめました。
TL;DR
DeL-TaiseiOzaki/Tengentoppa-sft-v1.0とAratako/Magpie-Tanuki-8B-annotated-96kの一部をデータセットとし、llm-jp-sftでgoogle/gemma-2-27bをSFTした。また、entropixを使用し、Q5_K_Mで推論した。
自己紹介
大阪在住のとあるSSH指定校の高校一年生。LLMの他に乗り物・交通が趣味。
過去の作品
- 文化祭で実施したカジノの通貨管理システム
- 生徒会のHP
- LLMをFTした翻訳機
- DiffLlama
AI・LLMに興味を持った道のり
- 元々サーバーに興味がありLinuxサーバーでWebサーバーやSamba、Minecraftサーバーなどを立てて遊んでいた。既存のサービスをOSSを使って真似していた。(NextCloudやRocketChatなど)
- 一時期AIに興味を持ちPytorchチュートリアルをやっていた
- 自分や妹の声を使った合成音声を作ろうとした。(Tacotron2)
- 2022年11月30日ChatGPT発表。初日か次の日に触った。
- 同じように真似をしたくてrinna/japanese-gpt-1bをPCで動かした。
- npakaさんのFT解説などで時々FTすることがあった。
- LOCAL AI HACKATHON #000の告知ツイートでローカルLLMに向き合う会(Discord)をしり、入った。いろんなLLMの話題にふれ、時折自分でもやってみることもあった。
モデルの開発
使用した元モデルはgoogle/gemma-2-27bです。llama.cppにて5bit量子化すると24GBのGPUでも動くパラメータサイズです。
初期の構想
1. LoRAでSFT
llm-jp-sftを使用してDeL-TaiseiOzaki/Tengentoppa-sft-v1.0でLoRA SFTします。
2. DPOデータセットの作成
こちらの記事を参考に新しくTanuki-8Bを使ってMagpieでinstructionを生成します。その後Qwen/Qwen2.5-32B-Instructとgoogle/gemma-2-27b-it、cyberagent/calm3-22b-chatの出力をchosen、開発中のモデルの出力をrejectedとしてDPOデータセットを作成します。これはweblab-GENIAC/Tanuki-8x8B-dpo-v1.0を参考としました。
3.DPO学習
llm-jp-dpoを使用して学習します。
2と3を3回ほど繰り返す予定でした。
4.gguf化
日本語imatrixを使ってキャリブレーションすることで日本語性能をできるだけ維持して量子化できるそうです。Q5_K_Mに量子化します。参考: https://note.com/npaka/n/nbd1348500a28
5.entropixで推論
entropixというLLMのデコーディング方式や温度などを毎トークン自動設定する推論方法を使い、推論します。実装はGooglefanさんの実装を用いました。効果はモデルによって異なるそうです。以下はElyza-tasks-100のスコアです。(GPT-4o-mini採点)
モデル名 | entropixあり | entropixなし |
---|---|---|
gemma-2-2b-jpn-it | 3.03 | 3.13 |
gemma-2-9b-it | 3.92 | 3.86 |
gemma-2-27b-it | 3.98 | 3.95 |
llm-jp-3-13b-instruct | 2.46 | 2.94 |
llm-jp-3-3.7b-instruct | 2.35 | 2.66 |
llm-jp-3-1.8b-instruct | 1.74 | 2.18 |
できたこと
1. LoRAでSFT
RunpodのA100 80GB PCIeにて合計10時間ほど学習しました。ただ、ハイパーパラメータをこだわることができなかったので一度最適なパラメータを見つける練習をしたいです。また、オプティマイザーの使い方もあまりよろしくなく、途中で中断したなどでスケジュールがうまくできなかったです。以下が学習曲線です。上二つのグラフがDeL-TaiseiOzaki/Tengentoppa-sft-v1.0、一番下のがAratako/Magpie-Tanuki-8B-annotated-96kの一部を学習させました。
2. DPOデータセットの作成
Qwen/Qwen2.5-32B-Instructとgoogle/gemma-2-27b-it、cyberagent/calm3-22b-chatを使用し、vLLMで推論しました。それぞれのモデルで3000件作成しました。一番大きなモデルQwen/Qwen2.5-32B-InstructでA100 80GB PCIe使用時にoutput 1000tok/sくらいでした。
しかし、DPOしてみると、あまり精度が出なかったため、路線を変更し、nitky/Llama-3.1-SuperSwallow-70B-Instruct-v0.1一つのモデルにすることにしました。
4.gguf化 5.entropixで推論
前述した通り、entropixはモデルによって効果に差があり、逆転することもあるので、盲目的に使用せず、確認すべきだったと思います。
できなかったこと
3.DPO学習
RLHFと違ってDPOはSFTと同じメモリ消費量だと考えていましたが、参照モデルという、あまりにも強化学習でモデルが離れないようにする基準のモデルの分、よりメモリを消費するようです。そのため、A100一枚では学習できませんでした。
分散学習
開発初期にRunpodのA40を8枚使用して学習の確認をしていましたが、Deepspeed Zero3を使用するとモデルのロードが終わらない問題に遭遇して、結果1枚でGemma-2-27bが載るA100 80GBを利用しました。
開発の総評など
やらなくて/できなくて後悔したことがたくさんあります。例えばハイパーパラメータを全くいじらずにネットの別モデルの学習用の物を流用したり、目視で出力を確認しなかったりです。このレポートを書きながら他の方のレポートを読んでいると、もう一度やり直したいくらいです。ただ、これを糧にして今後も研究開発をしていけたらなと思います。貴重な機会、松尾研や東京大学、登壇者様、寄附者様、コミュニティの皆さん、ありがとうございました。
Discussion