🫕

【Instead of early stopping】"Checkpoint soup" explained

2024/06/16に公開

I'm not understanding detail of checkpoint soup, so I explain about checkpoint soup in my understaing. I would be grad if you could let me know if this article is incorrect.

Postscript:

Creating checkpoint soups follows the idea of model soups. But here, weights of the same model from different checkpoints from epochs 13-50 are averaged if they show an improvement in local CV score on one of the tracked metrics (LRAP, cMAP, F1, AUC). This led to more stable and sometimes even better LB scores.
Quote: [1] BirdCLEF2024 2nd place solution

1. What is Checkpoint soup

1.1 Premise

We have the experience that couldn't choice the epoch should be adopt appropriately.

Usually, we choice the epoch that has teh best test score(with early stopping), but it's some unstable.

1.2 Chekpoint soup

Now, checkpoint soup provide different way to adopt the epochs.
It stores some model checkpints while training, and averaging those to create a final model.
This method can sometimes yield better performance than selecting single checkpoint, as it can incorporate the strengths of multiple models trained under slightly different conditions.

2. Sample code

This is a sample code to explain it(doesn't work).

・Checkpoint soup

import torch
import torch.nn as nn
import torch.optim as optim

# Assume you have your model class defined
class YourModel(nn.Module):
    def __init__(self):
        super(YourModel, self).__init__()
        # Define your layers here

    def forward(self, x):
        # Define the forward pass
        pass

# Function to average model parameters
def average_checkpoints(checkpoints):
    avg_model = YourModel()
    avg_state_dict = avg_model.state_dict()

    for key in avg_state_dict.keys():
        avg_state_dict[key] = sum(checkpoint[key] for checkpoint in checkpoints) / len(checkpoints)
    
    avg_model.load_state_dict(avg_state_dict)
    return avg_model

# Training loop with checkpoint saving
def train_model(model, dataloader, num_epochs, checkpoint_interval):
    criterion = nn.CrossEntropyLoss()  # Example loss function
    optimizer = optim.Adam(model.parameters(), lr=0.001)  # Example optimizer
    
    checkpoints = []

    for epoch in range(num_epochs):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Save checkpoints at intervals
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint = model.state_dict()
            checkpoints.append(checkpoint)
            print(f"Checkpoint saved at epoch {epoch + 1}")

    return checkpoints

# Example usage
model = YourModel()
# Assume you have a dataloader defined
dataloader = ...
num_epochs = 50
checkpoint_interval = 10

checkpoints = train_model(model, dataloader, num_epochs, checkpoint_interval)
final_model = average_checkpoints(checkpoints)

What this is doing is simple. It take the model parameters(weight, bias) with specified interval, and averageing those parameters at last.

From this, the feature of models which trained by different order and amout input, are utilized to final model.

3. Summary

Checkpoint soup is the way to create a stable model by using ensamble like method instead of serecting single epoch.
This doesn't always work, but may provide the stability to the models.

Reference

[1] BirdCLEF2024 2nd place solution

Discussion