Open3
torch.nn.functional.affine_grid のONNXエクスポートのための置き換え
# https://pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AffineGridGenerator.cpp
def affine_grid(theta, size, align_corners=False):
N, C, H, W = size
grid = create_grid(N, C, H, W, align_corners)
grid = grid.view(N, H * W, 3).bmm(theta.transpose(1, 2))
grid = grid.view(N, H, W, 2)
return grid
def create_grid(N, C, H, W, align_corners):
grid = torch.empty((N, H, W, 3), dtype=torch.float32)
grid.select(-1, 0).copy_(linspace_from_neg_one(W, align_corners))
grid.select(-1, 1).copy_(linspace_from_neg_one(H, align_corners).unsqueeze_(-1))
grid.select(-1, 2).fill_(1)
return grid
def linspace_from_neg_one(num_steps, align_corners, dtype=torch.float32):
r = torch.linspace(-1, 1, num_steps, dtype=torch.float32)
if not align_corners:
r = r * (num_steps - 1) / num_steps
return r
def patch_affine_grid_generator():
torch.nn.functional.affine_grid = affine_grid
2025年3月14日 torch==2.8.0.dev20250313+cu126
で affine_grid
のエクスポートに対応していることを確認できた。
test.py
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, theta, size):
return torch.nn.functional.affine_grid(theta, size, align_corners=None)
model = Model()
theta = torch.ones((1, 2, 3))
size = torch.Size((1,3,24,24))
torch.onnx.export(model, (theta, size,), 'test.onnx', dynamo=True)
推論も正常に動いた。