🐥
CNNのデータ拡張による過学習防止
CNNの実装
こちらに単純なCNNの実装を用意しました。
CNNのメインとなる畳み込み演算は、VGG16を簡素にしたものです。
class CNN(nn.Module):
def __init__(self, num_classes):
super().__init__()
# torch.Size([32, 3, 32, 32])
self.features = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
self.flatten = nn.Flatten()
self.classifier = nn.Linear(in_features=4096, out_features=num_classes)
def forward(self, x):
x = self.features(x) # torch.Size([32, 256, 4, 4])
x = self.flatten(x) # torch.Size([32, 4096])
x = self.classifier(x)
return x
データセットにはCIFAR10を利用しました
##### 学習データ/テストデータ作成
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 正解ラベル
この実装で学習を進めると、訓練データに対する精度は上がるのですが、テストデータに対する精度は頭打ちになってしまいます。
lossも増えており、過学習となっていることがわかります。
過学習を抑えるために、データ拡張を行います。
データ拡張
transforms.RandomHorizontalFlip(), ColorJitter(), RandomRotation(10), でデータ拡張をします。
データ拡張はデータの水増しで、RandomHorizontalFlip()でランダムな左右逆転、ColorJitter()で明るさやコントラストの変更、RandomRotation(10)で回転を行います。
これにより、エポックごとに別の画像が生成され、過学習を抑えることが期待できます。
データセット作成時のコードを少し変更します。
ソースコード全文はこちらです。
##### 学習データ/テストデータ作成
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # Data Augmentation
transforms.ColorJitter(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 正解ラベル
学習アルゴリズムなどは変更せず学習させます。
すると、テストデータに対する精度が以下の画像のように改善します。
Discussion