🐠

【ML】What is EMA in machine learning fields

2024/08/04に公開

1. Exponential Moving Average (EMA)

EMA is used in many places to indicate the value moving in easy to understand, in here, I explain about EMA for stability of machine learning models.

1.1 Purpose

EMA is used to stabilize and smooth the model’s parameter updates over time.

1.2 How It Works

  1. Initialize the values of ema values same as model weights
  2. During training, an EMA of the model’s parameters is calculated and maintained alongside the regular model parameters at each epoch.
  3. At evaluation or inference time, the EMA values can be used instead of the current model parameters to obtain more stable and potentially better-performing results.
    Typically, the ema values at the end of the epoch are used as model weights.

1.3 EMA Update formula

・Formula
\text{EMA}_{new} = decay \times \text{EMA}_{current} + (1-decay) \times \text{Weights}_{current}

・In code

def update(self, model):
    self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

e: Represents the current EMA value (ema_v in the code).
m: Represents the current model parameter value (model_v in the code).
self.decay: The decay factor, typically a value close to 1 (e.g., 0.9999).

2. Practical code

・Model with EMA

import torch
import torch.nn as nn
from copy import deepcopy

class ModelEmaV2(nn.Module):
    def __init__(self, model, decay=0.9999, device=None):
        super(ModelEmaV2, self).__init__()
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)
        self.backup = {}

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

    def apply_shadow(self):
        """Save current model parameters and replace them with EMA parameters."""
        self.backup = {name: param.data.clone() for name, param in self.module.state_dict().items()}
        for name, param in self.module.named_parameters():
            if param.requires_grad:
                param.data.copy_(self.module.state_dict()[name])

    def restore(self):
        """Restore the original model parameters."""
        for name, param in self.module.named_parameters():
            if param.requires_grad:
                param.data.copy_(self.backup[name])
        self.backup = {}

・Usage Example

import torch.optim as optim

# Initialize model and optimizer
model = YourModel()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)
ema = ModelEmaV2(model, decay=0.9999)
ema.set(model)  # Initialize EMA with the current model parameters

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    for data, target in dataloader:
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        ema.update(model)  # Update EMA after optimizer step

    # Optionally evaluate using EMA parameters
    ema.apply_shadow()
    model.eval()
    evaluate_model()
    ema.restore()

・The model parameters are updated using the optimizer. After each update, the EMA values are updated with ema.update(model).
・Use ema.apply_shadow() to replace the model's parameters with EMA values for evaluation, and ema.restore() to revert to the original parameters afterward.

This approach ensures that the model benefits from the stability of EMA while still being trained with regular updates.

3. Summary

This time, I explained about EMA in machine learning fields.
EMA gives more stability to the model by smoothing model weights. Please try it if you are interested.

Reference

[1] timmdocs, Model EMA (Exponential Moving Average)

Discussion