♾️

遅延評価と機械学習

2023/09/27に公開
6

最近「なぜ関数プログラミングは重要か」という文書の存在を知りました。関数型プログラミング界隈ではかなり有名な文書のようだったので私も読んでみたのですが、話題の一つとして「遅延評価がプログラムのモジュール化を可能にし、生産性を高める」という話が事例とともに説明されており、とても勉強になりました。まだまだ理解しきれてはいませんが……

本記事では、「なぜ関数プログラミングは重要か」に触発された私が、試しに機械学習のパイプライン構築に遅延評価を適用してみた事例を紹介します。読者のターゲットは普段Pythonで機械学習に触れているデータサイエンティストの方です。本記事を通して、遅延評価を使うと機械学習の学習処理ような「停止条件を満たすまでforループを回す」系の処理をうまくモジュール化できることを実感していただければ幸いです。一方で、例えばC#のLINQやJavaのStream APIなど (私はよく知らない……) を既に使いこなしていらっしゃる方にとっては今さらな内容かなと思います。

遅延評価とは

(いきなり実用的じゃなさそうな話で読む気失せると思いますが、負けないでください)

遅延評価とは、プログラミング言語における式 (関数呼び出しなど) の評価戦略のひとつです。評価戦略には以下の2種類があります。

  • 正格評価
    • 呼び出された時点で式を評価する。
  • 遅延評価
    • 呼び出された時点ではいったん保留しておき、値が必要になった時点で式を評価する。

遅延評価だからできることの例としてよく挙がるのが、無限リストです。以下のコードは無限リストの例で、「自然数を無限に含むリストを作り、そのリストの先頭5つを取得する」という処理を実装したものです。正格評価を採用している言語だと、無限リストを作る処理を真面目に実行しようとして無限ループに陥ってしまいます。一方で遅延評価を採用している言語だと、無限リストを作る処理はいったん保留しておき、後段の処理で先頭5つが必要だとわかってから先頭5つを生成するのに必要な分だけリストを作る処理を実行します。

def n(x):
    return [x] + n(x + 1)

n1 = n(1)
print(n1[0:5])

Pythonは正格評価なので無限リストは作れないのですが、無限リストに近いものを作る機能なら備わっています。それはジェネレーターです。以下のコードは、先ほどの「自然数を無限に含むリストを作り、そのリストの先頭5つを取得する」という処理をジェネレーターを用いて実装したものです。自然数の無限リストを作る代わりに自然数を無限に生成する関数を作って必要な回数だけ呼び出すことにした、とみなすことができます。

def n(x):
    while True:
        yield x
        x += 1

n1 = n(1)
print([n1.__next__() for _ in range(5)])
[1, 2, 3, 4, 5]

なお、本記事では遅延評価で実現できることのうち「無限リストから順に要素を取り出すこと」に特に関心があり、これは今説明した通りジェネレーターで実現できるため、これ以降はジェネレーターをPythonで遅延評価を実現するための仕組みだとみなして話を進めます。

ここまでで、遅延評価を初めて知った人は「いや、遅延評価なら無限リスト作れますって言われても、そもそも作りたくないのだが……」と思うことでしょう。私もずっとそう思っていました。しかし、我々が意識できていないだけで、実は「無限リストを作って、その先頭を取り出す」という処理を我々は日常的に実装しています。その最たる例が、機械学習の学習処理です。学習処理は「無限に続く学習ステップ (=無限リストに相当) を、目的関数が収束した時点で打ち切る (=無限リストの先頭を取り出すことに相当)」という処理であることから、「無限リストを作って、その先頭を取り出す」という処理の一例だと捉えることができます。そのため、学習処理は遅延評価と無限リストを使っても実装できるのです。むしろ、遅延評価と無限リストを使って実装したほうが、forループを使って実装するよりモジュール性の高いプログラムになります。そのことを、これから実際に実装しながら見ていきたいと思います。

セットアップ

本記事では遅延評価を使った実装の例として、MNISTの手書き数字画像の判別モデルを作るパイプラインをPyTorchで実装していきます。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

まず、MNISTのデータセットを読み込み、データローダーを作成します。データローダーはデータセットからミニバッチを生成するジェネレーターです。

※MNISTのデータセットの詳細は本記事の趣旨から外れるので割愛します。

from torchvision import datasets, transforms

train_dataset = datasets.MNIST(
    root = './mnist',
    train = True,
    transform = transforms.ToTensor(),
    download = True
)
valid_dataset = datasets.MNIST(
    root = './mnist', 
    train = False,
    transform = transforms.ToTensor(),
    download = True
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = 32,
    shuffle = True
)
valid_dataloader = torch.utils.data.DataLoader(
    valid_dataset,     
    batch_size = 32,
    shuffle = False
)

次に、手書き数字画像を判別する10クラス分類モデルを定義します。ここでは、3層の全結合層から成るニューラルネットワークを使うことにします。

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

遅延評価を使って実装した場合

準備が整ったので、遅延評価を使ってパイプラインを実装していきましょう。

以下の順に実装していきます。

  1. 無限ミニバッチ生成器を実装する
  2. パラメーターを更新する
  3. 損失を出力する
  4. バリデーションデータに対するメトリクスを求める
  5. Early stoppingを実装する
  6. パイプでつなぐ (※2023-09-28追記)

無限ミニバッチ生成器を実装する

天下り的ですが、まずは無限にミニバッチを生成するジェネレーターを実装します。

import itertools

def sample(g):
    for epoch in itertools.count():
        for step, batch in enumerate(g):
            yield epoch, step, batch

これだけだと無限ループに陥ってしまうので、適当なエポック数で停止させるジェネレーターも実装しておきます。停止処理は「ミニバッチを生成するジェネレーターを受け取り、指定されたエポック数に達した場合は打ち切りつつ、それ以外の場合は受け取ったジェネレーターをバイパスするジェネレーター」として実装します。

def epoch(g, epochs):
    for epoch, step, batch in g:
        if epoch >= epochs:
            break
        yield epoch, step, batch

これらのジェネレーターを使って2エポック分のミニバッチを生成するパイプラインを実装すると、次のようになります。

g = train_dataloader
g = sample(g)
g = epoch(g, epochs = 2)

history = [(epoch, step, batch) for epoch, step, batch in g]

これだけだと遅延評価のありがたみはまだ感じられないと思いますが、注目してほしいのは、遅延評価を使うことで「ミニバッチを生成する処理」と「適当なエポック数で停止させる処理」を別々の関数としてモジュール化できた、という点です。このトリックを使ってパイプラインに必要な処理をモジュールとしてどんどん実装していきますので、この後の実装を追うことで遅延評価の威力を実感できるはずです。

パラメーターを更新する

次に、ニューラルネットワークのパラメーターを更新する処理を実装します。更新処理は「ミニバッチを生成するジェネレーターを受け取り、ミニバッチの学習結果を返すジェネレーター」として実装します。

def train(g):
    model = Model()
    criterion = nn.CrossEntropyLoss() 
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.01) 
    model.train()

    for epoch, step, (x, y) in g:
        optimizer.zero_grad()
        logit = model(x)
        loss = criterion(logit, y)
        loss.backward()
        optimizer.step()
        yield epoch, step, (x, y, model, loss.item())

では、パイプラインに更新処理を組み込んでみましょう。組み込み方は簡単で、sampleepoch の間に train を挟むだけです。

g = train_dataloader
g = sample(g)
g = train(g)  # 更新処理を追加
g = epoch(g, epochs = 2)

history = [(epoch, step, batch) for epoch, step, batch in g]

損失を出力する

先ほど作ったパイプラインだと学習の様子がわかりませんね。そこで、学習データに対する損失を出力する処理を追加しましょう。損失を出力する処理は「ミニバッチの学習結果を返すジェネレーターを受け取り、損失を出力しつつ、受け取ったジェネレーターをバイパスするジェネレーター」として実装します。また、ただ損失を出力するのではなく、指定した頻度で出力する機能も実装します。

def print_loss(g, freq):
    loss_history = []
    for epoch, step, (x, y, model, loss) in g:
        loss_history.append(loss)
        if (step + 1) % freq == 0:
            print(
                f'epoch: {epoch:2d}, '
                f'step: {step:4d}, '
                f'train_loss: {np.array(loss_history).mean():.6f}'
            )
        yield epoch, step, (x, y, model, loss)

では、パイプラインに組み込みましょう。train を組み込んだときと同様に、trainepoch の間に print_loss を挟むだけで組み込むことができます。

g = train_dataloader
g = sample(g)
g = train(g)
g = print_loss(g, freq = 100)  # 学習データに対する損失を表示する処理を追加
g = epoch(g, epochs = 2)

history = [(epoch, step, batch) for epoch, step, batch in g]
epoch:  0, step:   99, train_loss: 0.663152
epoch:  0, step:  199, train_loss: 0.532888
epoch:  0, step:  299, train_loss: 0.469407
epoch:  0, step:  399, train_loss: 0.426044
epoch:  0, step:  499, train_loss: 0.398290
epoch:  0, step:  599, train_loss: 0.385324
epoch:  0, step:  699, train_loss: 0.376233
epoch:  0, step:  799, train_loss: 0.363760
epoch:  0, step:  899, train_loss: 0.354995
epoch:  0, step:  999, train_loss: 0.348061
epoch:  0, step: 1099, train_loss: 0.338226
epoch:  0, step: 1199, train_loss: 0.331887
epoch:  0, step: 1299, train_loss: 0.324841
epoch:  0, step: 1399, train_loss: 0.318441
epoch:  0, step: 1499, train_loss: 0.312837
epoch:  0, step: 1599, train_loss: 0.307243
epoch:  0, step: 1699, train_loss: 0.302320
epoch:  0, step: 1799, train_loss: 0.297000
epoch:  1, step:   99, train_loss: 0.290979
epoch:  1, step:  199, train_loss: 0.286124
epoch:  1, step:  299, train_loss: 0.282953
epoch:  1, step:  399, train_loss: 0.278730
epoch:  1, step:  499, train_loss: 0.276113
epoch:  1, step:  599, train_loss: 0.274349
epoch:  1, step:  699, train_loss: 0.270321
epoch:  1, step:  799, train_loss: 0.268394
epoch:  1, step:  899, train_loss: 0.265803
epoch:  1, step:  999, train_loss: 0.262449
epoch:  1, step: 1099, train_loss: 0.261021
epoch:  1, step: 1199, train_loss: 0.258586
epoch:  1, step: 1299, train_loss: 0.256030
epoch:  1, step: 1399, train_loss: 0.254241
epoch:  1, step: 1499, train_loss: 0.252580
epoch:  1, step: 1599, train_loss: 0.252040
epoch:  1, step: 1699, train_loss: 0.250729
epoch:  1, step: 1799, train_loss: 0.249429

損失が出力されるようになりました!

バリデーションデータに対するメトリクスを求める

学習データに対する損失だけでなく、バリデーションデータに対するメトリクスも出力したいですよね? ということで、バリデーションデータに対するメトリクスを求める処理を実装しましょう。ここではメトリクスとして、目的関数と同じクロスエントロピー損失を使うことにします。バリデーションデータに対するメトリクスを求める処理は「ミニバッチの学習結果を返すジェネレーターを受け取り、ミニバッチの学習結果とバリデーションデータに対するメトリクスを返すジェネレーター」として実装します。

def eval(g, dataloader, freq):
    criterion = nn.CrossEntropyLoss() 
    for epoch, step, (x, y, model, loss) in g:
        if (step + 1) % freq == 0:
            model.eval()
            with torch.no_grad():
                logit_val = []
                y_val = []
                for batch_x_val, batch_y_val in dataloader:
                    batch_logit_val = model(batch_x_val)
                    logit_val.append(batch_logit_val)
                    y_val.append(batch_y_val)
                logit_val = torch.cat(logit_val)
                y_val = torch.cat(y_val)
                loss_val = criterion(logit_val, y_val)
                loss_val = loss_val.item()
        else:
            loss_val = float('nan')
        yield epoch, step, (x, y, model, loss, loss_val)

また、先ほど作った print_loss にバリデーションデータに対するメトリクスを表示する処理を追加します。

def print_loss(g, freq):
    loss_history = []
    for epoch, step, (x, y, model, loss, loss_val) in g:
        loss_history.append(loss)
        if (step + 1) % freq == 0:
            print(
                f'epoch: {epoch:2d}, '
                f'step: {step:4d}, '
                f'train_loss: {np.array(loss_history).mean():.6f}, '
                f'valid_loss: {loss_val:.6f}'
            )
        yield epoch, step, (x, y, model, loss, loss_val)

では、パイプラインに組み込みましょう。先ほど組み込んだ print_loss の代わりに、eval と 新しい print_loss を組み込めば完成です。

g = train_dataloader
g = sample(g)
g = train(g)
g = eval(g, dataloader = valid_dataloader, freq = 100)  # バリデーションデータに対するメトリクスを求める処理を追加
g = print_loss(g, freq = 100)
g = epoch(g, epochs = 2)

history = [(epoch, step, batch) for epoch, step, batch in g]
epoch:  0, step:   99, train_loss: 0.708380, valid_loss: 0.377931
epoch:  0, step:  199, train_loss: 0.549197, valid_loss: 0.342479
epoch:  0, step:  299, train_loss: 0.483882, valid_loss: 0.314955
epoch:  0, step:  399, train_loss: 0.440803, valid_loss: 0.264604
epoch:  0, step:  499, train_loss: 0.413406, valid_loss: 0.296241
epoch:  0, step:  599, train_loss: 0.393419, valid_loss: 0.279318
epoch:  0, step:  699, train_loss: 0.379148, valid_loss: 0.230853
epoch:  0, step:  799, train_loss: 0.362535, valid_loss: 0.219792
epoch:  0, step:  899, train_loss: 0.349287, valid_loss: 0.226868
epoch:  0, step:  999, train_loss: 0.341394, valid_loss: 0.258043
epoch:  0, step: 1099, train_loss: 0.335757, valid_loss: 0.275199
epoch:  0, step: 1199, train_loss: 0.327363, valid_loss: 0.279877
epoch:  0, step: 1299, train_loss: 0.320532, valid_loss: 0.199107
epoch:  0, step: 1399, train_loss: 0.315832, valid_loss: 0.229835
epoch:  0, step: 1499, train_loss: 0.311150, valid_loss: 0.267595
epoch:  0, step: 1599, train_loss: 0.307990, valid_loss: 0.245136
epoch:  0, step: 1699, train_loss: 0.303149, valid_loss: 0.232942
epoch:  0, step: 1799, train_loss: 0.298048, valid_loss: 0.192749
epoch:  1, step:   99, train_loss: 0.291517, valid_loss: 0.220366
epoch:  1, step:  199, train_loss: 0.287115, valid_loss: 0.228257
epoch:  1, step:  299, train_loss: 0.282764, valid_loss: 0.199374
epoch:  1, step:  399, train_loss: 0.279637, valid_loss: 0.206661
epoch:  1, step:  499, train_loss: 0.276635, valid_loss: 0.219089
epoch:  1, step:  599, train_loss: 0.272761, valid_loss: 0.220430
epoch:  1, step:  699, train_loss: 0.269227, valid_loss: 0.187363
epoch:  1, step:  799, train_loss: 0.266293, valid_loss: 0.278255
epoch:  1, step:  899, train_loss: 0.263997, valid_loss: 0.191688
epoch:  1, step:  999, train_loss: 0.261281, valid_loss: 0.194956
epoch:  1, step: 1099, train_loss: 0.258783, valid_loss: 0.208358
epoch:  1, step: 1199, train_loss: 0.256781, valid_loss: 0.207703
epoch:  1, step: 1299, train_loss: 0.255177, valid_loss: 0.173686
epoch:  1, step: 1399, train_loss: 0.253368, valid_loss: 0.203147
epoch:  1, step: 1499, train_loss: 0.251830, valid_loss: 0.243142
epoch:  1, step: 1599, train_loss: 0.249535, valid_loss: 0.196261
epoch:  1, step: 1699, train_loss: 0.247223, valid_loss: 0.182852
epoch:  1, step: 1799, train_loss: 0.245790, valid_loss: 0.220971

バリデーションデータに対するメトリクスが表示されるようになりました!

Early stoppingを実装する

バリデーションデータに対するメトリクスを計算できるようにしたので、その結果を使ってearly stoppingも実装しましょう。Early stoppingは「ミニバッチの学習結果とバリデーションデータに対するメトリクスを返すジェネレーターを受け取り、メトリクスがしばらく更新されなかった場合は打ち切りつつ、それ以外の場合は受け取ったジェネレーターをバイパスするジェネレーター」として実装します (手抜き実装なので、更新処理を打ち切るだけでベストパラメーターを返す処理は入っていません、悪しからず。。)

def early_stopping(g, steps):
    stop = 0
    min_loss_val = float('inf')
    for epoch, step, (x, y, model, loss, loss_val) in g:
        if stop >= steps:
            break
        elif min_loss_val > loss_val:
            min_loss_val = loss_val
            stop = 0
        else:
            stop += 1
        yield epoch, step, (x, y, model, loss, loss_val)

パイプラインに組み込みましょう。今回は epoch の後ろに early_stopping を差し込みます。こうすることで、early stoppingがなかなか作動しなかった場合でも epoch で指定したエポック数に到達した時点で更新処理が停止するようになります。逆にearly stoppingが作動するまで更新処理を止めたくない場合は、パイプラインから epoch を削除します。こういった抜き差しが簡単にできるのは、遅延評価を使って各処理を別々の関数としてモジュール化したおかげです。

g = train_dataloader
g = sample(g)
g = train(g)
g = eval(g, dataloader = valid_dataloader, freq = 100)
g = print_loss(g, freq = 100)
g = epoch(g, epochs = 2)
g = early_stopping(g, steps = 300)  # early stoppingを追加

history = [(epoch, step, batch) for epoch, step, batch in g]
epoch:  0, step:   99, train_loss: 0.665783, valid_loss: 0.361112
epoch:  0, step:  199, train_loss: 0.505192, valid_loss: 0.317636
epoch:  0, step:  299, train_loss: 0.451494, valid_loss: 0.326704
epoch:  0, step:  399, train_loss: 0.420570, valid_loss: 0.360294
epoch:  0, step:  499, train_loss: 0.398997, valid_loss: 0.230503
epoch:  0, step:  599, train_loss: 0.377945, valid_loss: 0.333187
epoch:  0, step:  699, train_loss: 0.365281, valid_loss: 0.256404
epoch:  0, step:  799, train_loss: 0.353939, valid_loss: 0.244301

バリデーションデータに対するメトリクスが改善しなくなった時点でパラメーターの更新処理が止まるようになりました!

パイプでつなぐ

(本項は2023-09-28に追記しました)

これまで実装してきた関数はすべて「第1引数でiterableを受け取り、iterableを返す」という形になっていました。このような形の関数を連続的に呼び出す際に使えるのがpipeというライブラリです。pipeを使うと、その名の通りパイプ | を使ってシェルコマンドのように処理を連続的に呼び出すことができます。さっそく使ってみましょう!

まず、pipeをインストールします。

pip install pipe

次に、pipeをインポートします。ここでは、自作関数をパイプでつなげる形へ変換するのに必要な @Pipe デコレーターをインポートします。

from pipe import Pipe

そして、これまで実装してきた関数をパイプでつなげる形へ変換します。変換方法は簡単で、以下のように関数定義の冒頭に @Pipe デコレーターを置くだけです。

@Pipe
def early_stopping(g, steps):
    ()

これで準備は完了です! では、パイプラインをパイプで書き直してみましょう。

g = (
    train_dataloader
    | sample()
    | train()
    | eval(dataloader = valid_dataloader, freq = 100)
    | print_loss(freq = 100)
    | epoch(epochs = 2)
    | early_stopping(steps = 300)
)

history = [(epoch, step, batch) for epoch, step, batch in g]

パイプによって前の処理の出力が後続の処理の入力になることが強調され、よりわかりやすく、すっきりとしたコードになりました! (ちょっと好き嫌いが分かれそうですが……)

遅延評価を使わずに実装した場合

これまで実装してきた処理を遅延評価を使わずに実装するとどうなるか見てみましょう。

以下に実装例を示します。普段からPyTorchをお使いの方にとっては見慣れたお決まりのコードですが、ここまでお見せしてきた遅延評価を使った実装を見比べるとかなりごちゃついていることがわかりますね。

# 最大エポック数
epochs = 2
# バリデーションデータに対するメトリクスを求める頻度
freq = 100
# early stoppingのステップ数
steps = 300

# 学習用の変数
model = Model()
criterion = nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)
model.train()

# 損失出力用の変数
loss_history = []

# early stopping用の変数
stop = 0
min_loss_val = float('inf')

for epoch in range(epochs):
    for step, (x, y) in enumerate(train_dataloader):
        # 学習
        optimizer.zero_grad()
        logit = model(x)
        loss = criterion(logit, y)
        loss.backward()
        optimizer.step()

        # 損失の平均を出力するために、損失の履歴を残しておく
        loss = loss.item()
        loss_history.append(loss)

        if (step + 1) % freq == 0:
            # バリデーションデータに対するメトリクスを求める
            model.eval()
            with torch.no_grad():
                logit_val = []
                y_val = []
                for batch_x_val, batch_y_val in valid_dataloader:
                    batch_logit_val = model(batch_x_val)
                    logit_val.append(batch_logit_val)
                    y_val.append(batch_y_val)
                logit_val = torch.cat(logit_val)
                y_val = torch.cat(y_val)
                loss_val = criterion(logit_val, y_val)
                loss_val = loss_val.item()

            # 損失を出力する
            print(
                f'epoch: {epoch:2d}, '
                f'step: {step:4d}, '
                f'train_loss: {np.array(loss_history).mean():.6f}, '
                f'valid_loss: {loss_val:.6f}'
            )
        else:
            loss_val = float('nan')

        # early stopping
        if stop >= steps:
            break
        elif min_loss_val > loss_val:
            min_loss_val = loss_val
            stop = 0
        else:
            stop += 1

    # 二重forループをbreakするためのおまじない
    else:
        continue
    break

ごちゃつきの原因は、ミニバッチを生成するforループを介して各種処理がつながっており切り離すことが難しいからです。各種処理を切り離すには、その前段として「ミニバッチを生成する処理」と「ミニバッチに対する各種処理」を切り離す必要があります。具体的には、全ミニバッチを集めたリストをあらかじめ作っておき、そのリストの各要素に対して各種処理を適用する、という順に処理が流れるよう実装し直す必要があります。

しかし、正格評価では「全ミニバッチを集めたリストをあらかじめ作っておく処理」を呼び出した時点で本当に全ミニバッチを集めたリストをメモリ上に作ってしまいます。学習データがメモリに載らないからミニバッチに切り分けていた (という側面もある) わけなので、全ミニバッチを集めたリストも当然メモリに載りません。さらに、early stoppingのようなステップ数を動的に決める処理を含む場合はそもそも必要なミニバッチ数が事前にはわからないため、ミニバッチを無限に含むリストを準備しておく必要がありますが、これは正格評価だと無限ループに陥ってしまいます。

一方で、遅延評価であれば「全ミニバッチを集めたリストをあらかじめ作っておく処理」は各要素が必要になるまで実際には計算されず保留されます。そのため、コード上は「ミニバッチを無限に含むリストをあらかじめ作っておき、そのリストの各要素に対して各種処理を適用する」という処理でありつつも、実際には「ミニバッチを一つ作り、そのミニバッチに対して各種処理を適用する、という処理を必要回数だけ順に繰り返す」という手順で実行される処理を実装することができます。本記事で紹介した実装は、まさしくこの遅延評価の特性を利用して「ミニバッチを生成する処理」と「ミニバッチに対する各種処理」を切り離したものになっています (※)。

※本記事で紹介した私の実装だと history に全ミニバッチが含まれてしまっているので「正格評価で実装した場合と同じで全部メモリに載せてるじゃないか」と思われるかもしれませんが、それはあくまで私の実装の問題であり遅延評価とは関係ありません。

最後に、遅延評価を使った実装を改めて見てみましょう。1行ごとにコメントを入れてみました。遅延評価を使って「ミニバッチを生成する処理」と「ミニバッチに対する各種処理」を切り離したおかげで、上から順に読んでいくだけで処理内容を理解できる非常に見通しの良いコードになっていることがわかります。これが遅延評価の威力です。

# データセットからミニバッチを生成し、
g = train_dataloader
g = sample(g)
# そのミニバッチでパラメーターを更新し、
g = train(g)
# バリデーションデータに対するメトリクスを求め、
g = eval(g, dataloader = valid_dataloader, freq = 100)
# 損失とメトリクスを定期的に出力する。
g = print_loss(g, freq = 100)
# ……以上の処理を、2エポック分、
g = epoch(g, epochs = 2)
# もしくは、300ステップ以上メトリクスが改善しなくなるまで繰り返す。
g = early_stopping(g, steps = 300)

history = [(epoch, step, batch) for epoch, step, batch in g]

まとめ

本記事では、機械学習パイプラインを構築する例を通して、遅延評価がコードのモジュール化に役立つことを示しました。遅延評価と無限リストを用いたモジュール化は、本記事で紹介した学習処理に代表されるような「停止条件を満たすまでforループを回す」系の処理に一般的に使えるパターンですので、機会があればぜひ使ってみてください。

Discussion

yKesamaruyKesamaru

素晴らしい記事をありがとうございます😊

処理全体のメモリ使用量はどのように変化するのでしょうか?
yieldを使っているので、メモリ使用量は減る…のか、かえって使用量が増えてしまうのか気になります🤔

mtmarumtmaru

ご質問ありがとうございます!
こんな感じ↓で、実際に計測してみました。

import psutil

@Pipe
def watch_memory(g):
    for epoch, step, (x, y, model, loss, loss_val) in g:
        m = psutil.virtual_memory()
        memory = m.used
        yield epoch, step, (x, y, model, loss, loss_val, memory)

g = (
    train_dataloader
    | sample()
    | train()
    | eval(dataloader = valid_dataloader, freq = 100)
    | print_loss(freq = 100)
    | epoch(epochs = 2)
    # ステップ終了時点のメモリ使用量を計測する。
    | watch_memory()
)

memory_history_lazy = [memory for (_, _, (_, _, _, _, _, memory)) in g]

lazy: yieldを使った実装、eager: forループを使った実装。横軸はステップ数、縦軸はステップ0終了時点のメモリ使用量からの増分です。

yieldを使ったほうがGCがうまく働くみたいで、メモリ使用量は横ばいになりました。何回か実行してみましたが、どれも似たような傾向でした。

本記事で扱ったPyTorchの学習の例に限っていうと、そもそもforループを使った実装でも、yieldを使った実装でも、1ステップ実行するのに必要なミニバッチだけをメモリに載せているので (そうしないとGPUのメモリに載らないので……) メモリ使用量に大きな差は出ないと思われます。yield自体のオーバーヘッドがあるかもしれないと思いましたが、今回実験してみた感じだと気にしなくても良さそうです。

yKesamaruyKesamaru

詳細な実験と解析、ありがとうございます😃

ミニバッチの大きさに依存しそうですね。グラフの縦軸が10MBずつなので、大きな差が生まれるかどうか判断がつきづらいですが、仮にこの状態でスケールするなら大きなアドバンテージになりそうです。

機能をモジュールとして分けてパイプでつなぐ、という発想は斬新で、コードがとてもスッキリする気がします。うちはfor文の「あのコードの形」に慣れてしまっているので、最初拝見した時は「???」となりましたが、何回か見ているとスッと頭に入ってくるようになり、便利かも…と思いました。

うちはスッキリしたコードにしたい時、pytorch-metric-learningを使います。
パイプでつなぐやり方は、初見の壁と大きなメリット(メモリとか)があれば、「スッキリしたコードにしたい」一定数の方々の需要を満たせる気がしますし、各モジュールをライブラリ化すれば将来性があると感じました。

mtmarumtmaru

pytorch-metric-learningというライブラリもあるんですね、
知りませんでした👀
教えていただきありがとうございます!

tenkohtenkoh

とても参考になります、ありがとうございます!

ちょうど実践のチャンスがあったため、本記事の内容を参考に、遅延評価にトライしてみています。その際に「こうしても良いかも」と思ったことがありましたのでコメントさせて頂きます。

バッチごとの結果を記録していく際にタプルの中に追加していくのもアリですが、辞書にしても使い勝手が良いのかもと思いました。タプルを展開する時の変数の数も固定でき、また、利用したいパラメータを名前指定して取り出しやすくなるのかな、と。

具体的には、最初のsampleの中で空の辞書を追加しておき、取っておきたい値などがあればその辞書に追加していくイメージです。

空の辞書を作っておく
def sample(g):
    for epoch in itertools.count():
        for step, (x, y) in enumerate(g):
            yield epoch, step, (x, y, dict())
辞書に値を放り込んでいく
def train(g):
    # 省略
    for epoch, step, (x, y, results) in g:
        optimizer.zero_grad()
        logit = model(x)
        loss = criterion(logit, y)
        loss.backward()
        optimizer.step()
        results['model'] = model
        results['loss'] = loss.item()
        yield epoch, step, (x, y, results)
mtmarumtmaru

コメントありがとうございます! 実践してみた方の知見を共有していただけると私も勉強になって大変助かります! 確かに辞書で取り回したほうが使い勝手が良さそうですね、私も試してみます👀