Google Colaboratoryの無料GPU環境でPytorchによる画像認識を行う
概要
Deep learningによる画像認識は、GPU環境で実行するのがスタンダードになっています。
最近の高精度な深層学習アーキテクチャの学習・推論を行うためには、それなりに高価なGPUが必要ですが、個人で導入するにはハードルが高く、クラウドサービス(AWS,GCP等)を利用するのが便利です。
ただ、計算を試すだけのためにAWSを借りるのもな、、というときは、Google Colaboratoryという無料でGPUが使える環境があるので、これを使うと良いです。
Google driveでcolaboratoryの準備
まずは、自分のGoogle drive上で開発できたほうがいいので、Google driveにアプリとしてcolaboratoryを追加し、ドライブ上でGoogle colaboratoryノートブックを作成します。この辺りの手順は、下記手順などを参照してください。
GPUランタイムの設定
ノートブックが立ち上がったら、GPUをセッティングします。
右上の「接続」タブをクリックし、ランタイムに接続します。
接続できると、RAMとディスクの容量が表示されます。
ただ、これではGPUが使えませんので、「ランタイム」→「ランタイムのタイプの変更」から、ハードウェアアクセラレータとして「GPU」を指定し、保存します。すると、自動的にGPUランタイムに接続されます。
GPUが使えるかどうか確認します。
!nvidia-smi
と入力し実行すると、GPUのステータスが表示されます。Notebook上では、!から始まるセンテンスはLinuxコマンドとして認識されますので、このように通常のLinuxコマンドが使えます。
私の環境の場合、Nvidia Tesla T4に接続されていることがわかります。使えるGPUは接続するたびに変わるようですが、随時性能のいいものが使えるようにアップデートされているようです。また、Cudaのバージョンは10.1であることがわかります。
Pytorchのインストール
次に、深層学習フレームワークのPytorchをインストールします。Pytorch公式HP(https://pytorch.org/)から、OSやCudaのバージョンを適切に選択し、表示されるpipコマンドをコピペします。今回の環境はCuda10.1なので、下記コマンドをnotebook上で実行します。
!pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
Colaboratoryは機械学習に必要な様々なフレームワークがすでにインストールされているので、Pytorchのlatestバージョンがインストール済みの場合もあります。
画像認識コードの実行
次に、画像認識のサンプルコードを実行していきます。Pytorch公式にCIFAR10の画像分類のサンプルがありましたので、これをベースに改変して使います。
まずは、Pytorch関連のモジュールをインポートします。%matplotlib inline
import torch
import torchvision
import torchvision.transforms as transforms
次に、計算するデバイスを指定します。"cuda:0"としてGPUデバイスを指定できます。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
次に、CNNのアーキテクチャを記述します。畳み込み層が2層の簡単なネットワークを使います。net.to(device) は、作成したネットワークはGPUデバイス上で計算するために必要です。
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
net.to(device)
次に、ロス関数と最適化手法を設定します。
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
そして、トレーニングを行うコードを書きます。
for epoch in range(10): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(device), data[1].to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
ここでは、トレーニングのループは10エポックにしてみます。最後にCNNの重みを保存しておきます。
実行すると、学習が進行して、ロスが下がっていきます。私の環境では、10分ほどでトレーニングが終了し、ロスが0.834となりました。
無事学習を行うことができましたが、学習したCNNの汎化性能を見るためにテストデータでテストします。
そのために、画像を表示する関数を作成します。
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
テストデータから4枚画像をサンプルしてGround truthを表示します。
dataiter = iter(testloader)
images, labels = dataiter.next()
# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
保存した重みをロードし、サンプルした4枚の画像を入力とした推論を行います。
net = Net()
net.load_state_dict(torch.load(PATH))
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
for j in range(4)))
3枚目のshipはplaneと推論されてしまっていますが、概ね正しく学習できていることが分かります。
最後に
Google colaboratoryは、GPUを使ってDeep learningをやりたい!というときに非常に便利です。無料なので、時間制限(ノートブックのセッションが切れてから90分、インスタンスは起動してから12時間)があったり、GPUやCPUの細かいチューニングができなかったりしますが、手軽に計算を回せるので重宝します。
Discussion