【ML】One example of TTA with PyTorch
This time, I'll explain TTA(test time augmentation) with PyTorch.
1. What is TTA?
TTA(test time augmentation) is a technique that uses augmentation when inference.
For example, please think that using an image and a horizontally flipped image as input of inference, then, the output of inference is more stable and reduced Overfitting than single image input. Like this, the TTA means the method that applies augmentations when inference(typically used the same as training).
2. How to do?
First, show the pseudo-code.
Using PyTroch and Albumentations.
・Pseudo code
val_dataset = Example_Dataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5_COMBINED, transform=val_transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)
if CFG.TTA:
# Diffenrence is only "transform=val_transform_TTA" of dataset.
val_dataset_TTA = Example_Dataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5_COMBINED, transform=val_transform_TTA)
val_loader_TTA = torch.utils.data.DataLoader(val_dataset_TTA, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)
# Caluclate oof twice(normal and with TTA)
for fold_id in range(5):
val_pred = run_inference_loop(model, val_loader, device)
val_idx = train_meta[train_meta["fold"] == fold_id].index.values
oof_pred_arr[val_idx] = val_pred
if CFG.TTA:
val_pred_TTA = run_inference_loop(model, val_loader_TTA, device)
val_idx_TTA = train_meta[train_meta["fold"] == fold_id].index.values
oof_pred_arr_TTA[val_idx_TTA] = val_pred_TTA
# merge each predictions with a ratio as you like
oof_pred_arr = (oof_pred_arr * 0.7) + (oof_pred_arr_TTA * 0.3)
Roughly, there are three steps.
- Define dataloader normally and for TTA(contains augmentation(transform))
- Infer twice with normal and TTA and save each prediction array(output of model)
- Merging each prediction array with a ratio as you like
This is the flow of TTA, you can also apply another augmentation to the pipeline and increase models that will be merging, but note that it causes inferring time to increase.
Whole code
if not CFG.is_infer:
# Drop duplicate rows based on the specified column
train_meta = train_meta.drop_duplicates(subset=[column_to_check], keep='first').reset_index(drop=True)
# label_arr = train[CLASSES].values
oof_pred_arr = np.zeros((len(train_meta), CFG.n_classes-1))
if CFG.TTA:
oof_pred_arr_TTA = np.zeros((len(train_meta), CFG.n_classes-1))
score_list = []
for fold_id in range(CFG.n_folds):
print(f"\n[fold {fold_id}]")
device = torch.device(CFG.device)
# get_dataloader
val_transform_TTA, val_transform = get_transforms()
if CFG.alldata_archive:
val_dataset = Example_Dataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5_COMBINED, transform=val_transform)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)
if CFG.TTA:
val_dataset_TTA = Example_Dataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5_COMBINED, transform=val_transform_TTA)
val_loader_TTA = torch.utils.data.DataLoader(
val_dataset_TTA, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)
else:
val_dataset = Example_Dataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5, transform=val_transform)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)
if CFG.TTA:
val_dataset_TTA = Example_Dataset(df=train_meta[train_meta["fold"] == fold_id], fp_hdf=CFG.TRAIN_HDF5, transform=val_transform_TTA)
val_loader_TTA = torch.utils.data.DataLoader(
val_dataset_TTA, batch_size=CFG.batch_size, num_workers=4, shuffle=False, drop_last=False)
# # get model
model_path = CFG.OUTPUT_DIR / f"best_model_fold{fold_id}.pth"
model = timmModel(
model_name=CFG.model_name,
pretrained=False,
in_channels=6,
num_classes=0,
is_training=False
)
model.initialize_dummy() # Initialize with dummy data for dynamic linear
model.load_state_dict(torch.load(model_path, map_location=device))
# # inference
val_pred = run_inference_loop(model, val_loader, device)
val_idx = train_meta[train_meta["fold"] == fold_id].index.values
oof_pred_arr[val_idx] = val_pred
if CFG.TTA:
val_pred_TTA = run_inference_loop(model, val_loader_TTA, device)
val_idx_TTA = train_meta[train_meta["fold"] == fold_id].index.values
oof_pred_arr_TTA[val_idx_TTA] = val_pred_TTA
del val_idx
del model, val_loader
torch.cuda.empty_cache()
gc.collect()
if CFG.TTA:
# ref. CFG.TTA_rate = {'None':0.7, 'with_train_aug':0.3}
oof_pred_arr = (oof_pred_arr * CFG.TTA_rate['None']) + (oof_pred_arr_TTA * CFG.TTA_rate['with_train_aug'])
3. Summary
This time, I introduced one example of how to apply the TTA.
Please try it.
Discussion