
【Knowledge Distillation explained】Part1: Intro ~ DataLoading


First, I recommend watching the pytorch tutorial. It's easy to understand.

I'll write breakdown of that tutorial. If I have a time, add more.

1. Introduction

1.1 Knowledge Distillation

Knowledge Distillation is a method to improve small model performance by using a similar larger model.
A small model has feature that fast prediction and low memory usage, and It is useful when creating a model in stricted environment like data competition.

1.2 Concrete way

Knowledge Distillation(Call it KD from here on) use soft-target, that was named against normal target.
Assuming we call the small model the student model and call the large model the teacher model, the example of KD is:

Then we have two different losses, The nomarl loss(like crossentropy or others) and soft-target loss(from applying softmax to output of each model), and use those in composed with each weight.

# Weighted sum of the two losses
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

So, Let's see more detail code.

2. Import

Import related librarys.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

3. Loading CIFAR-10

Use CIFAR-10 as a dataset. It is a popular dataset with ten calsses.
We have to predict one class for each input image.

・Loading CIFAR-10

# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transforms_cifar = transforms.Compose([
    # convert to tensor
    # pre-computed value for normalize
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

# Loading the CIFAR-10 dataset:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)

The shape of import RGB data is 3 x 32 x 32 = 3072 numbers ranging from 0 to 255.
We download CIFAR-10 datasets by pytorch datasets method with normalization. The value that use for nomarization is already calculated from train_dataset. Here, what you should be careful about is using same value when normalize test dataset, in spite of those mean and std are from only train_data.
The reason is, we can assume that train data has enough generality, and it's difficult to get such statistics in real situation.

In here, briefly define dataloader

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

Up to here for this time, I'll continue next time.
Thank you for reading.


(1)PyTorch Tutorial
(2)Kaggle BirdCREF 4th solution
