Open2

OpenCV warpAffine transform あるいは from skimage import transform as trans を PyTorch で置き換える

PINTOPINTO
pytorch_transform.py
import cv2
import numpy as np
import torch
import torch.nn.functional as F

def get_rotation_matrix2d_torch(center, angle_deg, scale, device='cpu'):
    """
    OpenCV の cv2.getRotationMatrix2D(center, angle_deg, scale) と同等の
    2x3 アフィン変換行列を PyTorch で作成する
    """
    angle = angle_deg * torch.pi / 180
    alpha = scale * torch.cos(angle)
    beta  = scale * torch.sin(angle)
    cx, cy = center
    M_torch = torch.tensor([
        [ alpha,  beta,  (1 - alpha)*cx - beta*cy ],
        [-beta,   alpha,  beta*cx + (1 - alpha)*cy ]
    ], dtype=torch.float32, device=device)
    return M_torch

def warp_affine_cv(
    img: np.ndarray,  # (H, W, C), RGB or BGR
    center,           # (cx, cy) 入力画像上の回転中心
    angle_deg: float, # +で反時計回り
    scale: float,
    out_size: tuple   # (width, height)
):
    """
    OpenCV 版 warpAffine: (H,W,C) の画像を (out_size) に変換して返す。
    M_cv: 入力->出力 の 2x3 行列
    """
    M_cv = cv2.getRotationMatrix2D(center, angle_deg, scale)
    w_out, h_out = out_size
    warped_cv = cv2.warpAffine(
        img, M_cv, (w_out, h_out),
        flags=cv2.INTER_LINEAR,
        borderValue=0  # 外は黒
    )
    return warped_cv, M_cv  # (H_out, W_out, C), shape=(2,3)

def warp_affine_torch(
    img_torch: torch.Tensor,  # [C,H,W], float32
    M_cv_2x3: np.ndarray,     # OpenCV (入力->出力) の 2x3 行列
    out_size: tuple,          # (width, height)
    align_corners: bool = True
):
    """
    PyTorch 版 warpAffine:
    - OpenCV と同じピクセル座標系の行列(M_cv_2x3)を受け取り、
    - その逆行列 + 正規化座標系変換をしてから `affine_grid` に通す。
    """
    device = img_torch.device

    #----------------------------------
    # 1) 2x3 -> 3x3 に拡張
    #----------------------------------
    M_cv_3x3 = np.vstack([M_cv_2x3, [0, 0, 1]])  # shape=(3,3)
    M_cv_3x3_t = torch.from_numpy(M_cv_3x3).float().to(device)

    #----------------------------------
    # 2) 逆行列 (出力ピクセル -> 入力ピクセル)
    #----------------------------------
    M_inv_3x3_t = torch.inverse(M_cv_3x3_t)

    #----------------------------------
    # 3) ピクセル座標系 <-> 正規化座標系 の変換行列を用意
    #----------------------------------
    # 入力画像: [0, W_in-1]×[0, H_in-1] -> [-1,1]×[-1,1]
    # 出力画像: [0, W_out-1]×[0, H_out-1] -> [-1,1]×[-1,1]
    # (align_corners=True のときは下記の係数が (W-1) や (H-1) になる)
    C, H_in, W_in = img_torch.shape
    w_out, h_out = out_size

    # T_in : 入力ピクセル -> 入力正規化座標
    #   x_norm = 2*x/(W_in-1) - 1
    #   y_norm = 2*y/(H_in-1) - 1
    T_in = torch.tensor([
        [2.0/(W_in-1), 0,            -1],
        [0,            2.0/(H_in-1), -1],
        [0,            0,             1]
    ], dtype=torch.float32, device=device)

    # T_out : 出力ピクセル -> 出力正規化座標
    #   x_norm_out = 2*x_out/(W_out-1) - 1
    #   y_norm_out = 2*y_out/(H_out-1) - 1
    T_out = torch.tensor([
        [2.0/(w_out-1), 0,            -1],
        [0,             2.0/(h_out-1),-1],
        [0,             0,             1]
    ], dtype=torch.float32, device=device)

    # T_out^{-1} : 出力正規化座標 -> 出力ピクセル
    T_out_inv = torch.inverse(T_out)

    #----------------------------------
    # 4) 正規化座標系での行列 M_norm を作る
    #
    #   出力正規化座標 -> 入力正規化座標
    # = (入力ピクセル -> 入力正規化座標) x (出力ピクセル -> 入力ピクセル) x (出力正規化座標 -> 出力ピクセル)
    #----------------------------------
    M_norm_3x3_t = T_in @ M_inv_3x3_t @ T_out_inv

    # affine_grid に渡すのは 2x3
    M_norm_2x3_t = M_norm_3x3_t[:2, :]  # shape=(2,3)
    theta = M_norm_2x3_t.unsqueeze(0)   # shape=(1,2,3)

    #----------------------------------
    # 5) affine_grid -> grid_sample
    #----------------------------------
    x = img_torch.unsqueeze(0)  # (1,C,H_in,W_in)
    grid = F.affine_grid(
        theta,
        size=(1, C, h_out, w_out),  # (N=1, C, H_out, W_out)
        align_corners=align_corners
    )
    warped_t = F.grid_sample(
        x, grid,
        mode='bilinear',
        padding_mode='zeros',
        align_corners=align_corners
    )
    return warped_t[0]  # -> [C, H_out, W_out]


if __name__ == "__main__":
    #--------------------------------------
    # 例: 同じ行列を使って OpenCV と PyTorch を比較
    #--------------------------------------
    img_bgr = cv2.imread("assets/test.png")  # 適当な画像パス
    if img_bgr is None:
        raise FileNotFoundError("画像が読み込めません。パスを確認してください。")

    # BGR->RGB
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    H, W, _ = img_rgb.shape

    center = (W/2, H/2)   # 入力画像の中心 (cx, cy)
    angle_deg = 30.0      # 反時計回り30度
    scale = 0.5
    out_size = (256, 256) # (width=256, height=256)

    # 1) OpenCV warpAffine (入力->出力 行列)
    warped_cv, M_cv_2x3 = warp_affine_cv(img_rgb, center, angle_deg, scale, out_size)

    # 2) PyTorch で同じ行列を用いて変換
    img_t = torch.from_numpy(img_rgb.transpose(2,0,1)).float()  # [C,H,W]
    M_2x3_t = get_rotation_matrix2d_torch(torch.tensor(center), torch.tensor(angle_deg), torch.tensor(scale))
    warped_t = warp_affine_torch(img_t, M_2x3_t, out_size, align_corners=True)
    warped_np = warped_t.permute(1,2,0).numpy().clip(0,255).astype(np.uint8)

    # 3) 表示 (BGRに戻して可視化)
    warped_cv_bgr = cv2.cvtColor(warped_cv, cv2.COLOR_RGB2BGR)
    warped_t_bgr  = cv2.cvtColor(warped_np, cv2.COLOR_RGB2BGR)
    cv2.imshow("OpenCV", warped_cv_bgr)
    cv2.imshow("PyTorch", warped_t_bgr)
    cv2.waitKey(0)
    cv2.destroyAllWindows()