🎏

【ML Method】2-stage learning explained with similar method

2024/04/27に公開

I saw a representation like 2-stage learing. But I can't understand that meaning then, So I explaine about it for a someone like me.

In my survey, this word seems to has various meaning in some field.
I focus to 2-stage learning in machine learning field, which has meaning like fine-tuning or transfar learning.

1. 2-stage learning

2-stage learning is a method has divided learning phase. We consider two phase as pre-training stage and fine-tuning stage.

1.1 pre-training stage

pre-training model using vast and varius dataset and learn genellary feature of those. This helps fine-runing model understand where is be attentioned.

Examples of pre-trained models include ImageNet for computer vision tasksand BERT for natunal language processing tasks, If you wanna train model with your domain dataset, you can use whole data in this phase(and only use precise dataset in fine-tuning stage)

1.2 fine-tuning stage

In the fine-tuning stage, the pre-trained model is adapted to a specific target task using a smaller dataset.

The pre-trained model's architecture is typically modified by replacing the last few layers with task-specific layers while keeping the earlier layers fixed or partially frozen. Those weights serve as a good initialization ofr the fine-tuning process, enabling faster convergence and beter generalization.

It is often performed with a lower learning rate to prevent the pre-trainedduring pre-training.
The fine-tuned model can achieve high performance on the target task, even with limited labeled data, by leveraging the knowledge learned during pre-training.

2. Example(ResNet18 with CIFER-10)

So, let's try it simply with pytorch. We compare 2 pattern that one is using pre-traned model(with ImageNet) and another not, also each model predict same dataset(CIFER-10).

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader

PreTrain = False

# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Stage 1: 
if PreTrain:
    pretrained_model = models.resnet18(pretrained=True)
    pretrained_model.to(device)  # Move the model to the GPU
    # Freeze the weights of the pre-trained layers
    for param in pretrained_model.parameters():
        param.requires_grad = False
else:
    pretrained_model = models.resnet18(weights=None)
    pretrained_model.to(device)  # Move the model to the GPU


# Replace the last fully connected layer with a new one
num_classes = 10  # Number of classes in CIFAR-10
pretrained_model.fc = nn.Linear(pretrained_model.fc.in_features, num_classes)
pretrained_model.fc.to(device)  # Move the new fully connected layer to the GPU

# Stage 2: Fine-tuning
# Prepare the CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the images to match the input size of ResNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the images
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Set the model to training mode
pretrained_model.train()

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss().to(device)  # Move the loss function to the GPU
optimizer = optim.Adam(pretrained_model.fc.parameters(), lr=0.001)

# Fine-tune the model
num_epochs = 10
for epoch in range(num_epochs):
    for images, labels in train_loader:
        images = images.to(device)  # Move the input data to the GPU
        labels = labels.to(device)  # Move the labels to the GPU

        # Forward pass
        outputs = pretrained_model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

# Save the fine-tuned model
torch.save(pretrained_model.state_dict(), 'fine_tuned_model.pth')

・Output

# Nromal learning
Epoch [1/10], Loss: 1.9358
Epoch [2/10], Loss: 1.6871
Epoch [3/10], Loss: 1.5769
Epoch [4/10], Loss: 1.8629
Epoch [5/10], Loss: 2.1075
Epoch [6/10], Loss: 1.7070
Epoch [7/10], Loss: 1.6018
Epoch [8/10], Loss: 2.0057
Epoch [9/10], Loss: 1.6282
Epoch [10/10], Loss: 1.7061

# 2-stage learning (preptrained with Imagenet, freeze wight without fc layer)
Epoch [1/10], Loss: 0.6391
Epoch [2/10], Loss: 0.8150
Epoch [3/10], Loss: 0.6262
Epoch [4/10], Loss: 0.3903
Epoch [5/10], Loss: 0.5749
Epoch [6/10], Loss: 0.5714
Epoch [7/10], Loss: 0.3196
Epoch [8/10], Loss: 0.3552
Epoch [9/10], Loss: 0.8488
Epoch [10/10], Loss: 0.6973

The result is obiously, 2-stage learning achieved great score. It is very useful with many tasks.

Additionaly, I tried 2-stage learning that don't freeze wight(It mean using pre-trained weights as initial weights), and output is here.

# 2-stage learning (preptrained with Imagenet, relearning all weights)
Epoch [1/10], Loss: 1.0034
Epoch [2/10], Loss: 0.6238
Epoch [3/10], Loss: 0.6553
Epoch [4/10], Loss: 0.4157
Epoch [5/10], Loss: 0.8527
Epoch [6/10], Loss: 0.3583
Epoch [7/10], Loss: 0.5348
Epoch [8/10], Loss: 1.1739
Epoch [9/10], Loss: 0.7181
Epoch [10/10], Loss: 0.7674

It doesn't achieve score better. We understand that shouldn't change weights value from pretrained in conbination of resnet18 with ImageNet and CIFER-10(may be overlearning?).

※Here, We use CIFER-10 dataset to be predicted. When you wanna use your dataset, you have to prepare dataset like this format.

format
root/
├── train/
│   ├── class1/
│   │   ├── image1.jpg
│   │   ├── image2.jpg
│   │   └── ...
│   ├── class2/
│   │   ├── image1.jpg
│   │   ├── image2.jpg
│   │   └── ...
│   └── ...
├── val/
│   ├── class1/
│   │   ├── image1.jpg
│   │   ├── image2.jpg
│   │   └── ...
│   ├── class2/
│   │   ├── image1.jpg
│   │   ├── image2.jpg
│   │   └── ...
│   └── ...
└── test/
    ├── class1/
    │   ├── image1.jpg
    │   ├── image2.jpg
    │   └── ...
    ├── class2/
    │   ├── image1.jpg
    │   ├── image2.jpg
    │   └── ...
    └── ...

3. Similar Method

The 2-stage learning (transfer learning or fine-tuning) is common method, and there are some slimilar methods, so I'll introduce some of those.

3.1 Knowledge Distillation

The object of Knowledge Distillation is making smaller model that has almost same performance and can predict more faster, more efficient.

The conception of it is copying weights of specific layer, for example,layer behind softmax.
But unlike 2-stage learning, model architecture is changed on each phase. So we can't use weights as is, it need some technic like using combine loss that mixed ordinaly loss and loss related to difference between pretrained layer and small model layer.

3.2 Few-Shot Learning

few-shot learning focuses on learning from a limited number of labled examples per class. It aims to quickly adapt a model to recognize new classes with just a few examples.

The model is typically pre-trained on a large dataset and then fine-tuned on a small dataset containing a few classes. In this time, aline distribution of number of each classes helps learning more better model(typically, imbalanced data has a bad effect).

3.3 Self-Supervised Learning

Self-supercised learning is a technique where the model learns from unlabeled data by solving pretext tasks.

Pretext tasks are designed to capture meaningful representations and patterns in the data without requiring explicit labels. Typically, those are generated from large unlabeled datasets.
Example:
・Predicting missing parts of an image
・Colorizing grayscale images
・Predicting the next word in a sentence
and others tailored to each task...

After pre-training, the model can be fine-tuned on a downstream task using labeled data, leveraging the learned representaions

3.4 Multitask Learning

Multitask learning involves training a model to solve multiple related tasks simultaneously. The idea i dthat by sharing representations and learning across tasks, the model can improve its generalization and efficiency.

The model is trained on an dataset on a daatset that contains multiple tasks, often with shared input features but different output targets. During training, the model learns to jointly optimize the objectives of all tasks, allowing for knowledge transfer and regularization.

Multitask learning can be seen as a form of iductive transfer, where knowledge from one task helps imporve performance on other tasks.

To treast multi domain data, there is no need to rely solely on this method. We can also emsamble prediction from some models that specialized for a specific input domain.

Discussion