⚗️

【Knowledge Distillation explained】Part1: Intro ~ DataLoading

2024/04/24に公開

First, I recommend watching the pytorch tutorial. It's easy to understand.

I'll write breakdown of that tutorial. If I have a time, add more.

1. Introduction

1.1 Knowledge Distillation

Knowledge Distillation is a method to improve small model performance by using a similar larger model.
A small model has feature that fast prediction and low memory usage, and It is useful when creating a model in stricted environment like data competition.

1.2 Concrete way

Knowledge Distillation(Call it KD from here on) use soft-target, that was named against normal target.
Assuming we call the small model the student model and call the large model the teacher model, the example of KD is:
・Example

Then we have two different losses, The nomarl loss(like crossentropy or others) and soft-target loss(from applying softmax to output of each model), and use those in composed with each weight.

# Weighted sum of the two losses
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
loss.backward()

So, Let's see more detail code.

2. Import

Import related librarys.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

3. Loading CIFAR-10

Use CIFAR-10 as a dataset. It is a popular dataset with ten calsses.
We have to predict one class for each input image.

・Loading CIFAR-10

# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
    # convert to tensor
    transforms.ToTensor(),
    # pre-computed value for normalize
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

The shape of import RGB data is 3 x 32 x 32 = 3072 numbers ranging from 0 to 255.
We download CIFAR-10 datasets by pytorch datasets method with normalization. The value that use for nomarization is already calculated from train_dataset. Here, what you should be careful about is using same value when normalize test dataset, in spite of those mean and std are from only train_data.
The reason is, we can assume that train data has enough generality, and it's difficult to get such statistics in real situation.

In here, briefly define dataloader
・DataLoader

#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

Up to here for this time, I'll continue next time.
Thank you for reading.

Reference

(1)PyTorch Tutorial
(2)Kaggle BirdCREF 4th solution

Discussion