📚
データ変換パイプラインのテスト
documentの関数の挙動をテスト
import
import random
import mmcv
import numpy as np
from mmcv.transforms import TRANSFORMS, BaseTransform, Compose
from my_transform import MyTransform
from PIL import Image
from torchvision import transforms
1.元画像
raw_results = dict(img_path='demo.jpg')
image_path = 'demo.jpg'
image = Image.open(image_path)
transform = transforms.Compose([
transforms.ToTensor(),
])
tensor_image = transform(image)
print('Tensor shape:', tensor_image.shape)
> Tensor shape: torch.Size([3, 1024, 1024])
※灰色の枠はスクショのため
2. Resize
transform = Compose([
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(800, 400), keep_ratio=True),
])
results = transform(raw_results)
print(results['img'].shape)
mmcv.imwrite(results['img'], '2.jpg')
> (400, 400, 3)
100,100でも試したが、小さいほうのサイズに合わせる仕様らしい
3. Resize
transform = Compose([
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(800, 400), keep_ratio=False),
])
results = transform(raw_results)
print(results['img'].shape)
mmcv.imwrite(results['img'], '3.jpg')
> (400, 800, 3)
4. custom transform
@TRANSFORMS.register_module()
class MyFlip(BaseTransform):
def __init__(self, direction: str):
super().__init__()
self.direction = direction
def transform(self, results: dict) -> dict:
img = results['img']
results['img'] = mmcv.imflip(img, direction=self.direction)
return results
transform = Compose([
dict(type='LoadImageFromFile'),
dict(type='MyFlip', direction='horizontal'),
])
results = transform(raw_results)
mmcv.imwrite(results['img'], '4.jpg')
@TRANSFORMS.register_module()をデコレータとして使用するだけで、horizontal flipできている
5. Pad
transform = Compose([
dict(type='LoadImageFromFile'),
dict(type='Pad', size_divisor=500), # 500の倍数
])
results = transform(raw_results)
print(results['img'].shape)
mmcv.imwrite(results['img'], '5.jpg')
> (1500, 1500, 3)
6. Normalize
transform = Compose([
dict(type='LoadImageFromFile'),
dict(type='Normalize', mean=[0, 0, 0], std=[2, 2, 2], to_rgb=True),
])
results = transform(raw_results)
print(results['img'].shape)
mmcv.imwrite(results['img'], '6.jpg')
> (1024, 1024, 3)
7. RandomGrayscale
transform = Compose([
dict(type='LoadImageFromFile'),
dict(type='RandomGrayscale',prob=1.0,color_format='bgr'),
])
results = transform(raw_results)
print(results['img'].shape)
mmcv.imwrite(results['img'], '8.jpg')
> (1024, 1024)
.
.
.
Discussion