🦓
How to Mixup by overwrite Dataset class
1. Mixup
Mixup is an augmentation method by overlapping data and label.
If you know details, albumentation docs is well organized and easy to understand.
2. Code
Right away, let see the code.
class Dataset_sample(torch.utils.data.Dataset):
pass
# ...
def mixup(image1, image2, label1, label2, alpha=0.5):
possible_values = np.arange(0, 1.1, 0.1) # anyone 0 to 1 with 0.1 step
p = np.random.choice(possible_values)
mixed_image = p * image1 + (1 - p) * image2
mixed_label = p * label1 + (1 - p) * label2
return mixed_image.astype("int8"), mixed_label
class MixupDataset(torch.utils.data.Dataset):
def __init__(self, dataset, alpha=0.5, transform=None):
self.dataset = dataset
self.alpha = alpha
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image1, label1 = self.dataset[idx]
# print(type(image1),type(label1)) # ndarray both
random_number = np.random.rand()
# mixup
if random_number < self.alpha:
image2, label2 = self.dataset[np.random.randint(0, len(self.dataset))]
mixed_image, mixed_label = mixup(image1, image2, label1, label2, self.alpha)
# print(type(mixed_image), type(mixed_label))
if self.transform:
mixed_image = self.transform(image=mixed_image)["image"]
mixed_image = mixed_image.float()
return (mixed_image, mixed_label)
# as is
else:
if self.transform:
image1 = self.transform(image=image1)['image']
image1 = image1.float()
return (image1, label1)
This is so simple imprementation.
Mixup overlap the 2 (or more) input with multiply like 0.4 or 0.6 to each, and same ratio is multiplied to label. Then we get the mixuped input and label.
This is so easy augmentation, but it improves model performance.
It is very useful, so please try it with your environment.
Discussion