CIFAR10の読み込みと正規化をソースコードベースで解説
概要
Pytorchのチュートリアル「TRAINING A CLASSIFIER」(日本語名:クラス分類の学習方法)の最初の章、「CIFAR10の読み込みと正規化」のみ解説していきます。
Pytorchのチュートリアル↓
このチュートリアルは「Deep Learning with PyTorch: A 60 Minute Blitz」の最終章であり、これを理解できればPyTorchを理解したと言っても過言ではないわけです(いや、過言ですね)。
ネットワークの定義や実行などは細かく解説されているサイトが多いですが前処理の部分をソースコードまで読んで細かく解説している人がいなかったので備忘録も兼ねて書いていきます。
C言語しか勉強してないので変に細かいところまで書いていますがご了承ください。
モジュールの読み込み
まずは必要なモジュールを読み込む。
import torch
import torchvision
import torchvision.transforms as transforms
PyTorchをインポートする際はPyTorchではなくtorch
とします。
torchvision
は画像のデータセットの処理を、
torchvision.transforms
はデータセットの変換などを行うモジュールです。
torchvision.transforms
をtransforms
としてインポートしています。
データセットの前処理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
6つインスタンス化している。長いのでコードを分割して解説。
transform
まずは最初の3行、
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
これは次の行'trainset'
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
のtorchvision.datasets.CIFAR10
のtransform
引数に渡すためのオブジェクトになっている。
transform
は読み込んだデータセットの画像を変換するためのオブジェクトで、それを先ほどの3行で定義している。
transforms.Compose
この3行のメインはtransforms.Compose
オブジェクト。
コンストラクタの引数に、2種類の変換用オブジェクトを指定しているという入れ子構造になっている。
ソースコードを読んでtransforms.Compose
クラスが何をしているのか見ていく。
[docs]class Compose:
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
特殊メソッドが3つ。
__init__関数
def __init__(self, transforms):
self.transforms = transforms
self.transforms
に変換用オブジェクトtransforms
を格納している。
__call__関数
call関数の説明↓
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
このコードから、コンストラクタで格納したtransforms
をt
に格納して、先頭から順に変換処理が行われていることがわかる。
__repr__関数
repr関数とは
オブジェクトを表す公式な文字列を生成する関数のこと
らしい。
reprはrepresentationの略であり、インスタンスの情報を文字列として返す。
主にデバッグなどで使われる。今回の学習では使わない。
repr関数の使い方の例
class Person:
def __init__(self, name, age):
self._name = name
self._age = age
def __repr__(self):
return f'{self._name} {self._age}'
my_instance = Person('Alice', 10)
detail_instance = repr(my_instance)
print(f"The instance information is ({detail_instance})")
このコードではrepr
関数を使ってmy_instance
の情報を取得している。
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
self.__class__.__name__
でクラス名(Compose)を取得している。
{0}'.format(t)
でt
の情報を埋め込んでいる(c言語のprintfの%dや%sのようなもの)。
transforms.Compose
には引数としてtransforms.ToTensor
とtransforms.Normalize
の2つのオブジェクトを指定している。その二つも見ていく。
transforms.ToTensor
[docs]class ToTensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
.. note::
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
説明文を日本語訳すると、
PIL Image や numpy.ndarray をテンソルに変換します。この変換は、torchscriptをサポートしていません。
PIL Image "がモード(L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)のいずれかに属しているか、numpy.ndarrayがdtype = "np.uint8 "である場合、[0, 255]の範囲の "PIL Image "または "numpy.ndarray" (H x W x C)を、[0.0, 1.0]の範囲の形状の "torch.FloatTensor "に変換します。
それ以外の場合、テンソルはスケーリングされずに返されます。
CIFAR-10の画像はPIL形式なので、PyTorchで扱うためにはTensor型に変換しなければならない。
PILとは↓
Pythonに、各種形式の画像ファイルの読み込み・操作・保存を行う機能を提供するフリーのライブラリ。
Tensor型とは↓
Tensor型は勾配情報の保持とGPU使用が可能。
つまりPIL(画像の輝度0~255の範囲)をTensor(0.0~1.0の範囲)に変換している。
transforms.Normalize
先ほどのtransforms.ToTensor
によって画像のデータは0~1に範囲になった。
それらのデータを、より良く活用するため-1~1の範囲に正規化する必要がある。
[docs]class Normalize(torch.nn.Module):
"""Normalize a tensor image with mean and standard deviation.
This transform does not support PIL Image.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will normalize each channel of the input
``torch.*Tensor`` i.e.,
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts out of place, i.e., it does not mutate the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation in-place.
"""
def __init__(self, mean, std, inplace=False):
super().__init__()
self.mean = mean
self.std = std
self.inplace = inplace
説明文によると正規化の計算は
output[channel] = (input[channel] - mean[channel]) / std[channel]
なので今回は引数にRGB3チャンネル分で、平均を0.5、標準偏差を0.5にしている。
init 関数
class Normalize(torch.nn.Module):
def __init__(self, mean, std, inplace=False):
super().__init__()
self.mean = mean
self.std = std
self.inplace = inplace
見慣れないsuper
関数が出てきた。
super関数とはクラス(子クラス)で別のクラス(親クラス)を継承している時に、親クラスのメソッドを子クラスから呼び出す際に使う関数。
__init__
関数が被っているため親クラスの__init__
だと示す必要がある。今回だと親クラスであるtorch.nn.Module
のコンストラクタを呼び出してNormalize
だけでなくtorch.nn.Modlue
も初期化するために使用している。
ちなみに今回は使っていないが第3引数のinplace
とは
新しいオブジェクトを作成するのではなく、元のオブジェクトを更新するかどうかTrue
or False
で指定できる。
trainset & testset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
CIFAR-10のデータセットは、torchvision.datasets.CIFAR10
オブジェクトによって作成することができる。
訓練データセットのtorchvision.datasets.CIFAR10
オブジェクトが生成され、変数trainsetに格納される。
- root
指定された位置にデータを保存。
GoogleColabでの実行結果
!ls
data drive sample_date
確かにカレントディレクトリにdataが保存されていた。
-
train
Trueなら訓練データ、Falseならテストデータ。 -
download
ダウンロードするか否か指定。 -
transform
読み込んだデータセットの画像を変換するためのオブジェクトを指定。
この変数transformには、先ほど定義した変換用オブジェクトが格納してある。
trainloader & test loader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
先ほど作成したデータセットをミニバッチ単位で連続して取り出せる形式に変換する。
-
dataset
第一引数には変換したいデータセット(訓練データ、テストデータ)を渡す。 -
batch_size
ミニバッチサイズを指定する。 -
shuffle
データセットをシャッフルするかどうか指定する。
訓練データでTrueにしているのは訓練の偏りを防ぐため。
逆にテストデータはシャッフルする必要がないのでFalseにしている。 -
num_workes
デフォルトでは0になっておりミニバッチの取り出し方がSingle processになっている。
num_workrs=2
などに設定することによってmulti-process data loadingとなり、処理が高速化される。
classes
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
クラスのラベル名を、変更できない配列であるタプル形式で定義している。後々このタプルを使って整数データ(0が飛行機~9がトラック)に対応させて名前を付ける。
データ画像の出力
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
imshow関数
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# imshow(torchvision.utils.make_grid(images))
-
img = img / 2 + 0.5
の/ 2 + 2.5
って何してるの?もし計算しない場合 → 画像もラベルも正常に表示されるが
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
という文字が出てくる。意味は
RGBデータでimshowの有効範囲(浮動小数なら[0〜1]、整数なら[0〜255])に入力データをクリッピングします。
というもの。
では引数の
torchvision.utils.make_grid(images)
とは何なのか↓print(torchvision.utils.make_grid(images))
これの出力結果は
tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.1216, ..., -0.6235, 0.0000, 0.0000], ..., [ 0.0000, 0.0000, -0.0667, ..., -0.0196, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.1059, ..., -0.7647, 0.0000, 0.0000], ..., [ 0.0000, 0.0000, -0.1451, ..., -0.2235, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], [[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0667, ..., -0.8745, 0.0000, 0.0000], ..., [ 0.0000, 0.0000, -0.2784, ..., -0.4196, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]])
負の値がある。値の範囲を詳しく調べてみる
a = torchvision.utils.make_grid(images) print(len(a[a > 0.9])) print(len(a[a < -0.9])) print(len(a[a > 1.0])) print(len(a[a < -1.0]))
出力結果
174 250 0 0
ダメ押しで
torch
の関数も使ってみるa = torchvision.utils.make_grid(images) # 最大値 print(torch.max(a).item()) # 最小値 print(torch.min(a).item())
出力結果
0.9921568632125854 -0.9843137264251709
正規化しているので-1.0~1.0の範囲になっている。
なのでこれをimshowの有効範囲,浮動小数なら[0~1]に合わせるために計算した↓
\frac{-1}{2} + 0.5 = 0 \frac{1}{2} + 0.5 = 1 よって(-1.0~1.0)の範囲が(0.0~1.0)になった。
-
npimg = img.numpy()
tensor
型からndarray
型に変換している。a = torchvision.utils.make_grid(images) / 2 + 0.5 print(type(a)) b = a.numpy() print(type(b))
実行結果
<class 'torch.Tensor'> <class 'numpy.ndarray'>
-
plt.imshow(np.transpose(npimg, (1, 2, 0)))
-
np.transpose
shapeを
imshow
のために転置した。 -
plt.imshow
plt.imshow
の第1引数のイメージの取り方は
-
イメージ画像とラベルの取り出し
dataiter = iter(trainloader)
images, labels = dataiter.next()
iter
関数の解説↓
イメージ画像の出力
imshow(torchvision.utils.make_grid(images))
torchvision.utils.make_grid
って何?
他のサイトの説明だとgrid線を表示させるものらしい → grid線いらないし消してもいいのでは?
試してみた↓
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(images) # ここを変えている!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# imshow(torchvision.utils.make_grid(images)) # 変える前のコード
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
実行結果
ValueError Traceback (most recent call last)
<ipython-input-62-5b8b26759ed6> in <module>()
11 images, labels = dataiter.next()
12
---> 13 imshow(images)
14 # imshow(torchvision.utils.make_grid(images))
15 print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
2 frames
<__array_function__ internals> in transpose(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
56
57 try:
---> 58 return bound(*args, **kwds)
59 except TypeError:
60 # A TypeError occurs if the object does have such a method in its
ValueError: axes don't match array
配列の次元数が違う、とインタプリタに怒られた。
次元数を調べてみる↓
print(f"images : {images.shape}")
x = torchvision.utils.make_grid(images)
print(f"grid : {x.shape}")
実行結果
images : torch.Size([4, 3, 32, 32])
grid : torch.Size([3, 36, 138])
確かに配列の次元数が異なっていた。
どうやらplt.imshow
関数は次元が3なのでそれに合わせなくてはならなかったみたい。
よってtorchvision.utils.make_grid
はgrid線を書くだけじゃなく次元数を調整する役目があるので必須関数。
ラベルの出力
print(' '.join('%5s' % classes[labels[j]] for j in range(4))) # (5)
変数labels
の4つのラベルの値を使い、ラベル名を表示する処理。なお、タプルclassesの前にある%は剰余ではなく、文字列にタプルやリストの各要素を順に埋め込むための記号である。
join
関数とは
書式 : リストの要素(文字列に限る)をセパレータで連結した文字列を作る
セパレータ.join(iterableなオブジェクト(イテレータ))
join関数の引数はリストだけではない↓
members = ["Tom", "Jerry", "Spike"]
print(members)
name = " and ".join(members)
print(name)
name = " ".join(members)
print(name)
出力結果
['Tom', 'Jerry', 'Spike']
Tom and Jerry and Spike
Tom Jerry Spike
%
とは
c言語のprintfと似たようなもの
c言語 printf("%s", "hello");
python print('%s' % 'hello')
%5s
の5
とは?
print "|%5s|" % 'ABC' #=> | ABC| : 右寄せ5文字分
' '.join('%5s' % classes[labels[j]] for j in range(4))
はジェネレータ式と言われる書き方↓
右寄せなどを使わずにC言語っぽく書くと、今回のコードが示しているのはこれ↓
a = []
for j in range(4):
a.append(classes[labels[j]])
print(' '.join(a))
出力結果(ランダム)
car dog deer deer
以上!
これでCIFAR10のデータをだいぶ理解できたのであとは機械学習するだけですね〜
Discussion