🐙

タンパク質のTransformerモデルを作ってパッケージにしてみる

に公開

🐊はじめに🐊

GWの空き時間に勢いで書いたものなので、温かい目で見てください。
タンパク質の機械学習モデル、特にアミノ酸配列をTransformerベースで埋め込む手法は最近とてもよく見られると思います。一方で、タンパク質を対象にしたTransformerの実装を実例で学べる機会はあんまりないかなっと思います[1]
いろいろ探しているうちに、丁度いいサイトを見つけました
https://open.substack.com/pub/ytian/p/building-transformer-models-for-proteins?utm_campaign=post&utm_medium=web

https://github.com/naity/protein-transformer/tree/main?tab=readme-ov-file

内容はThe Immune Epitope Database (IEDB) から生データを取得してキュレーションを行い、抗体がHIV-1またはSARS-CoV-2のどちらに結合するのかを分類する分類機を作るというものです。丁寧にTokenizeやPositional Encoding, Transformerレイヤーまで実装して、trainingやtuningができるようになっているというとても教育的なものです。

せっかくデータセットを取得したので、GWの空き時間を使って備忘録がてら次の遊びをしてみました。

  1. 埋め込みを事前学習済の言語モデル (ESM2) に置き換えてよくある分類機を実装してみました。
  2. 作った分類機をパッケージにまとめてみました。
    最近のPython環境に明るくないので、ぜひぜひ気軽にコメント・アドバイスなどいれてもらえたらと思います。

ESM2での分類機の実装

事前準備

環境はこちらのレポジトリをcloneしてインストールしたもので動くと思います。
https://github.com/wani-wani-wa/BCR_classifier.git

git clone https://github.com/wani-wani-wa/BCR_classifier.git
cd  BCR_classifier
pip install requirements-dev.txt

でインストールできます。
続いて、BCR_classifier/scripts/prepare_dataset.ipynbの手順でデータを準備します。
事前にデータbcr_full_v3.zipをダウンロードして、BCR_classifier/data_dirに保存して解答してbcr_full_v3.csvを置いておきます。その後はprepare_dataset.ipynbの指示に従って、train/validation/testを8:1:1で分けます。本来は配列相同性などをみてtrain/validation&testの間に似た配列が入らないようにsplitをするべきですが、今回は素朴にrandom splitをしています。

モデル実装

ノートブック全体はこちらで公開しています:notebook

Transformersのライブラリを使って、facebook/esm2_t6_8M_UR50Dのモデルを使うことにしました。軽量で扱いやすいですね。

from transformers import AutoTokenizer, AutoModel
import torch
from torch import nn

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

文章に対する分類機を作る際の方法として、
1.CLSの埋め込みを使う方法
2.配列全体の埋め込みの平均プーリングを使う方法
3.配列全体の最大値プーリングを使う方法, etc...
などがあると思います。タンパク質の機械学習の場合は2の方法がよく使われるのですが、今回は1の方法でシンプルな線形分類機を下層につけるだけのおもちゃモデルを試してみます。

class ESMClassifier(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.classifier = nn.Linear(320, 2)  
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        cls_token = outputs.last_hidden_state[:, 0, :] 
        logits = self.classifier(cls_token)
        return logits

このような簡単なモデルでもハイパラ調整なしに86%程度のバリデーションスコアが出ました。
Train Loss = 0.0356, Train Acc = 1.0000 | Val Loss = 0.3943, Val Acc = 0.8684

実装した分類機のパッケージ化

notebookを流せばモデルは実行できるのですが、自分が開発しやすくする目的や、他の人に公開して理解しやすくするためにパッケージにまとめることが必要です。そこで、今回作ってみた分類機をよくある構成のパッケージにまとめてみることにしました。ついでに生PyTorchよりもPyTorch Lightningの方が好きなのでPyTorch Lightningに書き換えてあります。
PyTorch Lightningについて詳しくはこちらの記事などを参考にしてください。
https://qiita.com/ground0state/items/c1d705ca2ee329cdfae4
また、HydraConfigと組み合わせて実験サイクルを回すトピックなどはこちらの記事が詳しいです。
https://zenn.dev/mixi/articles/13b8cf80afcd93

パッケージの全体像

パッケージ構成のイメージはこんな感じです。
protein_transformer/
├── __init__.py
├── data/
│ ├── __init__.py
│ ├── dataset.py # BCRDatasetクラス
│ └── datamodule.py # LightningDataModuleクラス
├── models/
│ ├── __init__.py
│ └── esm_classifier.py # LightningModule (モデル+学習ロジック)
├── config/
│ └── config.yaml # ハイパーパラメータ管理
├── train.py # 学習スクリプト (Lightning Trainer)
├── utils/
│ └── metrics.py # accuracy計算とか
tests/            # 各モジュールのテスト
├── test_dataset.py
├── test_model.py
.gitignore
requirements.txt
setup.py
README.md

パッケージの作成の流れ

だいたいこんな感じで進めていきます。

  1. githubにレポジトリを作成してcloneする
  2. cloneしたレポジトリに移動して、uvで環境構築を行う
  3. 要件を満たすようにテストを書く、テストが通るように開発をする
  4. CIワークフローを設定して3を繰り返す
  5. Train/Evalutaion

1. githubにレポジトリを作成してcloneする

Web UI で New repositoryを作成します。
リポジトリ名:BCR_classifier

License / .gitignore / README は空で OKですが、特に何も考えずにLicenseはMIT, .gitignoreはpythonの初期設定を使いました。
その後ターミナル上で

git clone https://github.com/wani-wani-wa/BCR_classifier

2. cloneしたレポジトリに移動して、uvで環境構築を行う

こちらの記事を参考にしました

cd BCR_classifier
uv init

ここまででpyproject.tomlなどのuvの設定などを内包するファイルができているので、下記のように書き換えます。

BCR_classifier/pyproject.toml
[project]
name = "protein-transformer-scratch"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.9"
dependencies = []

[tool.uv]
dev-dependencies = ["mypy", "notebook", "pandas", "pytest", "ruff", "pandas"]

[tool.ruff]
indent-width = 4
line-length = 88 # Same as Black.
exclude = [".ruff_cache", ".ruff.toml", ".ruff.lock"]
target-version = "py311"

[tool.ruff.lint]
select = [
    "F", # Flake8
    "B", # Black
    "I", # isort
    "E", # error
    "W", # warning
]
ignore = ["F401", "E501"]
fixable = ["ALL"]
unfixable = []

[tool.ruff.isort]
combine-as-imports = true
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
split-on-trailing-comma = true

[tool.ruff.format]
quote-style = "double"

[tool.ruff.lint.isort]
known-third-party = ["fastapi", "pydantic", "starlette"]

[tool.pytest.ini_options]
filterwarnings = [
    "ignore:.*Jupyter is migrating.*:DeprecationWarning",
]
addopts = "-vv --color=yes"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.metadata]
dynamic = ["name", "version"]

[tool.hatch.build.targets.wheel]
packages = ["protein_transformer"]
uv sync
uv pip install -e .  # 開発モードでパッケージインストール
uv pip install "pytest<8.0.0" pytorch-lightning torch torchvision torchaudio transformers pandas scikit-learn tensorboard pytest-lazy-fixture
uv pip freeze > requirements-dev.txt

以上で開発環境が整えられました。詳しくは参考にした記事の方をご参照ください。

3. 要件を満たすようにテストを書く、テストが通るように開発をする (テスト駆動開発)

開発項目を順番に要件定義して、テストを書いていきます。
今回はデータの流れから、

  1. dataset.py
  2. datamodule.py
  3. esm_classifier.py
    の順番にテストを書きながら開発していって、最後にtrain.pyなどを書いていきます。

具体的な開発:dataset.pyを例に

dataset.pyの要件は次のようなものになります。

  1. 引数1: pandasのdf形式で、入力となるアミノ酸配列 ("sequence") とどちらの抗原を認識するのか ("label",事前に0,1でエンコード済) を受け取る
  2. 引数2: 受け取ったアミノ酸配列をtokenizerでトークン化する
  3. 出力: torch tensorの形式で、トークン化されたアミノ酸配列 ("input_ids") ・attention mask ("attention_mask") ・labels ("labels") を返す。
  4. テストで確認すべきことは、1,2の入力を受け取って、3の出力の"input_ids"や"attention_mask"が1で受け取ったアミノ酸配列と同じ長さのtorch tensorになっていることや、"labels"の内容の0/1が変わってしまっていないかなどです。

これらを要件にすると下記のように書けます。

tests/test_dataset.py
import pandas as pd
import pytest
import torch
from protein_transformer.data.dataset import BCRDataset


class DummyTokenizer:
    def __call__(self, sequence, padding, truncation, max_length, return_tensors):
        input_ids = [min(ord(char), 320) for char in sequence[:max_length]]
        attention_mask = [1] * len(input_ids)

        # padding dummy
        while len(input_ids) < max_length:
            input_ids.append(0)
            attention_mask.append(0)

        return {
            "input_ids": torch.tensor([input_ids]),
            "attention_mask": torch.tensor([attention_mask])
        }

@pytest.fixture
def sample_df():
    return pd.DataFrame({
        "sequence": ["ABCDE", "FGHIJ"],
        "label": [0, 1]
    })
# lazy_fixture is not suppoeter to new version of pytest, following is enough
# @pytest.mark.parametrize("tokenizer", [
#     DummyTokenizer(),
#     AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D"),
# ])


@pytest.mark.parametrize("tokenizer", [
    DummyTokenizer(),
    pytest.lazy_fixture("esm_tokenizer"),
])
def test_dataset_length(sample_df, tokenizer):
    dataset = BCRDataset(sample_df, tokenizer)
    assert len(dataset) == 2

@pytest.mark.parametrize("tokenizer", [
    DummyTokenizer(),
    pytest.lazy_fixture("esm_tokenizer"),
])
def test_dataset_item_format(sample_df, tokenizer):
    dataset = BCRDataset(sample_df, tokenizer, max_length=10)
    item = dataset[0]
    print(item, item["attention_mask"].sum().item())
    assert isinstance(item, dict)
    assert "input_ids" in item
    assert "attention_mask" in item
    assert "labels" in item

    assert isinstance(item["input_ids"], torch.Tensor)
    assert isinstance(item["attention_mask"], torch.Tensor)
    assert isinstance(item["labels"], torch.Tensor)

    assert item["input_ids"].shape == torch.Size([10])
    assert item["attention_mask"].shape == torch.Size([10])
    assert item["labels"].item() == 0

デコレーター (@) で修飾してfixtureの機能をたくさん使っているのですが、これはtestの前処理とmock作成をいい感じに行って、テストケースになる仮想的な変数のようなものを作成しています。詳しくはこちらを参考にしてみてください。mockでsample_dfを作成して、tokenizerはテストの動作を軽く簡潔にするためにdummy_tokenizerを作成してみました。でもやっぱりESM Tokenizerでの挙動もテストしたいよね!と思ったので結局ESM Tokenizerでの挙動もテストしています。ちょっとそれっぽいことをしたくて、pytest.lazy_fixtureを使って、conftest.pyにこんな感じで書いてfixtureをparametrizeに引き渡しています。

tests/conftest.py
import pytest
from transformers import AutoTokenizer


@pytest.fixture(scope="session")
def esm_tokenizer():
    return AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

ただ、今回のケースでは複雑ではなく再利用の頻度も少ないので、素直にparametrizeにベタ書きしたほうが良かったなと思います。

実際にこのテストを通るように書いてみコードがこちらです。このくらいのライトな実装だと実は先にコードを書いてからテストを書いています。 テストを書くと個人で開発する場合でも、コードをリファクタリングしたり新しい機能を追加したときの異常検知が捗るので、未来の自分を助けるためにも要所要所で書いておくと幸せになれます。

protein_transformer/data/dataset.py
import torch
from torch.utils.data import Dataset


class BCRDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=320):
        self.df = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sequence = self.df.iloc[idx]["sequence"]
        label = self.df.iloc[idx]["label"]

        tokens = self.tokenizer(sequence, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
        return {
            "input_ids": tokens["input_ids"].squeeze(0),
            "attention_mask": tokens["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }

ここまでできたら、pytestをしてみます。testsに移動して、

pytest -s test_dataset.py

と打つと、testが実行されます。すべて通ると

collected 4 items                                                                                                                                                                                                                                                                                                                                                                           

test_dataset.py::test_dataset_length[tokenizer0] PASSED
test_dataset.py::test_dataset_length[esm_tokenizer] PASSED
test_dataset.py::test_dataset_item_format[tokenizer0] PASSED
test_dataset.py::test_dataset_item_format[esm_tokenizer] PASSED

===================================================================================================================================================================================== 4 passed in 0.46s =====================================================================================================================================================================================

のような表示が見れるかと思います。今回は初手で通るものを書いていますが、基本的にはtestは通らなくて、通るようにbug fixをし続けていきます。同様に要件定義・テスト作成・開発のサイクルを繰り返して、2. datamodule.py、3. esm_classifier.pyを実装していきます。

4. CIワークフローを設定して3を繰り返す

3.で行ったテスト駆動開発を繰り返し行ったり、複数人で作業をする際にはCIのワークフローを生やすとテスト駆動開発が自動化できて便利です。Github ActionsでCIワークフローを流す設定を入れてみました。mainやdevにPRやpushしたときに流れる仕組みです。

BCR_classifier/.github/workflows/ci.yaml
name: CI

on:
  push:
    branches:
      - main
      - dev
  pull_request:
    branches:
      - main
      - dev

jobs:
  test:
    runs-on: ubuntu-latest

    steps:
    # GitHub リポジトリのソースを取得
    - name: Checkout code
      uses: actions/checkout@v3

    # Python をセットアップ
    - name: Set up Python
      uses: actions/setup-python@v4
      with:
        python-version: "3.11"
        
    # uv をインストール
    - name: Install uv
      run: |
        curl -Ls https://astral.sh/uv/install.sh | sh
        echo "$HOME/.local/bin" >> $GITHUB_PATH

    # 仮想環境作成と依存関係インストール
    - name: Install dependencies
      run: |
        uv venv
        source .venv/bin/activate
        uv pip install -r requirements-dev.txt
        uv pip install -e .  # 開発モードでパッケージをインストール
        uv pip install pytest

    # テスト実行
    - name: Run tests
      run: |
        source .venv/bin/activate
        pytest -s tests/

requirements-devは作った環境をそのままpip freeze > requirements-dev.txtにして書き出したもので、テストを行うだけであれば不要なパッケージや過度なバージョン指定が入っていて動かなくなるリスクがあるので、本来であればもっと簡便でゆるい指定のものを指定すべきですが、今回は動いたのでヨシ!(๑•̀ㅂ•́)و✧としてます。CIでテストが全部通ると緑の文字がたくさん見れて目にもいいです!

5. Training/Evaluation

Training/Evaluationコードの作成

実際にモデルが開発できてきたら、trainingやevaluationのコードを書いて学習や評価を実行していきます。data_dirに前処理したデータを置いておき、次のようなtrain.pyを実装して実行します。

BCR_classifier/protein_transformer/train.py
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from transformers import EsmTokenizer

from protein_transformer.data.datamodule import BCRDataModule
from protein_transformer.models.esm_classifier import ESMClassifier


def main():
    train_df = pd.read_csv("../data_dir/bcr_train.csv")
    val_df = pd.read_csv("../data_dir/bcr_val.csv")
    test_df = pd.read_csv("../data_dir/bcr_test.csv")
    tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
    datamodule = BCRDataModule(train_df, val_df, test_df, tokenizer, batch_size=32)

    model = ESMClassifier()

    tensorboard_logger = TensorBoardLogger(
        save_dir="logs/",
        name="protein_transformer",
    )

    trainer = pl.Trainer(
        max_epochs=10,
        accelerator="auto",
        logger=[tensorboard_logger, CSVLogger(save_dir="logs/")],
    )

    trainer.fit(model, datamodule)

if __name__ == "__main__":
    main()

logはlogs以下にcsvとtensorboard用のlogが出ているので、次のコマンドでログを確認します。

 tensorboard --logdir ./protein_transformer/logs/

tensor board上でログが確認できると思います。

おまけ:HydraConfigでハイパラなどを書き換えられるようにしてみる

HydraCondigを使うと、configでハイパラの設定をして管理したり、実験ごとに自動で出力を分けてconfigも保存できるので複数のハイパラ探索をした結果をまとめて分析しやすくなります。train_hydra.pyをこんな感じに書いてみました。argsで刺していた部分がHydraConfigで刺さるようになっている感じです。

BCR_classifier/protein_transformer/train_hydra.py
import hydra
import pandas as pd
import pytorch_lightning as pl
from omegaconf import DictConfig
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from transformers import EsmTokenizer

from protein_transformer.data.datamodule import BCRDataModule
from protein_transformer.models.esm_classifier import ESMClassifier


@hydra.main(config_path="config", config_name="config", version_base="1.3")
def main(cfg: DictConfig):
    print(f"Running with config: {cfg}")

    train_df = pd.read_csv(cfg.data.train_path)
    val_df = pd.read_csv(cfg.data.val_path)
    test_df = pd.read_csv(cfg.data.test_path)

    tokenizer = EsmTokenizer.from_pretrained(cfg.model.pretrained_model_name)
    datamodule = BCRDataModule(train_df, val_df, test_df, tokenizer, batch_size=cfg.data.batch_size)

    model = ESMClassifier(lr=cfg.model.lr)

    tb_logger = TensorBoardLogger(save_dir=cfg.logging.log_dir, name="protein_transformer")
    csv_logger = CSVLogger(save_dir=cfg.logging.log_dir)

    trainer = pl.Trainer(
        max_epochs=cfg.trainer.max_epochs,
        accelerator=cfg.trainer.accelerator,
        logger=[tb_logger, csv_logger],
    )

    trainer.fit(model, datamodule)

if __name__ == "__main__":
    main()

config部分はこんな感じにしました。項目ごとに別ファイルに分けてかけるみたいだったので、dataやmodelを別ファイルにしてみました。

config/config.yaml
defaults:
  - data: data
  - model: model

trainer:
  max_epochs: 10
  accelerator: auto

logging:
  log_dir: logs/
config/data/data.yaml
train_path: ../data_dir/bcr_train.csv
val_path: ../data_dir/bcr_val.csv
test_path: ../data_dir/bcr_test.csv
batch_size: 32
config/model.yaml
pretrained_model_name: facebook/esm2_t6_8M_UR50D
lr: 1e-4

trainingを流すときは、こんな感じでパラメータの上書きもできます。

python train_hydra.py model.lr=5e-5 data.batch_size=32

batch sizeの部分を試しに複数のパラメータを探索したいときには、こんな感じで --multirun オプションを入れます。

python train_hydra.py --multirun model.lr=5e-5 data.batch_size=16,32,64

batch_sizeを16,32,64の3つの値で探索しています(間にスペースを入れるとエラーになるので注意だそうです)
参考
https://qiita.com/Isaka-code/items/3a0671306629756895a6

🐊おわりに🐊

今回はおもちゃモデルを通して実践的なデータセットに対するタンパク質の機械学習モデルのテスト駆動開発っぽいことやパッケージ化をしてみました。何かを動くものを作って遊びたい人のお役に立てたらわにもうれしいです!

脚注
  1. いい教材をご存知の方がいたらコメントで教えていただけるとうれしいです。 ↩︎

Discussion