🐥

CNNのデータ拡張による過学習防止

2023/02/05に公開

CNNの実装

こちらに単純なCNNの実装を用意しました。
CNNのメインとなる畳み込み演算は、VGG16を簡素にしたものです。

class CNN(nn.Module):
  def __init__(self, num_classes):
    super().__init__()
    # torch.Size([32, 3, 32, 32])
    self.features = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=2),

        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=2),

        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=2)
    )
    self.flatten = nn.Flatten()
    self.classifier = nn.Linear(in_features=4096, out_features=num_classes)

  def forward(self, x):
    x = self.features(x) # torch.Size([32, 256, 4, 4])
    x = self.flatten(x) # torch.Size([32, 4096])
    x = self.classifier(x)
    return x

データセットにはCIFAR10を利用しました

##### 学習データ/テストデータ作成
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 正解ラベル

この実装で学習を進めると、訓練データに対する精度は上がるのですが、テストデータに対する精度は頭打ちになってしまいます。

lossも増えており、過学習となっていることがわかります。

過学習を抑えるために、データ拡張を行います。

データ拡張

transforms.RandomHorizontalFlip(), ColorJitter(), RandomRotation(10), でデータ拡張をします。
データ拡張はデータの水増しで、RandomHorizontalFlip()でランダムな左右逆転、ColorJitter()で明るさやコントラストの変更、RandomRotation(10)で回転を行います。
これにより、エポックごとに別の画像が生成され、過学習を抑えることが期待できます。

データセット作成時のコードを少し変更します。
ソースコード全文はこちらです。

##### 学習データ/テストデータ作成
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(), # Data Augmentation
    transforms.ColorJitter(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 正解ラベル

学習アルゴリズムなどは変更せず学習させます。
すると、テストデータに対する精度が以下の画像のように改善します。

Discussion