🐺
LLM開発にSLAの原則を適応してTDDで進める
はじめに
LLMを自前でファインチューニングするにあたり
Trainingの実装をSLAの原則を取り入れて快適にした話
環境
- Python
- HuggingFace
SLA Principle
一般的にはSLA(Single Level of Abstraction)もしくはSLAP(Single Level of Abstraction Principle)と呼ばれている原則
各メソッドは、単一の抽象レベルで記述することが大切だよと伝えてくれています
Robert C. Martin (a.k.a Uncle Bob) の人気書籍『Clean Code』(2009年)にも記載されている
複雑さを別の抽象レベルに押し上げ、少なくとも読むことができるものをメソッドにすることで
メソッド内部のコードが同じ抽象度になりそのメソッドが何をしているのか理解しやすくなります。
例
バリデーションの流れを抽象的にして箇条書する
- 指定した文字数であること
- 大文字英字+小文字英字が含まれていること
- 半角数字が含まれていること
実装例
# バリデーションの抽象度を揃える場合
def validate(value):
return ValidateClass.totalLength(value) and \
ValidateClass.hasAlpaBigAndSmall(value) and \
ValidateClass.hasNumber(value)
抽象度が揃っていない場合
def validate(value):
# lengthが10文字以上16文字以下
totalLength = 10 <= len(value) <= 16
# 大文字英字と小文字英字が含まれている
hasAlpaBigAndSmall = any(c.isupper() for c in value) and any(c.islower() for c in value)
# 数字が含まれている
hasNumber = any(c.isdigit() for c in value)
return totalLength and hasAlpaBigAndSmall and hasNumber
Trainingのロジック水準を考える
Trainingの流れを抽象的にして箇条書きする
例
- modelのロード
- tokenizerのロード
- datasetの準備
- SFTConfigの設定値を取得
- formattingの設定を取得
- collatorの作成
- trainerの設定
- trainの実行
実装例
箇条書きした流れを実装したイメージ
model = load_model(model_name)
tokenizer = load_tokenizer(model_name)
train_dataset, eval_dataset = prepare_dataset(
dataset_name,
dataset_revision,
tokenizer,
)
args = get_args(
output_dir,
hub_model_id,
)
formatting_prompts_func = get_formatting()
collator = create_collator(tokenizer, -100)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=args,
formatting_func=formatting_prompts_func,
data_collator=collator,
)
trainer.train()
抽象度を揃えないと全体理解までの道のりが長くなる
from transformers import AutoTokenizer
# ... 他import
##### tokenizer #####
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# ... 他の設定値も同様に記載していく
trainer = SFTTrainer(
...
tokenizer=tokenizer
)
TDD
SLAPの原則を取り入れることでもう一つ恩恵があり
それはTDDで進めることができること
特に設定値の取得において引数で受け取るような関数を作成した時
いくつかテストケースが出てくる場合があります
# SFTConfigのmax_stepsはnum_train_epochsを上書きする
args = get_args(max_steps, num_train_epochs)
# 引数がない場合エラー
# どっちも渡されたらエラー
# max_stepsだけ渡されたらmax_stepsだけセット
# etc...
抽象的にした関数は一つ一つの関数は責務が小さくなりテストが容易に試せます
TDDは着実にコミットできるので安心して進められます
TDDについてはこちら
代表的なレイヤードアーキテクチャを採用せずとも抽象度を揃えることだけで
時間を空けて自身が見直した時や他メンバーが見た時にも
何をしているのか理解しやすく手助けをしてくれます。
Discussion