🧮

【PyTorch】TorchEval を使って精度評価しよう

2022/11/13に公開約13,000字2件のコメント

はじめに

今回は PyTorch で Deep Learning (深層学習,機械学習) を行う際に用いる,評価指標の計算方法について記述していきます.

本記事では,TorchEval という Facebook 社が開発を主導している PyTorch と同時に使われることを想定している,2022/10/30 にリリースされたばかりの精度評価用のライブラリです.

ドキュメントを以下に貼っておきます.

この記事を読むメリット

  • 複雑な評価計算を 1 行で実装できる
  • PyTorch が公式で出しているライブラリのためコードが綺麗になる

といった 2 点のメリットがあります.

これを使うことで,Accuracy,Top-k Accuracy はもちろん Precision(適合率)や Recall(再現率)や F1-score(Dice),混同行列,PR曲線,AUCに至るまで Tensor型のまま手軽に扱うことができます!

ちなみに,それぞれ macro平均と micro平均を選択できる実装となっています.至れり尽くせりでとても嬉しいですね!

環境

現時点(2022/11/14)で stable である Python 3.10.6 を使用します.

PC MacBook Pro (16-inch, 2019)
OS Monterey
CPU 2.3 GHz 8コアIntel Core i9
メモリ 16GB
Python 3.10.6

ライブラリ

現在の情報(2022/11/14)では

Python >= 3.7 and PyTorch >= 1.11

の必要があるようです.

今回の記事で筆者が用いるライブラリとバージョンをまとめます.
また,seaborn については任意ですが,出力するグラフの見た目を整えるのに使用します.

ライブラリ バージョン
matplotlib 3.6.2
seaborn 0.12.1
torch 1.13.0
torcheval 0.0.5
torchvision 0.14.0
terminal
pip install matplotlib seaborn torch torchvision torcheval

上記のコマンドで必要なライブラリをインストールできますが,PyTorch には GPU環境と CPU環境があるので,公式ページを見ながら注意してインストールしましょう.

もしエラーを吐かれてしまい上手く動作しなかった場合には,上記のバージョンを指定して再度インストールしてみてください.

今までの書き方

ここでは,TorchEval を使わない今までの記法でプログラムを書いてみます.
ただし,Precision,Recall,F1-score については大変なので省略します.

out = model(images)
preds = out.argmax(dim=1)
# preds.shape == torch.Size([batch_size])

accuracy = (preds == labels).sum() / len(labels)

TorchEvalを使った書き方

ここからは,TorchEval を使った評価指標の計算方法を紹介します.
クラス で書かれたものと関数で書かれたものと 2 種類存在するので,条件に合わせて使いやすい方を使ってください.

サンプルコードを以下に記述します.評価指標を全て 1 行で簡単に書けました!
なお,今回のサンプルコードでは,こちらの様なラベルと Sofmax関数を通した後の出力(確率値)を使用しているので注意してください.

"""
labels = tensor([8, 7, 9, 3, 7, 6, 1, 1, 3, 2])
"""

out = model(images)
"""
out = tensor(
    [[0.2558, 0.0244, 0.1184, 0.0310, 0.1463, 0.0152, 0.2133, 0.0512, 0.0105, 0.1340],
    [0.1786, 0.0246, 0.1293, 0.0243, 0.0443, 0.0400, 0.0683, 0.2168, 0.0808, 0.1930],
    [0.1311, 0.3172, 0.0659, 0.1582, 0.0350, 0.0399, 0.0162, 0.0911, 0.0632, 0.0820],
    [0.3395, 0.0520, 0.0117, 0.1835, 0.0559, 0.0198, 0.1181, 0.0534, 0.0599, 0.1062],
    [0.0147, 0.0716, 0.0287, 0.0799, 0.1378, 0.0601, 0.2032, 0.0776, 0.2142, 0.1122],
    [0.0575, 0.1023, 0.0310, 0.1131, 0.0523, 0.0441, 0.2783, 0.1829, 0.0874, 0.0511],
    [0.0976, 0.0729, 0.0602, 0.1615, 0.1284, 0.0748, 0.1066, 0.1356, 0.0491, 0.1132],
    [0.1591, 0.2803, 0.0285, 0.0558, 0.0611, 0.0189, 0.0371, 0.1770, 0.0442, 0.1380],
    [0.0606, 0.0690, 0.1088, 0.1799, 0.1497, 0.0697, 0.1711, 0.0488, 0.1159, 0.0266],
    [0.1326, 0.0567, 0.0791, 0.1281, 0.0108, 0.0933, 0.3284, 0.0407, 0.0277, 0.1026]],
    grad_fn=<SoftmaxBackward0>
)
out.argmax(dim=1) = tensor([0, 7, 1, 0, 8, 6, 3, 1, 3, 6])
"""

Accuracy

最もイメージしやすく,みなさんが最もよく使う指標ではないでしょうか.
精度や正解率と言われるものです.

from torcheval.metrics.functional import multiclass_accuracy
# from torcheval.metrics import MulticlassAccuracy


out = model(images)
accuracy = multiclass_accuracy(
    input=out,
    target=labels,
    num_classes=10,
    average="micro"
).item()
# metric = MulticlassAccuracy(num_classes=10, average="micro")
# metric.update(out, labels)
# accuracy = metric.compute().item()
"""
accuracy = 0.4000000059604645
"""

Precision

適合率と訳されます.
推論したものがどれだけ正しかったかを表す評価指標で,以下の式で求めることが可能です.

Precision = \frac{TP}{TP + FP}
from torcheval.metrics.functional import multiclass_precision
# from torcheval.metrics import MulticlassPrecision


out = model(images)
precision = multiclass_precision(
    input=out,
    target=labels,
    num_classes=10,
    average="micro"
).item()
# metric = MulticlassPrecision(num_classes=10, average="micro")
# metric.update(out, labels)
# precision = metric.compute().item()
"""
precision = 0.4000000059604645
"""

Recall

再現率と訳され,以下の式で求めることが可能です.
どれだけ取りこぼすことなく推論することができたかを表す評価指標のため,主に医療分野で重要視されます.

Recall = \frac{TP}{TP + FN}
from torcheval.metrics.functional import multiclass_recall
# from torcheval.metrics import MulticlassRecall


out = model(images)
recall = multiclass_recall(
    input=out,
    target=labels,
    num_classes=10,
    average="micro"
).item()
# metric = MulticlassRecall(num_classes=10, average="micro")
# metric.update(out, labels)
# recall = metric.compute().item()
"""
recall = 0.4000000059604645
"""

F1-score (Dice)

Precision と Recall の調和平均のことで,F-measure や Dice 係数とも呼ばれます.
Precision と Recall ともにそれなりに精度が欲しい時に使う指標で,以下の式で求めることが可能です.

\begin{align*} F1\_score =& 2 \div \left( \frac{1}{Precision} + \frac{1}{Recall} \right) \\ =& \frac{2}{\frac{1}{Precision} + \frac{1}{Recall}} \\ =& 2 \times \frac{Precision \times Recall}{Precision + Recall} \end{align*}

余談ですが,F-score にも種類があります.
それらは,Precision 重視の F0.5-score と Recall 重視の F2-score で,以下の式で求めることが可能です.

\begin{align*} F1\_score =& \frac{2 \times Precision \times Recall}{Precision + Recall} \\ F0.5\_score =& \frac{1.25 \times Precision \times Recall}{0.25Precision + Recall} \\ F2\_score =& \frac{5 \times Precision \times Recall}{4Precision + Recall} \end{align*}
from torcheval.metrics.functional import multiclass_f1_score
# from torcheval.metrics import MulticlassF1Score


out = model(images)
f1 = multiclass_f1_score(
    input=out,
    target=labels,
    num_classes=10,
    average="micro"
).item()
# metric = MulticlassF1Score(num_classes=10, average="micro")
# metric.update(out, labels)
# f1 = metric.compute().item()
"""
f1 = 0.4000000059604645
"""

ConfusionMatrix

混同行列と訳され,AIモデルがどのような予測をしているのか把握したい時に用います.

from torcheval.metrics.functional import multiclass_confusion_matrix
# from torcheval.metrics import MulticlassConfusionMatrix


out = model(images)
cmat = multiclass_confusion_matrix(
    input=out,
    target=labels,
    num_classes=10,
)
# metric = MulticlassConfusionMatrix(num_classes=10)
# metric.update(out, labels)
# cmat = metric.compute()
"""
cmat = tensor(
    [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    [1, 0, 0, 1, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]]
)
"""

AUC

AUC は Area Under Curve の頭文字を取ったもので,ROC曲線 (Receiver Operating Characteristic) の曲線下面積を計算した指標です.
AUC は 0 から 1 の範囲で値を取り,1 に近ければより良いモデルだと言えます.

from torcheval.metrics.functional import multiclass_auroc
# from torcheval.metrics import MulticlassAUROC


out = model(images)
auc = multiclass_auroc(
    input=out,
    target=labels,
    num_classes=10,
    average="macro"
).item()
# metric = MulticlassAUROC(num_classes=10, average="macro")
# metric.update(out, labels)
# auc = metric.compute().item()
"""
auc = 0.5840
"""

PR曲線

PR曲線についてはクラスごとに出力されるみたいです.
少し複雑ですが,DocString に仕様が詳しく書かれていたので参考にしてください.
大事なのはpr_curve = (precision, recall, thresholds)を返すということですね!
さらに,precision.shape = recall.shape = (num_classes, batch_size + 1)となっています.
https://github.com/pytorch/torcheval/blob/main/torcheval/metrics/functional/classification/precision_recall_curve.py#L107-L139

サンプルコードを以下に記述します.
また,multiclass_precision_recall_curveMulticlassPrecisionRecallCurveは信頼度のしきい値を適当に決めてくれますが,自分で信頼度のしきい値を指定したい場合はmulticlass_binned_precision_recall_curveMulticlassBinnedPrecisionRecallCurveを使用してください.

from torcheval.metrics.functional import multiclass_precision_recall_curve
# from torcheval.metrics import MulticlassPrecisionRecallCurve


out = model(images)
pr_curve = multiclass_precision_recall_curve(
    input=out,
    target=labels,
    num_classes=10,
)
# metric = MulticlassPrecisionRecallCurve(num_classes=10)
# metric.update(out, labels)
# pr_curve = metric.compute()
"""
pr_curve = (
    [tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]),
    tensor([0.2000, 0.2222, 0.2500, 0.2857, 0.3333, 0.4000, 0.5000, 0.3333, 0.5000, 0.0000, 1.0000]),
    tensor([0.1000, 0.1111, 0.1250, 0.1429, 0.1667, 0.2000, 0.2500, 0.0000, 0.0000, 0.0000, 1.0000]),
    tensor([0.2000, 0.2222, 0.2500, 0.2857, 0.3333, 0.4000, 0.5000, 0.6667, 1.0000, 1.0000, 1.0000]),
    tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]),
    tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]),
    tensor([0.1000, 0.1111, 0.1250, 0.1429, 0.1667, 0.2000, 0.2500, 0.3333, 0.5000,0.0000, 1.0000]),
    tensor([0.2000, 0.2222, 0.2500, 0.2857, 0.3333, 0.2000, 0.2500, 0.3333, 0.5000, 1.0000, 1.0000]),
    tensor([0.1000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]),
    tensor([0.1000, 0.1111, 0.1250, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000])],

    [tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.]),
    tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.5000, 0.5000, 0.0000, 0.0000]),
    tensor([1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.]),
    tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.5000, 0.0000]),
    tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.]),
    tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.]),
    tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.]),
    tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.0000]),
    tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
    tensor([1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.])],

    [tensor([0.0147, 0.0575, 0.0606, 0.0976, 0.1311, 0.1326, 0.1591, 0.1786, 0.2558,0.3395]),
    tensor([0.0244, 0.0246, 0.0520, 0.0567, 0.0690, 0.0716, 0.0729, 0.1023, 0.2803,0.3172]),
    tensor([0.0117, 0.0285, 0.0287, 0.0310, 0.0602, 0.0659, 0.0791, 0.1088, 0.1184,0.1293]),
    tensor([0.0243, 0.0310, 0.0558, 0.0799, 0.1131, 0.1281, 0.1582, 0.1615, 0.1799,0.1835]),
    tensor([0.0108, 0.0350, 0.0443, 0.0523, 0.0559, 0.0611, 0.1284, 0.1378, 0.1463,0.1497]),
    tensor([0.0152, 0.0189, 0.0198, 0.0399, 0.0400, 0.0441, 0.0601, 0.0697, 0.0748,0.0933]),
    tensor([0.0162, 0.0371, 0.0683, 0.1066, 0.1181, 0.1711, 0.2032, 0.2133, 0.2783,0.3284]),
    tensor([0.0407, 0.0488, 0.0512, 0.0534, 0.0776, 0.0911, 0.1356, 0.1770, 0.1829,0.2168]),
    tensor([0.0105, 0.0277, 0.0442, 0.0491, 0.0599, 0.0632, 0.0808, 0.0874, 0.1159,0.2142]),
    tensor([0.0266, 0.0511, 0.0820, 0.1026, 0.1062, 0.1122, 0.1132, 0.1340, 0.1380,0.1930])]
)
"""

PR曲線を描画したい時は以下の様なプログラムを使うことで描画することができます.

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

PRECISION, RECALL = 0, 1

plt.axes().set_aspect("equal")
plt.xlim(0, 1)
plt.ylim(0, 1)

for n in num_classes:
    precision = pr_curve[PRECISION][n]
    recall = pr_curve[RECALL][n]
    plt.plot(recall, precision, label=n)

plt.legend()
plt.savefig("pr.pdf")

使用例

ResNet18 を用いた MNIST の画像データ(10クラス)を分類する機械学習プログラムとその出力結果を置いておきます.
もしよければ参考にしてください.

プログラム

main.py
from typing import OrderedDict

from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.models import resnet18
from torcheval.metrics import MulticlassAccuracy
from torchvision import transforms
from tqdm import tqdm


def train(
    dataloader: DataLoader,
    model: nn.Module,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    total_epoch: int = 100
) -> None:
    metric = MulticlassAccuracy(average="micro", num_classes=10)
    model.train()

    for epoch in range(total_epoch):
        with tqdm(dataloader) as pbar:
            pbar.set_description(f"[Epoch {epoch + 1}/{total_epoch}]")

            for images, labels in pbar:
                optimizer.zero_grad()

                out = model(images)
                loss = criterion(out, labels)

                loss.backward()
                optimizer.step()

                metric.update(out, labels)
                accuracy = metric.compute().item()

                pbar.set_postfix(OrderedDict(Accuracy=accuracy))

        metric.reset()


def main():
    dataset = MNIST(
        root="data/MNIST",
        train=True,
        transform=transforms.ToTensor(),
        download=True,
    )
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=128,
        shuffle=True,
        drop_last=True,
    )

    model = resnet18(pretrained=True)
    model.conv1 = nn.Conv2d(
        in_channels=1,
        out_channels=64,
        kernel_size=model.conv1.kernel_size,
        stride=model.conv1.stride,
        padding=model.conv1.padding,
        bias=False,
    )
    model.fc = nn.Linear(in_features=model.fc.in_features, out_features=10)

    train(
        dataloader=dataloader,
        model=model,
        optimizer=optim.SGD(params=model.parameters(), lr=1e-3),
        criterion=nn.CrossEntropyLoss()
    )


if __name__ == "__main__":
    main()

出力結果

上記のプログラムを走らせた時の出力結果です.

[Epoch 1/100]:   2%|█       | 698/30000 [01:02<45:11, 10.81it/s, Accuracy=0.21]

おわりに

本記事では,TorchEval を用いた精度評価の計算方法について書いてきました.
今までは,自分で複雑なプログラムを書いて間違っていないかビクビクしていましたが,公式に評価指標の計算がサポートされてとても嬉しいです.
今後は積極的に活用していこうと思います.

GitHubで編集を提案

Discussion

公式に↓とあるので、GPU環境の人がCPU版のpytorchを入れないようにコメント書いてあげると良いかも。
Contains a rich collection of high performance metric calculations out of the box. We utilize vectorization and GPU acceleration where possible via PyTorch.

コメントありがとうございます!
追記修正しました!

ログインするとコメントできます