お家のパソコンでESM2のLoRAをする
初めて記事を書いてみたので、至らない点が多いかと思いますが暖かく見守ってください。
🐊はじめに🐊
タンパク質の言語モデル(protein language models: pLMs)は、自然言語処理 (Natural Language Processing: NLP) の技術をタンパク質のアミノ酸配列に適用して、タンパク質の機能予測やタンパク質同士の結合予測、タンパク質の構造予測など様々なタスクを行える機械学習モデルです。
タンパク質の言語モデルでは、UniProtどの大規模なデータベース内のデータ用いて、アミノ酸位置文字を1単語として扱ってタンパク質のアミノ酸配列を文章のように学習していきます。近年ではTransformerやそのエンコーダー部分を派生させたBERTと呼ばれるアーキテクチャのものが一般になっています。これらのタンパク質の言語モデルの活用に向けた学習は、一般的に事前学習とファインチューニングという2つの過程に別れます。タンパク質の言語モデルに関してはこちらの総説が日本語でとても詳しくわかりやすく書いてあります。
事前学習は、タンパク質の言語モデルがまず一般的な知識を学習する段階です。この段階では、大規模なタンパク質配列データを使用して、モデルに基本的なパターンや構造的な関係を学ばせます。大切な点として、大規模なアミノ酸配列のデータだけから人の手でキレイに整えたデータを用いなくても自ら学習できる (自己教師あり学習) 点です。この事前学習されたモデルはhugging faceなどを通じて、簡単に活用できるようになっています。
ファインチューニングは、事前学習によって得た一般的な知識を、特定のタスクに適応させるプロセスです。この段階では、人が手でキュレーションした特定のタスクに関連するラベル付きデータを用いてモデルを微調整します。
今回はこのファインチューニングについて、LORA (Low Rank Adaptation) というものをやってみます。LORAはファインチューニングの際に、モデルの各層に低ランクの行列を挿入して低ランク行列をチューニング対象とすることで、計算コストやメモリ使用量を減らす手法です(wikipedia)。具体的に取り組む内容は、こちらのESMBindというタンパク質のリガンドなどの結合部位の予測タスクを、お家のGPU (RTX-3060, 12GB) で回してみたいと思います。
環境 | バージョン等 |
---|---|
GPU | RTX-3060 |
Driver | 535.183.01 |
CUDA | 12.2 |
Python環境
記事内のPython環境がPython3.8の古めの環境だったのと、今回使わない付属のライブラリーが大量に入っていたので、下記のシンプルな環境で行いました。mambaかsolverでlibmambaを使ったら解決できると思います。とりあえず最新の環境で動けばいいやの気持ちでつくったので、バージョン依存は厳しく入れてません。
name: esm2_lora_py312
channels:
- pytorch
- nvidia
- conda-forge
dependencies:
- python=3.12.7
- pytorch-gpu>=2.0.0
- pip
- pip:
- accelerate
- datasets
- peft
- pyyaml
- scikit-learn
- transformers
- tensorflow
- tf-keras
minicondaをインストール してからこちらのenv.yamlを用意したら、置いてあるディレクトリで
conda env create -f conda-env.yml
でconda環境を作ります。うまく行かない場合は
conda install conda-forge::mamba
mamba env create -f conda-env.yml
を試すとできるかもしれません。
環境ができたら
conda activate esm2_lora_py312
でconda環境をactivateします。
上のenv.yamlで作った環境をexport envしたとき、私の環境はこんな感じでした。
export_env.yaml
name: esm2_lora_py312
channels:
- nvidia
- conda-forge
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_kmp_llvm
- bzip2=1.0.8=h4bc722e_7
- ca-certificates=2024.12.14=hbcca054_0
- cpython=3.12.7=py312hd8ed1ab_0
- cuda-cudart=12.6.77=0
- cuda-cudart_linux-64=12.6.77=0
- cuda-nvrtc=12.6.85=0
- cuda-nvtx=12.6.77=0
- cuda-version=12.6=3
- cudnn=9.3.0.75=cuda12.6
- filelock=3.16.1=pyhd8ed1ab_1
- gmp=6.3.0=hac33072_2
- gmpy2=2.1.5=py312h7201bc8_3
- jinja2=3.1.5=pyhd8ed1ab_0
- ld_impl_linux-64=2.43=h712a8e2_2
- libabseil=20240722.0=cxx17_h5888daf_1
- libblas=3.9.0=26_linux64_openblas
- libcblas=3.9.0=26_linux64_openblas
- libcublas=12.6.4.1=0
- libcufft=11.3.0.4=0
- libcufile=1.11.1.6=0
- libcurand=10.3.7.77=0
- libcusolver=11.7.1.2=0
- libcusparse=12.5.4.2=0
- libexpat=2.6.4=h5888daf_0
- libffi=3.4.2=h7f98852_5
- libgcc=14.2.0=h77fa898_1
- libgcc-ng=14.2.0=h69a702a_1
- libgfortran=14.2.0=h69a702a_1
- libgfortran5=14.2.0=hd5240d6_1
- libhwloc=2.11.2=default_h0d58e46_1001
- libiconv=1.17=hd590300_2
- liblapack=3.9.0=26_linux64_openblas
- liblzma=5.6.3=hb9d3cd8_1
- liblzma-devel=5.6.3=hb9d3cd8_1
- libmagma=2.8.0=h566cb83_2
- libmagma_sparse=2.8.0=h0af6554_0
- libnsl=2.0.1=hd590300_0
- libnvjitlink=12.6.85=0
- libopenblas=0.3.28=pthreads_h94d23a6_1
- libprotobuf=5.28.2=h5b01275_0
- libsqlite=3.47.2=hee588c1_0
- libstdcxx=14.2.0=hc0a3c3a_1
- libstdcxx-ng=14.2.0=h4852527_1
- libtorch=2.5.1=cuda126_hebb32c0_306
- libuuid=2.38.1=h0b41bf4_0
- libuv=1.49.2=hb9d3cd8_0
- libxcrypt=4.4.36=hd590300_1
- libxml2=2.13.5=h0d44e9d_1
- libzlib=1.3.1=hb9d3cd8_2
- llvm-openmp=19.1.6=h024ca30_0
- markupsafe=3.0.2=py312h178313f_1
- mkl=2024.2.2=ha957f24_16
- mpc=1.3.1=h24ddda3_1
- mpfr=4.2.1=h90cbb55_3
- mpmath=1.3.0=pyhd8ed1ab_1
- nccl=2.23.4.1=h2b5d15b_3
- ncurses=6.5=he02047a_1
- networkx=3.4.2=pyh267e887_2
- openssl=3.4.0=hb9d3cd8_0
- pip=24.3.1=pyh8b19718_2
- python=3.12.7=hc5c86c4_0_cpython
- python_abi=3.12=5_cp312
- pytorch=2.5.1=cuda126_py312h1763f6d_306
- pytorch-gpu=2.5.1=cuda126ha999a5f_306
- readline=8.2=h8228510_1
- setuptools=75.6.0=pyhff2d567_1
- sleef=3.7=h1b44611_2
- sympy=1.13.3=pyh2585a3b_104
- tbb=2021.13.0=hceb3a55_1
- tk=8.6.13=noxft_h4845f30_101
- typing_extensions=4.12.2=pyha770c72_1
- wheel=0.45.1=pyhd8ed1ab_1
- xz=5.6.3=hbcc6ac9_1
- xz-gpl-tools=5.6.3=hbcc6ac9_1
- xz-tools=5.6.3=hb9d3cd8_1
- pip:
- absl-py==2.1.0
- accelerate==1.2.1
- aiohappyeyeballs==2.4.4
- aiohttp==3.11.11
- aiosignal==1.3.2
- astunparse==1.6.3
- attrs==24.3.0
- certifi==2024.12.14
- charset-normalizer==3.4.0
- datasets==3.2.0
- dill==0.3.8
- flatbuffers==24.12.23
- frozenlist==1.5.0
- fsspec==2024.9.0
- gast==0.6.0
- google-pasta==0.2.0
- grpcio==1.68.1
- h5py==3.12.1
- huggingface-hub==0.27.0
- idna==3.10
- joblib==1.4.2
- keras==3.7.0
- libclang==18.1.1
- markdown==3.7
- markdown-it-py==3.0.0
- mdurl==0.1.2
- ml-dtypes==0.4.1
- multidict==6.1.0
- multiprocess==0.70.16
- namex==0.0.8
- numpy==2.0.2
- opt-einsum==3.4.0
- optree==0.13.1
- packaging==24.2
- pandas==2.2.3
- peft==0.14.0
- propcache==0.2.1
- protobuf==5.29.2
- psutil==6.1.1
- pyarrow==18.1.0
- pygments==2.18.0
- python-dateutil==2.9.0.post0
- pytz==2024.2
- pyyaml==6.0.2
- regex==2024.11.6
- requests==2.32.3
- rich==13.9.4
- safetensors==0.4.5
- scikit-learn==1.6.0
- scipy==1.14.1
- six==1.17.0
- tensorboard==2.18.0
- tensorboard-data-server==0.7.2
- tensorflow==2.18.0
- termcolor==2.5.0
- tf-keras==2.18.0
- threadpoolctl==3.5.0
- tokenizers==0.21.0
- tqdm==4.67.1
- transformers==4.47.1
- tzdata==2024.2
- urllib3==2.3.0
- werkzeug==3.1.3
- wrapt==1.17.0
- xxhash==3.5.0
- yarl==1.18.3
実際に動かしてみる
必要なライブラリをimportします。
import os
import numpy as np
import torch
import torch.nn as nn
import pickle
import xml.etree.ElementTree as ET
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
roc_auc_score,
matthews_corrcoef
)
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
DataCollatorForTokenClassification,
TrainingArguments,
Trainer
)
from datasets import Dataset
from accelerate import Accelerator
# Imports specific to the custom peft lora model
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
続いて、helper関数を用意します。ロス関数内で後述のウェイト調整を行っているところが工夫されている点でしょうか。
# Helper Functions and Data Preparation
def truncate_labels(labels, max_length):
"""Truncate labels to the specified max_length."""
return [label[:max_length] for label in labels]
def compute_metrics(p):
"""Compute metrics for evaluation."""
predictions, labels = p
predictions = np.argmax(predictions, axis=2)
# Remove padding (-100 labels)
predictions = predictions[labels != -100].flatten()
labels = labels[labels != -100].flatten()
# Compute accuracy
accuracy = accuracy_score(labels, predictions)
# Compute precision, recall, F1 score, and AUC
precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
auc = roc_auc_score(labels, predictions)
# Compute MCC
mcc = matthews_corrcoef(labels, predictions)
return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
def compute_loss(model, inputs):
"""Custom compute_loss function."""
logits = model(**inputs).logits
labels = inputs["labels"]
loss_fct = nn.CrossEntropyLoss(weight=class_weights)
active_loss = inputs["attention_mask"].view(-1) == 1
active_logits = logits.view(-1, model.config.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
return loss
続いて、データセットの読み込みを行います。今回は、短時間で学習を試してみたいだけだったので、データセットを1/10にして学習を行います (このサイズにすると1~2時間で学習が終わります) 。事前に、hugging faceからこちらのpklファイルをダウンロードしておきます。
# Load the data from pickle files (replace with your local paths)
with open("train_sequences_chunked_by_family.pkl", "rb") as f:
train_sequences = pickle.load(f)
with open("test_sequences_chunked_by_family.pkl", "rb") as f:
test_sequences = pickle.load(f)
with open("train_labels_chunked_by_family.pkl", "rb") as f:
train_labels = pickle.load(f)
with open("test_labels_chunked_by_family.pkl", "rb") as f:
test_labels = pickle.load(f)
max_sequence_length = 1000
# train_sequencesとtest_sequencesの1/10を使用
train_sequences_subset = train_sequences[:len(train_sequences) // 10] # 1/10に縮小
test_sequences_subset = test_sequences[:len(test_sequences) // 10] # 1/10に縮小
# ラベルのトランケート
train_labels_subset = truncate_labels(train_labels[:len(train_labels) // 10], max_sequence_length) # 1/10に縮小してからラベルのトランケート
test_labels_subset = truncate_labels(test_labels[:len(test_labels) // 10], max_sequence_length) # 1/10に縮小してからラベルのトランケート
続いて、トークン化を行います。Hugging FaceのAutoTokenizerを使用して、facebook/esm2_t12_35M_UR50D用のトークナイザーをロードして、数値ベクトル化します。配列の長さをパディングで揃えたり、長い配列に対してはmax_sequence_length (=1000) まで長さを揃える操作が入っています。
# Tokenization
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
train_tokenized = tokenizer(train_sequences_subset, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_sequences_subset, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
# Directly truncate the entire list of labels
train_labels = truncate_labels(train_labels_subset, max_sequence_length)
test_labels = truncate_labels(test_labels_subset, max_sequence_length)
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
続いて、クラスウェイトの計算を行います。これはデータセットのうち、正例と負例のバランスが均等でない場合にウェイトを導入することで学習のバランスを調節するために行っています。具体的にはロス関数内のnn.CrossEntropyLoss(weight=class_weights)で重みが調整されます。
# Compute Class Weights
classes = np.array([0, 1])
flat_train_labels = [label for sublist in train_labels for label in sublist]
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
accelerator = Accelerator()
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
ウェイトをつけたTrainerクラスを実装します。
# Define Custom Trainer Class
class WeightedTrainer(Trainer):
def compute_loss(self, model, inputs, num_items_in_batch=None, return_outputs=False):
outputs = model(**inputs)
loss = compute_loss(model, inputs)
return (loss, outputs) if return_outputs else loss
細かいtrainingの設定を含んだtraining用の関数を書きます。
configの内容を見ていきます。lora_alphaはLoRAのウェイトの強さを決めるパラメータで1を基準に大きさを変えていきます。今回GPUメモリが12GBなので、バッチサイズを12から3に減らしています。それに合わせて、lrをもとの設定から1/4倍しています。
trainingでf1スコアがいいものを採用することにして、モデル・ログはbest_model_esm2_t12_35M_lora_{timestamp}というディレクトリに保存しています。メーカーなどに所属していると業務でwandbが使えないので、学習経過はログに保存してtensorboardで可視化するようにしています。
def train_function_no_sweeps(train_dataset, test_dataset):
# Set the LoRA config
config = {
"lora_alpha": 1, #try 0.5, 1, 2, ..., 16
"lora_dropout": 0.2,
# "lr": 5.701568055793089e-04,
"lr": 5.701568055793089e-04/4,
"lr_scheduler_type": "cosine",
"max_grad_norm": 0.5,
"num_train_epochs": 3,
# "per_device_train_batch_size": 12,
"per_device_train_batch_size": 3,
"r": 2,
"weight_decay": 0.2,
# Add other hyperparameters as needed
}
# The base model you will train a LoRA on top of
model_checkpoint = "facebook/esm2_t12_35M_UR50D"
# Define labels and model
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)
# Convert the model into a PeftModel
peft_config = LoraConfig(
task_type=TaskType.TOKEN_CLS,
inference_mode=False,
r=config["r"],
lora_alpha=config["lora_alpha"],
target_modules=["query", "key", "value"], # also try "dense_h_to_4h" and "dense_4h_to_h"
lora_dropout=config["lora_dropout"],
bias="none" # or "all" or "lora_only"
)
model = get_peft_model(model, peft_config)
# Use the accelerator
model = accelerator.prepare(model)
train_dataset = accelerator.prepare(train_dataset)
test_dataset = accelerator.prepare(test_dataset)
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
# Training setup
training_args = TrainingArguments(
output_dir=f"esm2_t12_35M-lora-binding-sites_{timestamp}",
learning_rate=config["lr"],
lr_scheduler_type=config["lr_scheduler_type"],
gradient_accumulation_steps=1,
max_grad_norm=config["max_grad_norm"],
per_device_train_batch_size=config["per_device_train_batch_size"],
per_device_eval_batch_size=config["per_device_train_batch_size"],
num_train_epochs=config["num_train_epochs"],
weight_decay=config["weight_decay"],
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
push_to_hub=False,
logging_dir=f"esm2_t12_35M-lora-binding-sites_{timestamp}",
logging_first_step=False,
logging_steps=200,
save_total_limit=7,
no_cuda=False,
seed=8893,
fp16=True,
)
# Initialize Trainer
trainer = WeightedTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=tokenizer,
data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
compute_metrics=compute_metrics
)
# Train and Save Model
trainer.train()
save_path = os.path.join("lora_binding_sites", f"best_model_esm2_t12_35M_lora_{timestamp}")
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)
最後に、trainingを回していきます
train_function_no_sweeps(train_dataset, test_dataset)
trainingの結果
trainingの過程はtensorboardで可視化できます。{timestamp}はtrainingを始めた日時になっているので適宜置き換えます。
tensorboard --logdir ./esm2_t12_35M-lora-binding-sites_{timestamp}
右のtraining lossが下がっていることがわかります。
結果のmetricsはこんな感じでした。
metrics | eval/accuracy | eval/auc | eval/f1 | eval/loss | mcc | eval/precision | eval/recall |
---|---|---|---|---|---|---|---|
Value | 0.9636 | 0.8083 | 0.4153 | 0.7399 | 0.429 | 0.3059 | 0.6466 |
3 epoch回していますが、ほとんどの項目が1 epoch目で頭打ちになっていることがわかります。Accuracyが96%とかなり高く出ていますが、f1, precision, recallなどをみるとそれなりに問題がありそうなことがわかります。特にprecisionが低いことから、false positiveの数が多いという問題点がわかります。一方でrecallは少しだけいい数字なので、positiveデータはある程度拾ってこれるモデルになっていそうです。Accuracyがかなり高いのは、アミノ酸配列のうち結合部位がごく一部のため、多くのアミノ酸をnegativeと判別しておけば正解しやすいというデータの問題にあります。他の記事をみた感じだとこういったデータを使ったようです。プロトンのドナー・アクセプターやリガンド結合部位や活性化サイトが混合しているデータセットなのでしょうか。
作ったモデルでinference
実際に作ったモデルを使ってinferenceして遊んでみます。
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch
自分で作ったモデルとベースモデルをロードします。{timestamp}はtrainingを始めた日時になっているので適宜置き換えます。
model_path = "./esm2_t12_35M-lora-binding-sites_{timestamp}/checkpoint-{your checkpoint}"
base_model_path = "facebook/esm2_t12_35M_UR50D"
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)
# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
モデルの構成をみてみます。
loaded_model.eval()
予測したいアミノ酸配列をいれて、モデルに入れるためにトークン化します。
# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
予測を行います。最初にモデルを読み込んで、先程トークン化したインプットを入れます。結果のロジットのうち0番目と1番目の数値で確率が高い方を採用して、結合サイトのラベルとして0か1で返します。
# Run the model
with torch.no_grad():
logits = loaded_model(**inputs).logits
# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)
予測結果のラベルが0なら結合サイトではない、1なら結合サイトとラベル定義します。
# Define labels
id2label = {
0: "No binding site",
1: "Binding site"
}
結果をprintします。
# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
if token not in ['<pad>', '<cls>', '<eos>']:
print((token, id2label[prediction]))
('M', 'No binding site')
('A', 'No binding site')
('V', 'No binding site')
('P', 'No binding site')
('E', 'No binding site')
('T', 'No binding site')
('R', 'No binding site')
('P', 'No binding site')
('N', 'No binding site')
('H', 'No binding site')
('T', 'Binding site')
...
おわりに
今回はお家のGPUでESM2を使ってLoRAをして遊んでみました。お家で誰でも簡単にモデルを作れる時代になって楽しかったです。一方で、素朴にAccuracyを見るだけだと問題点を見落としてしまったり、いいデータをとっていい問題設定をすることが大事なのかしらという感想を持ちました。
今回のコードやパラメータはhugging faceの記事に準拠していますが、実践的には大きいバッチサイズで大きいデータセットを使って、いろんなパラメータを検討する必要があると思います。遊んでみたあとにLoRAのパラメータなどの詳細な検討をしている記事を見つけたので、下に貼りました。ご興味のある方は参考にしてください。
Discussion