Open2

torch.nn.functional.affine_grid のONNXエクスポートのための置き換え

PINTOPINTO
# https://pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AffineGridGenerator.cpp
def affine_grid(theta, size, align_corners=False):
    N, C, H, W = size
    grid = create_grid(N, C, H, W, align_corners)
    grid = grid.view(N, H * W, 3).bmm(theta.transpose(1, 2))
    grid = grid.view(N, H, W, 2)
    return grid

def create_grid(N, C, H, W, align_corners):
    grid = torch.empty((N, H, W, 3), dtype=torch.float32)
    grid.select(-1, 0).copy_(linspace_from_neg_one(W, align_corners))
    grid.select(-1, 1).copy_(linspace_from_neg_one(H, align_corners).unsqueeze_(-1))
    grid.select(-1, 2).fill_(1)
    return grid
    
def linspace_from_neg_one(num_steps, align_corners, dtype=torch.float32):
    r = torch.linspace(-1, 1, num_steps, dtype=torch.float32)
    if not align_corners:
        r = r * (num_steps - 1) / num_steps
    return r

def patch_affine_grid_generator():
    torch.nn.functional.affine_grid = affine_grid