📚

データ変換パイプラインのテスト

2024/04/14に公開

documentの関数の挙動をテスト

https://mmdetection.readthedocs.io/en/v2.19.1/tutorials/data_pipeline.html

import

import random
import mmcv
import numpy as np
from mmcv.transforms import TRANSFORMS, BaseTransform, Compose
from my_transform import MyTransform
from PIL import Image
from torchvision import transforms

1.元画像

raw_results = dict(img_path='demo.jpg')
image_path = 'demo.jpg'
image = Image.open(image_path)
transform = transforms.Compose([
    transforms.ToTensor(), 
])
tensor_image = transform(image)
print('Tensor shape:', tensor_image.shape)

> Tensor shape: torch.Size([3, 1024, 1024])

※灰色の枠はスクショのため

2. Resize

transform = Compose([
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(800, 400), keep_ratio=True),
])
results = transform(raw_results)
print(results['img'].shape)
mmcv.imwrite(results['img'], '2.jpg')
> (400, 400, 3)

100,100でも試したが、小さいほうのサイズに合わせる仕様らしい

3. Resize

transform = Compose([
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(800, 400), keep_ratio=False),
])
results = transform(raw_results)
print(results['img'].shape)
mmcv.imwrite(results['img'], '3.jpg')
> (400, 800, 3)

4. custom transform

@TRANSFORMS.register_module()
class MyFlip(BaseTransform):
    def __init__(self, direction: str):
        super().__init__()
        self.direction = direction

    def transform(self, results: dict) -> dict:
        img = results['img']
        results['img'] = mmcv.imflip(img, direction=self.direction)
        return results

transform = Compose([
    dict(type='LoadImageFromFile'),
    dict(type='MyFlip', direction='horizontal'),
])
results = transform(raw_results)
mmcv.imwrite(results['img'], '4.jpg')

@TRANSFORMS.register_module()をデコレータとして使用するだけで、horizontal flipできている

5. Pad

transform = Compose([
    dict(type='LoadImageFromFile'),
    dict(type='Pad', size_divisor=500), # 500の倍数
])
results = transform(raw_results)
print(results['img'].shape)
mmcv.imwrite(results['img'], '5.jpg')
> (1500, 1500, 3)

6. Normalize

transform = Compose([
    dict(type='LoadImageFromFile'),
    dict(type='Normalize', mean=[0, 0, 0], std=[2, 2, 2], to_rgb=True),
])
results = transform(raw_results)
print(results['img'].shape)
mmcv.imwrite(results['img'], '6.jpg')
> (1024, 1024, 3)

7. RandomGrayscale

transform = Compose([
    dict(type='LoadImageFromFile'),
    dict(type='RandomGrayscale',prob=1.0,color_format='bgr'),
])
results = transform(raw_results)
print(results['img'].shape)
mmcv.imwrite(results['img'], '8.jpg')
> (1024, 1024)

.


.


.


Discussion