🦁

skorchを使ってPyTorchモデルを学習してみた

に公開

今回はskorchを使ってPyTorchのモデルを学習してみました。skorchを利用すると、scikit-learnと同じ使い勝手でモデルを学習できるようになります。

skorchとは?

先ほども書いたように、skorchを利用するとscikit-learnと互換性がある記述方式でPyTorchのモデルを学習できます。後ほどサンプルを見ながら進めますが、PyTorchで定義したモデルをskorchに受け渡して学習に利用できます。

https://github.com/skorch-dev/skorch

早速使ってみる

今回はGitHubに乗っているサンプルを元に使ってみます。

環境構築

uvを利用して以下のように環境を構築します。

uv init skorch_tutorial -p 3.12
cd skorch_tutorial
uv add skorch torch scikit-learn

シンプルな分類モデルの学習

それでは分類モデルの学習をしてみます。コードの全体かんはGitHubにあるように以下のコードを利用します。

train_classifier.py
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
from skorch import NeuralNetClassifier

X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=nn.ReLU()):
        super().__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, num_units)
        self.output = nn.Linear(num_units, 2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = self.nonlin(self.dense1(X))
        X = self.softmax(self.output(X))
        return X

net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

net.fit(X, y)
y_proba = net.predict_proba(X)

まずは分類に利用するデータを作成します。今回はscikit-learn上で提供されている分類モデル生成のためのmake_classificationを利用します。データ型はデフォルトではnp.float64ですがここではnp.float32にキャストしています。

X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

次にPyTorchを利用して分類モデルを実装します。モデルの実装は通常のPyTorchモデルの実装と同じ要領で問題ありません。入力から出力までLinear層とReLUおよびSoftmaxを組み合わせたシンプルなものとなっています。

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=nn.ReLU()):
        super().__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, num_units)
        self.output = nn.Linear(num_units, 2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = self.nonlin(self.dense1(X))
        X = self.softmax(self.output(X))
        return X

最後にモデル学習部分の実装です。skorch.NeuralNetClassifierにモデルを実装したクラスを指定し、epoxhや学習率を指定します。NeuralnetClassifierが準備できたらscikit-learnと同じようにnet.fit(X, y)のようにfit関数を呼び出すと学習が実行されます。

net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

net.fit(X, y)

学習が完了したらnet.predict_proba(X)とすることで推論結果を取得することができます。

y_proba = net.predict_proba(X)

早速このコードを実行してみましょう。今回はエポック数は10に設定しているので、全てのデータを10回利用してモデルが学習されました。

uv run train_classifier.py

# 結果
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        0.7012       0.5550        0.6913  0.0193
      2        0.6872       0.6300        0.6855  0.0047
      3        0.6799       0.6600        0.6816  0.0047
      4        0.6752       0.6700        0.6775  0.0044
      5        0.6703       0.6700        0.6741  0.0048
      6        0.6639       0.6950        0.6665  0.0046
      7        0.6558       0.6850        0.6609  0.0045
      8        0.6518       0.7200        0.6549  0.0044
      9        0.6412       0.7250        0.6482  0.0044
     10        0.6314       0.7450        0.6423  0.0047

学習の細かなチューニングなどはしていませんが、モデルの定義だけでモデルが学習されるのはとても楽です。実際、PyTorch + Lightningでモデルを学習しようとすると以下のようなことをする必要があり結構大変です。

https://zenn.dev/akasan/articles/2b625606090524

まとめ

今回はskorchを使ってPyTorchのモデルを学習してみました。本格的にモデルを学習する場合はskorchよりLightningなどを組み合わせた実装になるかと思いますが、最低限の実装で検証のベースとなるモデルを作ってみたい、または定義したモデルをとりあえず学習してみて学習ができそうか判断する分にはとても有用だと思います。また、scikit-learnと互換性があるので、クロスバリデーションやパイプラインに組み込むこともできるようになっています。ぜひ興味がある方は一度使ってみてはいかがでしょうか。

Discussion