Open3

torch.nn.functional.affine_grid のONNXエクスポートのための置き換え

PINTOPINTO
# https://pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AffineGridGenerator.cpp
def affine_grid(theta, size, align_corners=False):
    N, C, H, W = size
    grid = create_grid(N, C, H, W, align_corners)
    grid = grid.view(N, H * W, 3).bmm(theta.transpose(1, 2))
    grid = grid.view(N, H, W, 2)
    return grid

def create_grid(N, C, H, W, align_corners):
    grid = torch.empty((N, H, W, 3), dtype=torch.float32)
    grid.select(-1, 0).copy_(linspace_from_neg_one(W, align_corners))
    grid.select(-1, 1).copy_(linspace_from_neg_one(H, align_corners).unsqueeze_(-1))
    grid.select(-1, 2).fill_(1)
    return grid
    
def linspace_from_neg_one(num_steps, align_corners, dtype=torch.float32):
    r = torch.linspace(-1, 1, num_steps, dtype=torch.float32)
    if not align_corners:
        r = r * (num_steps - 1) / num_steps
    return r

def patch_affine_grid_generator():
    torch.nn.functional.affine_grid = affine_grid
PINTOPINTO

2025年3月14日 torch==2.8.0.dev20250313+cu126affine_grid のエクスポートに対応していることを確認できた。

test.py
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, theta, size):
        return torch.nn.functional.affine_grid(theta, size, align_corners=None)

model = Model()
theta = torch.ones((1, 2, 3))
size = torch.Size((1,3,24,24))
torch.onnx.export(model, (theta, size,), 'test.onnx', dynamo=True)

推論も正常に動いた。