🧸

【ML】One example of TTA with PyTorch

2024/08/07に公開

This time, I'll explain TTA(test time augmentation) with PyTorch.

1. What is TTA?

TTA(test time augmentation) is a technique that uses augmentation when inference.
For example, please think that using an image and a horizontally flipped image as input of inference, then, the output of inference is more stable and reduced Overfitting than single image input. Like this, the TTA means the method that applies augmentations when inference(typically used the same as training).

2. How to do?

First, show the pseudo-code.
Using PyTroch and Albumentations.

・Pseudo code

val_dataset = Example_Dataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5_COMBINED, transform=val_transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)
if CFG.TTA:
    # Diffenrence is only "transform=val_transform_TTA" of dataset.
    val_dataset_TTA = Example_Dataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5_COMBINED, transform=val_transform_TTA)
    val_loader_TTA = torch.utils.data.DataLoader(val_dataset_TTA, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)

# Caluclate oof twice(normal and with TTA)
for fold_id in range(5):
    val_pred = run_inference_loop(model, val_loader, device)
    val_idx = train_meta[train_meta["fold"] == fold_id].index.values
    oof_pred_arr[val_idx] = val_pred
    if CFG.TTA:
        val_pred_TTA = run_inference_loop(model, val_loader_TTA, device)
        val_idx_TTA = train_meta[train_meta["fold"] == fold_id].index.values
        oof_pred_arr_TTA[val_idx_TTA] = val_pred_TTA

# merge each predictions with a ratio as you like
oof_pred_arr = (oof_pred_arr * 0.7) + (oof_pred_arr_TTA * 0.3)

Roughly, there are three steps.

  1. Define dataloader normally and for TTA(contains augmentation(transform))
  2. Infer twice with normal and TTA and save each prediction array(output of model)
  3. Merging each prediction array with a ratio as you like

This is the flow of TTA, you can also apply another augmentation to the pipeline and increase models that will be merging, but note that it causes inferring time to increase.

Whole code
if not CFG.is_infer:

    # Drop duplicate rows based on the specified column
    train_meta = train_meta.drop_duplicates(subset=[column_to_check], keep='first').reset_index(drop=True)

    # label_arr = train[CLASSES].values
    oof_pred_arr = np.zeros((len(train_meta), CFG.n_classes-1))
    if CFG.TTA:
        oof_pred_arr_TTA = np.zeros((len(train_meta), CFG.n_classes-1))
    score_list = []

    for fold_id in range(CFG.n_folds):
        print(f"\n[fold {fold_id}]")
        device = torch.device(CFG.device)

        # get_dataloader
        val_transform_TTA, val_transform = get_transforms()
        if CFG.alldata_archive:
            val_dataset = Example_Dataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5_COMBINED, transform=val_transform)
            val_loader = torch.utils.data.DataLoader(
                val_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)
            if CFG.TTA:
                val_dataset_TTA = Example_Dataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5_COMBINED, transform=val_transform_TTA)
                val_loader_TTA = torch.utils.data.DataLoader(
                    val_dataset_TTA, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)
        else:
            val_dataset = Example_Dataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5, transform=val_transform)
            val_loader = torch.utils.data.DataLoader(
                val_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)
            if CFG.TTA:
                val_dataset_TTA = Example_Dataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5, transform=val_transform_TTA)
                val_loader_TTA = torch.utils.data.DataLoader(
                    val_dataset_TTA, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)

        # # get model
        model_path = CFG.OUTPUT_DIR / f"best_model_fold{fold_id}.pth"
        model = timmModel(
            model_name=CFG.model_name, 
            pretrained=False, 
            in_channels=6,
            num_classes=0,
            is_training=False
        )
        model.initialize_dummy()  # Initialize with dummy data for dynamic linear
        model.load_state_dict(torch.load(model_path, map_location=device))

        # # inference
        val_pred = run_inference_loop(model, val_loader, device)
        val_idx = train_meta[train_meta["fold"] == fold_id].index.values
        oof_pred_arr[val_idx] = val_pred
        if CFG.TTA:
            val_pred_TTA = run_inference_loop(model, val_loader_TTA, device)
            val_idx_TTA = train_meta[train_meta["fold"] == fold_id].index.values
            oof_pred_arr_TTA[val_idx_TTA] = val_pred_TTA

        del val_idx
        del model, val_loader
        torch.cuda.empty_cache()
        gc.collect()
    
    if CFG.TTA:
        # ref. CFG.TTA_rate = {'None':0.7, 'with_train_aug':0.3}
        oof_pred_arr = (oof_pred_arr * CFG.TTA_rate['None']) + (oof_pred_arr_TTA * CFG.TTA_rate['with_train_aug'])

3. Summary

This time, I introduced one example of how to apply the TTA.
Please try it.

Discussion