Open9

RTStereoNet の ONNX エクスポートのワークアラウンド (ScatterND と grid_sample (grid_sampler) の置き換え)

PINTOPINTO
from __future__ import print_function
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
import math
import numpy as np

norm_layer2d = nn.BatchNorm2d
norm_layer3d = nn.BatchNorm3d


def convbn(
    in_planes,
    out_planes,
    kernel_size,
    stride,
    pad,
    dilation,
    groups = 1
):

    return nn.Sequential(
        nn.Conv2d(
            in_planes,
            out_planes,
            kernel_size=kernel_size,
            stride=stride,
            padding=dilation if dilation > 1 else pad,
            dilation = dilation,
            groups = groups,
            bias=False
        ),
        norm_layer2d(out_planes)
    )

class featexchange(nn.Module):
    def __init__(self):
        super(featexchange, self).__init__()

        self.x2_fusion = nn.Sequential(
            nn.ReLU(),
            convbn(4, 4, 3, 1, 1, 1, 1),
            nn.ReLU(),
            nn.Conv2d(4, 4, 3, 1, 1, bias=False)
        )
        self.upconv4 = nn.Sequential(
            nn.Conv2d(8, 4, 1, 1, 0, bias=False),
            norm_layer2d(4)
        )
        self.upconv8 = nn.Sequential(
            nn.Conv2d(20, 4, 1, 1, 0, bias=False),
            norm_layer2d(4)
        )

        self.x4_fusion = nn.Sequential(
            nn.ReLU(),
            convbn(8, 8, 3, 1, 1, 1, 1),
            nn.ReLU(),
            nn.Conv2d(8, 8, 3, 1, 1, bias=False)
        )
        self.downconv4 = nn.Sequential(
            nn.Conv2d(4, 8, 3, 2, 1, bias=False),
            norm_layer2d(8)
        )
        self.upconv8_2 = nn.Sequential(
            nn.Conv2d(20, 8, 1, 1, 0, bias=False),
            norm_layer2d(8)
        )

        self.x8_fusion = nn.Sequential(
            nn.ReLU(),
            convbn(20, 20, 3, 1, 1, 1, 1),
            nn.ReLU(),
            nn.Conv2d(20, 20, 3, 1, 1, bias=False)
        )
        self.downconv81 = nn.Sequential(
            nn.Conv2d(8, 20, 3, 2, 1, bias=False),
            norm_layer2d(20)
        )
        self.downconv82 = nn.Sequential(
            nn.Conv2d(8, 20, 3, 2, 1, bias=False),
            norm_layer2d(20)
        )

    def forward(self, x2, x4, x8, attention):

        A = torch.split(attention,[4,8,20],dim=1)

        x4tox2 = self.upconv4(
            F.upsample(x4, (x2.size()[2],x2.size()[3]))
        )
        x8tox2 = self.upconv8(
            F.upsample(x8, (x2.size()[2],x2.size()[3]))
        )
        fusx2  = x2 + x4tox2 + x8tox2
        fusx2  = self.x2_fusion(fusx2)*A[0].contiguous()+fusx2

        x2tox4 = self.downconv4(x2)
        x8tox4 = self.upconv8_2(
            F.upsample(x8, (x4.size()[2],x4.size()[3]))
        )
        fusx4  = x4 + x2tox4 + x8tox4
        fusx4  = self.x4_fusion(fusx4)*A[1].contiguous()+fusx4

        x2tox8 = self.downconv81(x2tox4)
        x4tox8 = self.downconv82(x4)
        fusx8  = x8 + x2tox8 + x4tox8
        fusx8  = self.x8_fusion(fusx8)*A[2].contiguous()+fusx8

        return fusx2, fusx4, fusx8

class feature_extraction(nn.Module):
    def __init__(self):
        super(feature_extraction, self).__init__()

        self.inplanes = 1
        self.firstconv = nn.Sequential(
            nn.Conv2d(3, 3, 3, 2, 1, bias=False),
            nn.Conv2d(3, 3, 3, 2, 1, bias=False),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Conv2d(3, 4, 1, 1, 0, bias=False),
            convbn(4, 4, 3, 1, 1, 1, 4),
            nn.ReLU(),
            nn.Conv2d(4, 4, 1, 1, 0, bias=False),
            convbn(4, 4, 3, 1, 1, 1, 4)
        ) # 1/4

        self.stage2 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(4, 8, 1, 1, 0, bias=False),
            convbn(8, 8, 3, 2, 1, 1, 8),
            nn.ReLU(),
            nn.Conv2d(8, 8, 1, 1, 0, bias=False),
            convbn(8, 8, 3, 1, 1, 1, 8)
        ) # 1/8

        self.stage3 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(8, 20, 1, 1, 0, bias=False),
            convbn(20, 20, 3, 2, 1, 1, 20),
            nn.ReLU(),
            nn.Conv2d(20, 20, 1, 1, 0, bias=False),
            convbn(20, 20, 3, 1, 1, 1,20)
        ) #1/16

        self.stage4 = nn.Sequential(
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(20, 10, 1, 1, 0, bias=True),
            nn.ReLU(),
            nn.Conv2d(10, 32, 1, 1, 0, bias=True),
            nn.Sigmoid(),
        )

        self.fusion = featexchange()


    def forward(self, x):
        #stage 1# 1x
        out_s1 = self.firstconv(x)
        out_s2 = self.stage2(out_s1)
        out_s3 = self.stage3(out_s2)
        attention = self.stage4(out_s3)
        out_s1, out_s2, out_s3 = self.fusion(
            out_s1,
            out_s2,
            out_s3,
            attention
        )
        return [out_s3, out_s2, out_s1]

def batch_relu_conv3d(
    in_planes,
    out_planes,
    kernel_size=3,
    stride=1,
    pad=1,
    bn3d=True
):

    if bn3d:
        return nn.Sequential(
            norm_layer3d(in_planes),
            nn.ReLU(),
            nn.Conv3d(
                in_planes,
                out_planes,
                kernel_size=kernel_size,
                padding=pad,
                stride=stride,
                bias=False
            )
        )
    else:
        return nn.Sequential(
            nn.ReLU(),
            nn.Conv3d(
                in_planes,
                out_planes,
                kernel_size=kernel_size,
                padding=pad,
                stride=stride,
                bias=False
            )
        )

def post_3dconvs(layers, channels):
    net  = [
        nn.Conv3d(
            1,
            channels,
            kernel_size=3,
            padding=1,
            stride=1,
            bias=False
        )
    ]
    net += [
        batch_relu_conv3d(channels, channels) for _ in range(layers)
    ]
    net += [
        batch_relu_conv3d(channels, 1)
    ]
    return nn.Sequential(*net)

class RTStereoNet(nn.Module):
    def __init__(self, maxdisp):
        super(RTStereoNet, self).__init__()

        self.feature_extraction = feature_extraction()
        self.maxdisp = maxdisp
        self.volume_postprocess = []

        layer_setting = [8,4,4]
        for i in range(3):
            net3d = post_3dconvs(3, layer_setting[i])
            self.volume_postprocess.append(net3d)
        self.volume_postprocess = nn.ModuleList(self.volume_postprocess)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.Conv3d):
                n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def warp(self, x, disp):
        """
        warp an image/tensor (im2) back to im1, according to the optical flow
        x: [B, C, H, W] (im2)
        flo: [B, 2, H, W] flow
        """
        B, C, H, W = x.size()
        # mesh grid
        xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
        yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
        xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
        yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
        grid = torch.cat((xx, yy), 1).float()

        if x.is_cuda:
            vgrid = grid.cuda()
        #vgrid = grid
        # vgrid[:,:1,:,:] = vgrid[:,:1,:,:] - disp
        vgrid = torch.cat([vgrid[:,:1,:,:] - disp,vgrid[:,[1],:,:]],axis=1)

        # scale grid to [-1,1]
        # vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W - 1, 1) - 1.0
        # vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H - 1, 1) - 1.0
        vgrid = torch.cat([2.0 * vgrid[:, [0], :, :] / max(W - 1, 1) - 1.0, 2.0 * vgrid[:, [1], :, :] / max(H - 1, 1) - 1.0],axis=1)
        vgrid = vgrid.permute(0, 2, 3, 1)


        # output = nn.functional.grid_sample(x, vgrid)

        def bilinear_sample_noloop(image, grid):
            """
            :param image: sampling source of shape [N, C, H, W]
            :param grid: integer sampling pixel coordinates of shape [N, grid_H, grid_W, 2]
            :return: sampling result of shape [N, C, grid_H, grid_W]
            """
            Nt, C, H, W = image.shape
            grid_H = grid.shape[1]
            grid_W = grid.shape[2]
            xgrid, ygrid = grid.split([1, 1], dim=-1)
            mask = ((xgrid >= 0) & (ygrid >= 0) & (xgrid < W - 1) & (ygrid < H - 1)).float()
            x0 = torch.floor(xgrid)
            x1 = x0 + 1
            y0 = torch.floor(ygrid)
            y1 = y0 + 1
            wa = ((x1 - xgrid) * (y1 - ygrid)).permute(3, 0, 1, 2)
            wb = ((x1 - xgrid) * (ygrid - y0)).permute(3, 0, 1, 2)
            wc = ((xgrid - x0) * (y1 - ygrid)).permute(3, 0, 1, 2)
            wd = ((xgrid - x0) * (ygrid - y0)).permute(3, 0, 1, 2)
            x0 = (x0 * mask).view(Nt, grid_H, grid_W).long()
            y0 = (y0 * mask).view(Nt, grid_H, grid_W).long()
            x1 = (x1 * mask).view(Nt, grid_H, grid_W).long()
            y1 = (y1 * mask).view(Nt, grid_H, grid_W).long()
            ind = torch.arange(Nt, device=image.device) #torch.linspace(0, Nt - 1, Nt, device=image.device)
            ind = ind.view(Nt, 1).expand(-1, grid_H).view(Nt, grid_H, 1).expand(-1, -1, grid_W).long()
            image = image.permute(1, 0, 2, 3)
            output_tensor = (
                image[:, ind, y0, x0] * wa + image[:, ind, y1, x0] * wb + image[:, ind, y0, x1] * wc + image[:, ind, y1, x1] * wd
            ).permute(1, 0, 2, 3)
            output_tensor *= mask.permute(0, 3, 1, 2).expand(-1, C, -1, -1)
            image = image.permute(1, 0, 2, 3)
            return output_tensor, mask
        output, _ = bilinear_sample_noloop(x, vgrid)

        return output


    def _build_volume_2d(
        self,
        feat_l,
        feat_r,
        maxdisp,
        stride=1
    ):

        assert maxdisp % stride == 0  # Assume maxdisp is multiple of stride
        b,c,h,w = feat_l.size()
        cost = torch.zeros(b, 1, maxdisp//stride, h, w).cuda().requires_grad_(False)
        dim1,dim2,dim3,dim4,dim5 = cost.shape
        for i in range(0, maxdisp, stride):
            if i > 0:
                # cost[:, :, i//stride, :, i:] = torch.norm(feat_l[:, :, :, i:] - feat_r[:, :, :, :-i], p=1, dim = 1,keepdim=True)
                tmp_cost = []
                for j in range(dim3):
                    if j == (i // stride):
                        tmp1 = cost[:, :, [j], :, :i]

                        # tmp2 = torch.norm(feat_l[:, :, :, i:] - feat_r[:, :, :, :-i], p=1, dim = 1,keepdim=True)
                        tmp2_sub = feat_l[:, :, :, i:] - feat_r[:, :, :, :-i]
                        tmp2_abs = torch.sqrt(torch.square(tmp2_sub))
                        tmp2 = torch.sum(tmp2_abs, 1, keepdim=True)

                        tmp2 = tmp2.unsqueeze(2)
                        tmp_cost.append(torch.cat([tmp1,tmp2],axis=4))
                    else:
                        tmp = cost[:, :, [j], :, :]
                        tmp_cost.append(tmp)
                cost = torch.cat(tmp_cost,axis=2)
            else:
                # cost[:, :, i//stride, :, i:] = torch.norm(feat_l[:, :, :, :] - feat_r[:, :, :, :], p=1,  dim =1,keepdim=True)
                tmp_cost = []
                for j in range(dim3):
                    if j == (i // stride):
                        tmp1 = cost[:, :, [j], :, :i]

                        # tmp2 = torch.norm(feat_l[:, :, :, :] - feat_r[:, :, :, :], p=1,  dim =1,keepdim=True)
                        tmp2_sub = feat_l[:, :, :, :] - feat_r[:, :, :, :]
                        tmp2_abs = torch.sqrt(torch.square(tmp2_sub))
                        tmp2 = torch.sum(tmp2_abs, 1, keepdim=True)

                        tmp2 = tmp2.unsqueeze(2)
                        tmp_cost.append(torch.cat([tmp1,tmp2],axis=4))
                    else:
                        tmp = cost[:, :, [j], :, :]
                        tmp_cost.append(tmp)
                cost = torch.cat(tmp_cost,axis=2)
        return cost.contiguous()

    def _build_volume_2d3(
        self,
        feat_l,
        feat_r,
        maxdisp,
        disp,
        stride=1
    ):

        b,c,h,w = feat_l.size()
        batch_disp = disp[:,None,:,:,:].repeat(1, maxdisp*2-1, 1, 1, 1).view(-1,1,h,w)
        temp_array = np.tile(np.array(range(-maxdisp + 1, maxdisp)), b) * stride
        batch_shift = torch.Tensor(
            np.reshape(
                temp_array,
                [len(temp_array), 1, 1, 1]
            )
        ).cuda().requires_grad_(False)
        batch_disp = batch_disp - batch_shift
        batch_feat_l = feat_l[:,None,:,:,:].repeat(1,maxdisp*2-1, 1, 1, 1).view(-1,c,h,w)
        batch_feat_r = feat_r[:,None,:,:,:].repeat(1,maxdisp*2-1, 1, 1, 1).view(-1,c,h,w)

        # cost = torch.norm(batch_feat_l - self.warp(batch_feat_r, batch_disp), 1, 1, keepdim=True)
        batch_sub = batch_feat_l - self.warp(batch_feat_r, batch_disp)
        batch_abs = torch.sqrt(torch.square(batch_sub))
        cost = torch.sum(batch_abs, 1, keepdim=True)

        cost = cost.view(b,1 ,-1, h, w).contiguous()
        return cost

    def forward(self, left, right):

        img_size = left.size()

        feats_l = self.feature_extraction(left)
        feats_r = self.feature_extraction(right)

        # print(f'@@@@@@@@@@@@@@@@@@@@@@ range(len(feats_l)): {range(len(feats_l))}')

        pred = []
        for scale in range(len(feats_l)):
            if scale > 0:
                wflow = F.upsample(
                    pred[scale-1],
                    (feats_l[scale].size(2), feats_l[scale].size(3)),
                    mode='bilinear'
                ) * feats_l[scale].size(2) / img_size[2]
                cost = self._build_volume_2d3(
                    feats_l[scale],
                    feats_r[scale],
                    3,
                    wflow,
                    stride=1
                )
            else:
                cost = self._build_volume_2d(
                    feats_l[scale],
                    feats_r[scale],
                    12,
                    stride=1
                )

            #cost = torch.unsqueeze(cost, 1)
            cost = self.volume_postprocess[scale](cost)
            cost = cost.squeeze(1)
            if scale == 0:
                pred_low_res = disparityregression2(0, 12)(F.softmax(cost, dim=1))
                pred_low_res = pred_low_res * img_size[2] / pred_low_res.size(2)
                disp_up = F.upsample(
                    pred_low_res,
                    (img_size[2], img_size[3]),
                    mode='bilinear'
                )
                pred.append(disp_up)
            else:
                pred_low_res = disparityregression2(-2, 3, stride=1)(F.softmax(cost, dim=1))
                pred_low_res = pred_low_res * img_size[2] / pred_low_res.size(2)
                disp_up = F.upsample(
                    pred_low_res,
                    (img_size[2], img_size[3]),
                    mode='bilinear'
                )
                pred.append(disp_up+pred[scale-1]) #
        if self.training:
            return pred[0],pred[1],pred[2]
        else:
            return pred[-1]

class disparityregression2(nn.Module):
    def __init__(self, start, end, stride=1):
        super(disparityregression2, self).__init__()
        self.disp = torch.arange(
            start*stride,
            end*stride,
            stride
        ).view(1, -1, 1, 1).type(torch.FloatTensor).cuda().requires_grad_(False)
    def forward(self, x):
        out = torch.sum(x * self.disp, 1, keepdim=True)
        return out
PINTOPINTO
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import time
import math
from models import *
import cv2
from PIL import Image

# 2012 data /media/jiaren/ImageNet/data_scene_flow_2012/testing/

parser = argparse.ArgumentParser(description='PSMNet')
parser.add_argument('--KITTI', default='2015', help='KITTI version')
parser.add_argument('--datapath', default='/media/jiaren/ImageNet/data_scene_flow_2015/testing/', help='select model')
parser.add_argument('--loadmodel', default='./trained/pretrained_model_KITTI2015.tar', help='loading model')
parser.add_argument('--leftimg', default= './VO04_L.png', help='load model')
parser.add_argument('--rightimg', default= './VO04_R.png', help='load model')
parser.add_argument('--model', default='RTStereoNet', help='select model')
parser.add_argument('--maxdisp', type=int, default=192, help='maxium disparity')
parser.add_argument('--no-cuda', action='store_true', default=False, help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

if args.model == 'stackhourglass':
    model = stackhourglass(args.maxdisp)
    model = nn.DataParallel(model)
elif args.model == 'basic':
    model = basic(args.maxdisp)
    model = nn.DataParallel(model)
elif args.model == 'RTStereoNet':
    model = RTStereoNet(args.maxdisp)
else:
    print('no model')

if args.cuda:
    model.cuda()

if args.loadmodel is not None:
    print('load model')
    state_dict = torch.load(args.loadmodel)
    model.load_state_dict(state_dict['state_dict'])

# H=180
# W=320
# x = torch.randn(1, 3, H, W).cuda()
# torch.onnx.export(model, args=(x,x), f=f"rtstereonet_maxdisp{args.maxdisp}_{H}x{W}.onnx", opset_version=11)
# import sys
# sys.exit(0)


print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

def test(imgL,imgR):
        model.eval()

        if args.cuda:
            imgL = imgL.cuda()
            imgR = imgR.cuda()

        with torch.no_grad():
            disp = model(imgL,imgR)

        disp = torch.squeeze(disp)
        pred_disp = disp.data.cpu().numpy()

        return pred_disp


def main():

        normal_mean_var = {'mean': [0.485, 0.456, 0.406],
                            'std': [0.229, 0.224, 0.225]}
        infer_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(**normal_mean_var)
            ]
        )

        imgL_o = Image.open(args.leftimg).convert('RGB')
        imgR_o = Image.open(args.rightimg).convert('RGB')

        imgL = infer_transform(imgL_o)
        imgR = infer_transform(imgR_o)


        # pad to width and hight to 16 times
        if imgL.shape[1] % 16 != 0:
            times = imgL.shape[1]//16
            top_pad = (times+1)*16 -imgL.shape[1]
        else:
            top_pad = 0

        if imgL.shape[2] % 16 != 0:
            times = imgL.shape[2]//16
            right_pad = (times+1)*16-imgL.shape[2]
        else:
            right_pad = 0

        imgL = F.pad(imgL,(0,right_pad, top_pad,0)).unsqueeze(0)
        imgR = F.pad(imgR,(0,right_pad, top_pad,0)).unsqueeze(0)

        start_time = time.time()
        pred_disp = test(imgL,imgR)
        print('time = %.2f' %(time.time() - start_time))


        if top_pad !=0 and right_pad != 0:
            img = pred_disp[top_pad:,:-right_pad]
        elif top_pad ==0 and right_pad != 0:
            img = pred_disp[:,:-right_pad]
        elif top_pad !=0 and right_pad == 0:
            img = pred_disp[top_pad:,:]
        else:
            img = pred_disp

        # img = (img*256).astype('uint16')
        # img = Image.fromarray(img)


        # Heatmap
        depth_map = np.squeeze(img)
        d_min = np.min(depth_map)
        d_max = np.max(depth_map)
        depth_map = (depth_map - d_min) / (d_max - d_min)
        depth_map = depth_map * 255.0
        depth_map = np.asarray(depth_map, dtype="uint8")
        depth_map = cv2.applyColorMap(depth_map, cv2.COLORMAP_JET)
        img = Image.fromarray(depth_map)

        img.save('Test_disparity.png')

if __name__ == '__main__':
    main()
PINTOPINTO
python3 Test_img.py \
--model RTStereoNet \
--loadmodel trained/pretrained_Kitti2015_realtime.tar \
--leftimg 0479_left.png \
--rightimg 0479_right.png
PINTOPINTO
H=180
W=320
onnx_file = f"rtstereonet_maxdisp{args.maxdisp}_{H}x{W}.onnx"
x = torch.randn(1, 3, H, W).cuda()
torch.onnx.export(
    model,
    args=(x,x),
    f=onnx_file,
    opset_version=11
)
import onnx
from onnxsim import simplify
model = onnx.load(onnx_file)
model_simp, check = simplify(model)
onnx.save(model_simp, onnx_file)
import sys
sys.exit(0)
PINTOPINTO
H=180
W=320
MODEL=rtstereonet
MAXDISP=192
$INTEL_OPENVINO_DIR/deployment_tools/model_optimizer/mo.py \
--input_model ${MODEL}_maxdisp${MAXDISP}_${H}x${W}.onnx \
--data_type FP32 \
--output_dir ${MODEL}_maxdisp${MAXDISP}_${H}x${W}/openvino/FP32
$INTEL_OPENVINO_DIR/deployment_tools/model_optimizer/mo.py \
--input_model ${MODEL}_maxdisp${MAXDISP}_${H}x${W}.onnx \
--data_type FP16 \
--output_dir ${MODEL}_maxdisp${MAXDISP}_${H}x${W}/openvino/FP16

mkdir -p ${MODEL}_maxdisp${MAXDISP}_${H}x${W}/openvino/myriad
${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/lib/intel64/myriad_compile \
-m ${MODEL}_maxdisp${MAXDISP}_${H}x${W}/openvino/FP16/rtstereonet_maxdisp${MAXDISP}_${H}x${W}.xml \
-ip U8 \
-VPU_NUMBER_OF_SHAVES 4 \
-VPU_NUMBER_OF_CMX_SLICES 4 \
-o ${MODEL}_maxdisp${MAXDISP}_${H}x${W}/openvino/myriad/${MODEL}_maxdisp${MAXDISP}_${H}x${W}.blob
PINTOPINTO
  • [WIP] TensorFlow version - bilinear only
import tensorflow as tf

def bilinear_sample_noloop(image, grid):
    Nt, H, W, C = image.shape
    grid_H = grid.shape[1]
    grid_W = grid.shape[2]
    xgrid, ygrid = tf.split(
        value=grid,
        num_or_size_splits=2,
        axis=-1,
    )
    mask = tf.cast(
        (xgrid >= 0) & (ygrid >= 0) & (xgrid < W - 1) & (ygrid < H - 1),
        dtype=tf.float32
    )
    x0 = tf.math.floor(xgrid)
    x1 = x0 + 1
    y0 = tf.math.floor(ygrid)
    y1 = y0 + 1

    wa = tf.transpose(
        a=(x1 - xgrid) * (y1 - ygrid),
        perm=[3, 0, 1, 2],
    )
    wb = tf.transpose(
        a=(x1 - xgrid) * (ygrid - y0),
        perm=[3, 0, 1, 2],
    )
    wc = tf.transpose(
        a=(xgrid - x0) * (y1 - ygrid),
        perm=[3, 0, 1, 2],
    )
    wd = tf.transpose(
        a=(xgrid - x0) * (ygrid - y0),
        perm=[3, 0, 1, 2],
    )

    x0 = tf.cast(
        tf.reshape(
            tensor=(x0 * mask),
            shape=[Nt, grid_H, grid_W],
        ),
        dtype=tf.int64,
    )
    y0 = tf.cast(
        tf.reshape(
            tensor=(y0 * mask),
            shape=[Nt, grid_H, grid_W]
        ),
        dtype=tf.int64,
    )
    x1 = tf.cast(
        tf.reshape(
            tensor=(x1 * mask),
            shape=[Nt, grid_H, grid_W]
        ),
        dtype=tf.int64,
    )
    y1 = tf.cast(
        tf.reshape(
            tensor=(y1 * mask),
            shape=[Nt, grid_H, grid_W]
        ),
        dtype=tf.int64,
    )

    ind = tf.range(limit=Nt)
    ind = tf.reshape(tensor=ind, shape=[Nt, 1])
    ind = tf.tile(input=ind, multiples=[1, grid_H])
    ind = tf.reshape(tensor=ind, shape=[Nt, grid_H, 1])
    ind = tf.tile(input=ind, multiples=[1, 1, grid_W])
    ind = tf.cast(ind, dtype=tf.int64)

    image = tf.transpose(
        a=image,
        perm=[3,0,1,2],
    )
    output_tensor = \
        image[:, ind, y0, x0] * wa \
        + image[:, ind, y1, x0] * wb \
        + image[:, ind, y0, x1] * wc \
        + image[:, ind, y1, x1] * wd
    output_tensor = tf.transpose(
        a=output_tensor,
        perm=[1,2,3,0],
    )
    mask = tf.tile(
        input=mask,
        multiples=[1,1,1,C],
    )
    output_tensor = output_tensor * mask

    return output_tensor