Open7

grid_sample (grid_sampler) to ONNX

dhirookadhirooka

こんにちは。私も同じことを調査しており、MMCVというライブラリでF.grid_sampleの置き換えが提供されていたので共有します。ご参考までに。
https://github.com/pytorch/pytorch/issues/27212#issuecomment-966334554

PINTOPINTO

ありがとうございます!
mmcv._ext の取り込みを強制されますが、onnxの拡張モジュールを使用しているわけではないのでしょうか?まだ mmcv をしっかりインストールしていないため、この質問が早とちりでしたらすみません。。。
実は、たまたま別件で昨晩こちらのドキュメントを見ていたのでカスタムOPをビルドして取り込む必要があるのかなぁ、と思いました。私は DCNv2 のコンバートをしているときに下記のドキュメントにあたりました。

https://github.com/open-mmlab/mmcv/blob/2b39d7a8ec638774d7789c6e7d23340d25bb3f50/docs/deployment/onnxruntime_op.md#list-of-operators-for-onnx-runtime-supported-in-mmcv

dhirookadhirooka

上記のGitHub issueで利用しているbilinear_grid_sampleに関しては、(MMCV内で提供されているものの)中身はピュアなPyTorchの処理で、ONNXの拡張モジュールではないという認識です。mmcv._extの取り込みも不要だと思います。
極端な話、以下のページのbilinear_grid_sampleをコピペすればMMCVをインストールしなくてもF.grid_sampleの代わりに利用できます(MMCVはApache 2.0ライセンスなのでその規約に従う必要はありますが)。

https://mmcv.readthedocs.io/en/latest/_modules/mmcv/ops/point_sample.html

共有いただいたカスタムOPのビルド→mmcv._extを使うと、F.grid_sampleなどを書き換えることなくONNXに変換、ONNX Runtimeで推論できるようですね。

PINTOPINTO

おぉ!有難うございます。パフォーマンスが良くなりそうなら是非移行してみたいと思います!とても勉強になりました。

PINTOPINTO
def bilinear_grid_sample(im, grid, align_corners=False):
    """Given an input and a flow-field grid, computes the output using input
    values and pixel locations from grid. Supported only bilinear interpolation
    method to sample the input pixels.

    Args:
        im (torch.Tensor): Input feature map, shape (N, C, H, W)
        grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
        align_corners {bool}: If set to True, the extrema (-1 and 1) are
            considered as referring to the center points of the input’s
            corner pixels. If set to False, they are instead considered as
            referring to the corner points of the input’s corner pixels,
            making the sampling more resolution agnostic.

    Returns:
        torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
    """
    n, c, h, w = im.shape
    gn, gh, gw, _ = grid.shape
    assert n == gn

    x = grid[:, :, :, 0]
    y = grid[:, :, :, 1]

    if align_corners:
        x = ((x + 1) / 2) * (w - 1)
        y = ((y + 1) / 2) * (h - 1)
    else:
        x = ((x + 1) * w - 1) / 2
        y = ((y + 1) * h - 1) / 2

    x = x.view(n, -1)
    y = y.view(n, -1)

    x0 = torch.floor(x).long()
    y0 = torch.floor(y).long()
    x1 = x0 + 1
    y1 = y0 + 1

    wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
    wb = ((x1 - x) * (y - y0)).unsqueeze(1)
    wc = ((x - x0) * (y1 - y)).unsqueeze(1)
    wd = ((x - x0) * (y - y0)).unsqueeze(1)

    # Apply default for grid_sample function zero padding
    im_padded = torch.nn.functional.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
    padded_h = h + 2
    padded_w = w + 2
    # save points positions after padding
    x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

    # Clip coordinates to padded image size
    x0 = torch.where(x0 < 0, torch.tensor(0), x0)
    x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
    x1 = torch.where(x1 < 0, torch.tensor(0), x1)
    x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
    y0 = torch.where(y0 < 0, torch.tensor(0), y0)
    y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
    y1 = torch.where(y1 < 0, torch.tensor(0), y1)
    y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)

    im_padded = im_padded.view(n, c, -1)

    x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
    x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
    x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
    x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)

    Ia = torch.gather(im_padded, 2, x0_y0)
    Ib = torch.gather(im_padded, 2, x0_y1)
    Ic = torch.gather(im_padded, 2, x1_y0)
    Id = torch.gather(im_padded, 2, x1_y1)

    return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)