📷

CIFAR10の読み込みと正規化をソースコードベースで解説

2022/09/25に公開約17,900字

概要

Pytorchのチュートリアル「TRAINING A CLASSIFIER」(日本語名:クラス分類の学習方法)の最初の章、「CIFAR10の読み込みと正規化」のみ解説していきます。

Pytorchのチュートリアル↓

https://colab.research.google.com/github/YutaroOgawa/pytorch_tutorials_jp/blob/main/notebook/1_Learning PyTorch/1_4_cifar10_tutorial_jp.ipynb#scrollTo=2UbJrC7Fq43T

このチュートリアルは「Deep Learning with PyTorch: A 60 Minute Blitz」の最終章であり、これを理解できればPyTorchを理解したと言っても過言ではないわけです(いや、過言ですね)。

ネットワークの定義や実行などは細かく解説されているサイトが多いですが前処理の部分をソースコードまで読んで細かく解説している人がいなかったので備忘録も兼ねて書いていきます。
C言語しか勉強してないので変に細かいところまで書いていますがご了承ください。

モジュールの読み込み

まずは必要なモジュールを読み込む。

import torch
import torchvision
import torchvision.transforms as transforms

PyTorchをインポートする際はPyTorchではなくtorchとします。
torchvisionは画像のデータセットの処理を、
torchvision.transformsはデータセットの変換などを行うモジュールです。
torchvision.transformstransformsとしてインポートしています。

データセットの前処理

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                          shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

6つインスタンス化している。長いのでコードを分割して解説。

transform

まずは最初の3行、

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

これは次の行'trainset'

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

torchvision.datasets.CIFAR10transform引数に渡すためのオブジェクトになっている。
transformは読み込んだデータセットの画像を変換するためのオブジェクトで、それを先ほどの3行で定義している。

transforms.Compose

この3行のメインはtransforms.Composeオブジェクト。
コンストラクタの引数に、2種類の変換用オブジェクトを指定しているという入れ子構造になっている。

ソースコードを読んでtransforms.Composeクラスが何をしているのか見ていく。

ソースコード

[docs]class Compose:
    """Composes several transforms together. This transform does not support torchscript.
    Please, see the note below.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
        >>> ])

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

特殊メソッドが3つ。

__init__関数

def __init__(self, transforms):
    self.transforms = transforms

self.transformsに変換用オブジェクトtransformsを格納している。

__call__関数

call関数の説明↓

https://techacademy.jp/magazine/50298
def __call__(self, img):
    for t in self.transforms:
        img = t(img)
    return img

このコードから、コンストラクタで格納したtransformstに格納して、先頭から順に変換処理が行われていることがわかる。

__repr__関数

repr関数とは

オブジェクトを表す公式な文字列を生成する関数のこと

らしい。
reprはrepresentationの略であり、インスタンスの情報を文字列として返す。
主にデバッグなどで使われる。今回の学習では使わない

repr関数の使い方の例
class Person:
    def __init__(self, name, age):
        self._name = name
        self._age = age
    
    def __repr__(self):
        return f'{self._name} {self._age}'


my_instance = Person('Alice', 10)
detail_instance = repr(my_instance)
print(f"The instance information is ({detail_instance})")

このコードではrepr関数を使ってmy_instanceの情報を取得している。

def __repr__(self):
    format_string = self.__class__.__name__ + '('
    for t in self.transforms:
        format_string += '\n'
        format_string += '    {0}'.format(t)
    format_string += '\n)'
    return format_string

self.__class__.__name__でクラス名(Compose)を取得している。

https://lightgauge.net/language/python/8516/

{0}'.format(t)tの情報を埋め込んでいる(c言語のprintfの%dや%sのようなもの)。

https://bbh.bz/2019/11/21/python-format-func/

transforms.Composeには引数としてtransforms.ToTensortransforms.Normalizeの2つのオブジェクトを指定している。その二つも見ていく。

transforms.ToTensor

ソースコード

[docs]class ToTensor:
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.

    .. note::
        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
        transforming target image masks. See the `references`_ for implementing the transforms for image masks.

    .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'

説明文を日本語訳すると、

PIL Image や numpy.ndarray をテンソルに変換します。この変換は、torchscriptをサポートしていません。

PIL Image "がモード(L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)のいずれかに属しているか、numpy.ndarrayがdtype = "np.uint8 "である場合、[0, 255]の範囲の "PIL Image "または "numpy.ndarray" (H x W x C)を、[0.0, 1.0]の範囲の形状の "torch.FloatTensor "に変換します。
それ以外の場合、テンソルはスケーリングされずに返されます。

CIFAR-10の画像はPIL形式なので、PyTorchで扱うためにはTensor型に変換しなければならない。

PILとは↓
Pythonに、各種形式の画像ファイルの読み込み・操作・保存を行う機能を提供するフリーのライブラリ。

https://ja.wikipedia.org/wiki/Python_Imaging_Library#:~:text=Python Imaging Library(略称 PIL,利用することができる。

Tensor型とは↓
Tensor型は勾配情報の保持とGPU使用が可能。

https://qiita.com/mathlive/items/241bfb42d852bb801b96

つまりPIL(画像の輝度0~255の範囲)をTensor(0.0~1.0の範囲)に変換している。

transforms.Normalize

先ほどのtransforms.ToTensorによって画像のデータは0~1に範囲になった。
それらのデータを、より良く活用するため-1~1の範囲に正規化する必要がある。

ソースコード

[docs]class Normalize(torch.nn.Module):
    """Normalize a tensor image with mean and standard deviation.
    This transform does not support PIL Image.
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``

    .. note::
        This transform acts out of place, i.e., it does not mutate the input tensor.

    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
        inplace(bool,optional): Bool to make this operation in-place.

    """

    def __init__(self, mean, std, inplace=False):
        super().__init__()
        self.mean = mean
        self.std = std
        self.inplace = inplace

説明文によると正規化の計算は

output[channel] = (input[channel] - mean[channel]) / std[channel]

なので今回は引数にRGB3チャンネル分で、平均を0.5、標準偏差を0.5にしている。

init 関数

class Normalize(torch.nn.Module):
    def __init__(self, mean, std, inplace=False):
        super().__init__()
        self.mean = mean
        self.std = std
        self.inplace = inplace

見慣れないsuper関数が出てきた。

super関数とはクラス(子クラス)で別のクラス(親クラス)を継承している時に、親クラスのメソッドを子クラスから呼び出す際に使う関数。

https://techacademy.jp/magazine/28283

__init__関数が被っているため親クラスの__init__だと示す必要がある。今回だと親クラスであるtorch.nn.Moduleのコンストラクタを呼び出してNormalizeだけでなくtorch.nn.Modlueも初期化するために使用している。

ちなみに今回は使っていないが第3引数のinplaceとは
新しいオブジェクトを作成するのではなく、元のオブジェクトを更新するかどうかTrue or Falseで指定できる。

https://stackoverflow.com/questions/29603510/mutable-objects-changed-in-place-means-what

trainset & testset

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

CIFAR-10のデータセットは、torchvision.datasets.CIFAR10オブジェクトによって作成することができる。
訓練データセットのtorchvision.datasets.CIFAR10オブジェクトが生成され、変数trainsetに格納される。

  • root
    指定された位置にデータを保存。
GoogleColabでの実行結果
!ls
data drive sample_date

確かにカレントディレクトリにdataが保存されていた。

  • train
    Trueなら訓練データ、Falseならテストデータ。

  • download
    ダウンロードするか否か指定。

  • transform
    読み込んだデータセットの画像を変換するためのオブジェクトを指定。
    この変数transformには、先ほど定義した変換用オブジェクトが格納してある。

trainloader & test loader

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                          shuffle=False, num_workers=2)

先ほど作成したデータセットをミニバッチ単位で連続して取り出せる形式に変換する。

  • dataset
    第一引数には変換したいデータセット(訓練データ、テストデータ)を渡す。

  • batch_size
    ミニバッチサイズを指定する。

  • shuffle
    データセットをシャッフルするかどうか指定する。
    訓練データでTrueにしているのは訓練の偏りを防ぐため。
    逆にテストデータはシャッフルする必要がないのでFalseにしている。

  • num_workes
    デフォルトでは0になっておりミニバッチの取り出し方がSingle processになっている。
    num_workrs=2などに設定することによってmulti-process data loadingとなり、処理が高速化される。

classes

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

クラスのラベル名を、変更できない配列であるタプル形式で定義している。後々このタプルを使って整数データ(0が飛行機~9がトラック)に対応させて名前を付ける。

データ画像の出力

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(trainloader)
images, labels = dataiter.next()

imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

imshow関数

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    
# imshow(torchvision.utils.make_grid(images))
  • img = img / 2 + 0.5/ 2 + 2.5って何してるの?

    もし計算しない場合 → 画像もラベルも正常に表示されるが

    Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
    

    という文字が出てくる。意味は

    RGBデータでimshowの有効範囲(浮動小数なら[0〜1]、整数なら[0〜255])に入力データをクリッピングします。

    というもの。

    では引数のtorchvision.utils.make_grid(images)とは何なのか↓

    print(torchvision.utils.make_grid(images))
    

    これの出力結果は

    tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.1216,  ..., -0.6235,  0.0000,  0.0000],
             ...,
             [ 0.0000,  0.0000, -0.0667,  ..., -0.0196,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
    
            [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.1059,  ..., -0.7647,  0.0000,  0.0000],
             ...,
             [ 0.0000,  0.0000, -0.1451,  ..., -0.2235,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
    
            [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0667,  ..., -0.8745,  0.0000,  0.0000],
             ...,
             [ 0.0000,  0.0000, -0.2784,  ..., -0.4196,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
             [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
    

    負の値がある。値の範囲を詳しく調べてみる

    a = torchvision.utils.make_grid(images)
    print(len(a[a > 0.9]))
    print(len(a[a < -0.9]))
    
    print(len(a[a > 1.0]))
    print(len(a[a < -1.0]))
    

    出力結果

    174
    250
    0
    0
    

    ダメ押しでtorchの関数も使ってみる

    a = torchvision.utils.make_grid(images)
    # 最大値
    print(torch.max(a).item())
    # 最小値
    print(torch.min(a).item())
    

    出力結果

    0.9921568632125854
    -0.9843137264251709
    

    正規化しているので-1.0~1.0の範囲になっている。

    なのでこれをimshowの有効範囲,浮動小数なら[0~1]に合わせるために計算した↓

    \frac{-1}{2} + 0.5 = 0

    \frac{1}{2} + 0.5 = 1

    よって(-1.0~1.0)の範囲が(0.0~1.0)になった。

  • npimg = img.numpy()

    tensor型からndarray型に変換している。

    a = torchvision.utils.make_grid(images) / 2 + 0.5
    print(type(a))
    b = a.numpy()
    print(type(b))
    

    実行結果

    <class 'torch.Tensor'>
    <class 'numpy.ndarray'>
    
  • plt.imshow(np.transpose(npimg, (1, 2, 0)))

イメージ画像とラベルの取り出し

dataiter = iter(trainloader)
images, labels = dataiter.next()

iter関数の解説↓

https://techacademy.jp/magazine/28379

イメージ画像の出力

imshow(torchvision.utils.make_grid(images))

torchvision.utils.make_gridって何?

他のサイトの説明だとgrid線を表示させるものらしい → grid線いらないし消してもいいのでは?

試してみた↓

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(trainloader)
images, labels = dataiter.next()

imshow(images)  # ここを変えている!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# imshow(torchvision.utils.make_grid(images))  # 変える前のコード
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

実行結果

ValueError                                Traceback (most recent call last)
<ipython-input-62-5b8b26759ed6> in <module>()
     11 images, labels = dataiter.next()
     12 
---> 13 imshow(images)
     14 # imshow(torchvision.utils.make_grid(images))
     15 print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

2 frames
<__array_function__ internals> in transpose(*args, **kwargs)

/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     56 
     57     try:
---> 58         return bound(*args, **kwds)
     59     except TypeError:
     60         # A TypeError occurs if the object does have such a method in its

ValueError: axes don't match array

配列の次元数が違う、とインタプリタに怒られた。

次元数を調べてみる↓

print(f"images : {images.shape}")
x = torchvision.utils.make_grid(images)
print(f"grid : {x.shape}")

実行結果

images : torch.Size([4, 3, 32, 32])
grid : torch.Size([3, 36, 138])

確かに配列の次元数が異なっていた。

どうやらplt.imshow関数は次元が3なのでそれに合わせなくてはならなかったみたい。

よってtorchvision.utils.make_gridはgrid線を書くだけじゃなく次元数を調整する役目があるので必須関数。

ラベルの出力

print(' '.join('%5s' % classes[labels[j]] for j in range(4)))  # (5)

変数labelsの4つのラベルの値を使い、ラベル名を表示する処理。なお、タプルclassesの前にある%は剰余ではなく、文字列にタプルやリストの各要素を順に埋め込むための記号である。

join関数とは
書式 : リストの要素(文字列に限る)をセパレータで連結した文字列を作る
セパレータ.join(iterableなオブジェクト(イテレータ))

join関数の引数はリストだけではない↓

https://qiita.com/conf8o/items/d57f74b4bcb67882be37
members = ["Tom", "Jerry", "Spike"]
print(members)
name = " and ".join(members)
print(name)
name = " ".join(members)
print(name)

出力結果

['Tom', 'Jerry', 'Spike']
Tom and Jerry and Spike
Tom Jerry Spike

%とは
c言語のprintfと似たようなもの

c言語   printf("%s", "hello");
python  print('%s' % 'hello')

https://qiita.com/takahiro_itazuri/items/e585b46d096036bc837f

%5s5とは?

print "|%5s|" % 'ABC'        #=> |  ABC| : 右寄せ5文字分

https://www.tohoho-web.com/python/types.html

' '.join('%5s' % classes[labels[j]] for j in range(4))はジェネレータ式と言われる書き方↓

https://qiita.com/conf8o/items/d57f74b4bcb67882be37#ジェネレータ式

右寄せなどを使わずにC言語っぽく書くと、今回のコードが示しているのはこれ↓

a = []
for j in range(4):
    a.append(classes[labels[j]])
print(' '.join(a))

出力結果(ランダム)

car dog deer deer

Discussion

ログインするとコメントできます