🦓

How to Mixup by overwrite Dataset class

2024/05/27に公開

1. Mixup

Mixup is an augmentation method by overlapping data and label.
If you know details, albumentation docs is well organized and easy to understand.

2. Code

Right away, let see the code.

class Dataset_sample(torch.utils.data.Dataset):
    pass
# ...

def mixup(image1, image2, label1, label2, alpha=0.5):
    
    possible_values = np.arange(0, 1.1, 0.1) # anyone 0 to 1 with 0.1 step
    p = np.random.choice(possible_values)
    mixed_image = p * image1 + (1 - p) * image2
    mixed_label = p * label1 + (1 - p) * label2
    return mixed_image.astype("int8"), mixed_label

class MixupDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, alpha=0.5, transform=None):
        self.dataset = dataset
        self.alpha = alpha
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image1, label1 = self.dataset[idx]
#         print(type(image1),type(label1)) # ndarray both
        random_number = np.random.rand()
        
        # mixup
        if random_number < self.alpha:
            image2, label2 = self.dataset[np.random.randint(0, len(self.dataset))]
            mixed_image, mixed_label = mixup(image1, image2, label1, label2, self.alpha)
#             print(type(mixed_image), type(mixed_label))
            if self.transform:
                mixed_image = self.transform(image=mixed_image)["image"]
                mixed_image = mixed_image.float()
            return (mixed_image, mixed_label)
        # as is
        else:
            if self.transform:
                image1 = self.transform(image=image1)['image']
                image1 = image1.float()
            return (image1, label1)

This is so simple imprementation.
Mixup overlap the 2 (or more) input with multiply like 0.4 or 0.6 to each, and same ratio is multiplied to label. Then we get the mixuped input and label.

This is so easy augmentation, but it improves model performance.
It is very useful, so please try it with your environment.

Reference

(1) MixUp transform in Albumentations

Discussion