Open9
RTStereoNet の ONNX エクスポートのワークアラウンド (ScatterND と grid_sample (grid_sampler) の置き換え)
from __future__ import print_function
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
import math
import numpy as np
norm_layer2d = nn.BatchNorm2d
norm_layer3d = nn.BatchNorm3d
def convbn(
in_planes,
out_planes,
kernel_size,
stride,
pad,
dilation,
groups = 1
):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=dilation if dilation > 1 else pad,
dilation = dilation,
groups = groups,
bias=False
),
norm_layer2d(out_planes)
)
class featexchange(nn.Module):
def __init__(self):
super(featexchange, self).__init__()
self.x2_fusion = nn.Sequential(
nn.ReLU(),
convbn(4, 4, 3, 1, 1, 1, 1),
nn.ReLU(),
nn.Conv2d(4, 4, 3, 1, 1, bias=False)
)
self.upconv4 = nn.Sequential(
nn.Conv2d(8, 4, 1, 1, 0, bias=False),
norm_layer2d(4)
)
self.upconv8 = nn.Sequential(
nn.Conv2d(20, 4, 1, 1, 0, bias=False),
norm_layer2d(4)
)
self.x4_fusion = nn.Sequential(
nn.ReLU(),
convbn(8, 8, 3, 1, 1, 1, 1),
nn.ReLU(),
nn.Conv2d(8, 8, 3, 1, 1, bias=False)
)
self.downconv4 = nn.Sequential(
nn.Conv2d(4, 8, 3, 2, 1, bias=False),
norm_layer2d(8)
)
self.upconv8_2 = nn.Sequential(
nn.Conv2d(20, 8, 1, 1, 0, bias=False),
norm_layer2d(8)
)
self.x8_fusion = nn.Sequential(
nn.ReLU(),
convbn(20, 20, 3, 1, 1, 1, 1),
nn.ReLU(),
nn.Conv2d(20, 20, 3, 1, 1, bias=False)
)
self.downconv81 = nn.Sequential(
nn.Conv2d(8, 20, 3, 2, 1, bias=False),
norm_layer2d(20)
)
self.downconv82 = nn.Sequential(
nn.Conv2d(8, 20, 3, 2, 1, bias=False),
norm_layer2d(20)
)
def forward(self, x2, x4, x8, attention):
A = torch.split(attention,[4,8,20],dim=1)
x4tox2 = self.upconv4(
F.upsample(x4, (x2.size()[2],x2.size()[3]))
)
x8tox2 = self.upconv8(
F.upsample(x8, (x2.size()[2],x2.size()[3]))
)
fusx2 = x2 + x4tox2 + x8tox2
fusx2 = self.x2_fusion(fusx2)*A[0].contiguous()+fusx2
x2tox4 = self.downconv4(x2)
x8tox4 = self.upconv8_2(
F.upsample(x8, (x4.size()[2],x4.size()[3]))
)
fusx4 = x4 + x2tox4 + x8tox4
fusx4 = self.x4_fusion(fusx4)*A[1].contiguous()+fusx4
x2tox8 = self.downconv81(x2tox4)
x4tox8 = self.downconv82(x4)
fusx8 = x8 + x2tox8 + x4tox8
fusx8 = self.x8_fusion(fusx8)*A[2].contiguous()+fusx8
return fusx2, fusx4, fusx8
class feature_extraction(nn.Module):
def __init__(self):
super(feature_extraction, self).__init__()
self.inplanes = 1
self.firstconv = nn.Sequential(
nn.Conv2d(3, 3, 3, 2, 1, bias=False),
nn.Conv2d(3, 3, 3, 2, 1, bias=False),
nn.BatchNorm2d(3),
nn.ReLU(),
nn.Conv2d(3, 4, 1, 1, 0, bias=False),
convbn(4, 4, 3, 1, 1, 1, 4),
nn.ReLU(),
nn.Conv2d(4, 4, 1, 1, 0, bias=False),
convbn(4, 4, 3, 1, 1, 1, 4)
) # 1/4
self.stage2 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(4, 8, 1, 1, 0, bias=False),
convbn(8, 8, 3, 2, 1, 1, 8),
nn.ReLU(),
nn.Conv2d(8, 8, 1, 1, 0, bias=False),
convbn(8, 8, 3, 1, 1, 1, 8)
) # 1/8
self.stage3 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(8, 20, 1, 1, 0, bias=False),
convbn(20, 20, 3, 2, 1, 1, 20),
nn.ReLU(),
nn.Conv2d(20, 20, 1, 1, 0, bias=False),
convbn(20, 20, 3, 1, 1, 1,20)
) #1/16
self.stage4 = nn.Sequential(
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(20, 10, 1, 1, 0, bias=True),
nn.ReLU(),
nn.Conv2d(10, 32, 1, 1, 0, bias=True),
nn.Sigmoid(),
)
self.fusion = featexchange()
def forward(self, x):
#stage 1# 1x
out_s1 = self.firstconv(x)
out_s2 = self.stage2(out_s1)
out_s3 = self.stage3(out_s2)
attention = self.stage4(out_s3)
out_s1, out_s2, out_s3 = self.fusion(
out_s1,
out_s2,
out_s3,
attention
)
return [out_s3, out_s2, out_s1]
def batch_relu_conv3d(
in_planes,
out_planes,
kernel_size=3,
stride=1,
pad=1,
bn3d=True
):
if bn3d:
return nn.Sequential(
norm_layer3d(in_planes),
nn.ReLU(),
nn.Conv3d(
in_planes,
out_planes,
kernel_size=kernel_size,
padding=pad,
stride=stride,
bias=False
)
)
else:
return nn.Sequential(
nn.ReLU(),
nn.Conv3d(
in_planes,
out_planes,
kernel_size=kernel_size,
padding=pad,
stride=stride,
bias=False
)
)
def post_3dconvs(layers, channels):
net = [
nn.Conv3d(
1,
channels,
kernel_size=3,
padding=1,
stride=1,
bias=False
)
]
net += [
batch_relu_conv3d(channels, channels) for _ in range(layers)
]
net += [
batch_relu_conv3d(channels, 1)
]
return nn.Sequential(*net)
class RTStereoNet(nn.Module):
def __init__(self, maxdisp):
super(RTStereoNet, self).__init__()
self.feature_extraction = feature_extraction()
self.maxdisp = maxdisp
self.volume_postprocess = []
layer_setting = [8,4,4]
for i in range(3):
net3d = post_3dconvs(3, layer_setting[i])
self.volume_postprocess.append(net3d)
self.volume_postprocess = nn.ModuleList(self.volume_postprocess)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def warp(self, x, disp):
"""
warp an image/tensor (im2) back to im1, according to the optical flow
x: [B, C, H, W] (im2)
flo: [B, 2, H, W] flow
"""
B, C, H, W = x.size()
# mesh grid
xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
grid = torch.cat((xx, yy), 1).float()
if x.is_cuda:
vgrid = grid.cuda()
#vgrid = grid
# vgrid[:,:1,:,:] = vgrid[:,:1,:,:] - disp
vgrid = torch.cat([vgrid[:,:1,:,:] - disp,vgrid[:,[1],:,:]],axis=1)
# scale grid to [-1,1]
# vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W - 1, 1) - 1.0
# vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H - 1, 1) - 1.0
vgrid = torch.cat([2.0 * vgrid[:, [0], :, :] / max(W - 1, 1) - 1.0, 2.0 * vgrid[:, [1], :, :] / max(H - 1, 1) - 1.0],axis=1)
vgrid = vgrid.permute(0, 2, 3, 1)
# output = nn.functional.grid_sample(x, vgrid)
def bilinear_sample_noloop(image, grid):
"""
:param image: sampling source of shape [N, C, H, W]
:param grid: integer sampling pixel coordinates of shape [N, grid_H, grid_W, 2]
:return: sampling result of shape [N, C, grid_H, grid_W]
"""
Nt, C, H, W = image.shape
grid_H = grid.shape[1]
grid_W = grid.shape[2]
xgrid, ygrid = grid.split([1, 1], dim=-1)
mask = ((xgrid >= 0) & (ygrid >= 0) & (xgrid < W - 1) & (ygrid < H - 1)).float()
x0 = torch.floor(xgrid)
x1 = x0 + 1
y0 = torch.floor(ygrid)
y1 = y0 + 1
wa = ((x1 - xgrid) * (y1 - ygrid)).permute(3, 0, 1, 2)
wb = ((x1 - xgrid) * (ygrid - y0)).permute(3, 0, 1, 2)
wc = ((xgrid - x0) * (y1 - ygrid)).permute(3, 0, 1, 2)
wd = ((xgrid - x0) * (ygrid - y0)).permute(3, 0, 1, 2)
x0 = (x0 * mask).view(Nt, grid_H, grid_W).long()
y0 = (y0 * mask).view(Nt, grid_H, grid_W).long()
x1 = (x1 * mask).view(Nt, grid_H, grid_W).long()
y1 = (y1 * mask).view(Nt, grid_H, grid_W).long()
ind = torch.arange(Nt, device=image.device) #torch.linspace(0, Nt - 1, Nt, device=image.device)
ind = ind.view(Nt, 1).expand(-1, grid_H).view(Nt, grid_H, 1).expand(-1, -1, grid_W).long()
image = image.permute(1, 0, 2, 3)
output_tensor = (
image[:, ind, y0, x0] * wa + image[:, ind, y1, x0] * wb + image[:, ind, y0, x1] * wc + image[:, ind, y1, x1] * wd
).permute(1, 0, 2, 3)
output_tensor *= mask.permute(0, 3, 1, 2).expand(-1, C, -1, -1)
image = image.permute(1, 0, 2, 3)
return output_tensor, mask
output, _ = bilinear_sample_noloop(x, vgrid)
return output
def _build_volume_2d(
self,
feat_l,
feat_r,
maxdisp,
stride=1
):
assert maxdisp % stride == 0 # Assume maxdisp is multiple of stride
b,c,h,w = feat_l.size()
cost = torch.zeros(b, 1, maxdisp//stride, h, w).cuda().requires_grad_(False)
dim1,dim2,dim3,dim4,dim5 = cost.shape
for i in range(0, maxdisp, stride):
if i > 0:
# cost[:, :, i//stride, :, i:] = torch.norm(feat_l[:, :, :, i:] - feat_r[:, :, :, :-i], p=1, dim = 1,keepdim=True)
tmp_cost = []
for j in range(dim3):
if j == (i // stride):
tmp1 = cost[:, :, [j], :, :i]
# tmp2 = torch.norm(feat_l[:, :, :, i:] - feat_r[:, :, :, :-i], p=1, dim = 1,keepdim=True)
tmp2_sub = feat_l[:, :, :, i:] - feat_r[:, :, :, :-i]
tmp2_abs = torch.sqrt(torch.square(tmp2_sub))
tmp2 = torch.sum(tmp2_abs, 1, keepdim=True)
tmp2 = tmp2.unsqueeze(2)
tmp_cost.append(torch.cat([tmp1,tmp2],axis=4))
else:
tmp = cost[:, :, [j], :, :]
tmp_cost.append(tmp)
cost = torch.cat(tmp_cost,axis=2)
else:
# cost[:, :, i//stride, :, i:] = torch.norm(feat_l[:, :, :, :] - feat_r[:, :, :, :], p=1, dim =1,keepdim=True)
tmp_cost = []
for j in range(dim3):
if j == (i // stride):
tmp1 = cost[:, :, [j], :, :i]
# tmp2 = torch.norm(feat_l[:, :, :, :] - feat_r[:, :, :, :], p=1, dim =1,keepdim=True)
tmp2_sub = feat_l[:, :, :, :] - feat_r[:, :, :, :]
tmp2_abs = torch.sqrt(torch.square(tmp2_sub))
tmp2 = torch.sum(tmp2_abs, 1, keepdim=True)
tmp2 = tmp2.unsqueeze(2)
tmp_cost.append(torch.cat([tmp1,tmp2],axis=4))
else:
tmp = cost[:, :, [j], :, :]
tmp_cost.append(tmp)
cost = torch.cat(tmp_cost,axis=2)
return cost.contiguous()
def _build_volume_2d3(
self,
feat_l,
feat_r,
maxdisp,
disp,
stride=1
):
b,c,h,w = feat_l.size()
batch_disp = disp[:,None,:,:,:].repeat(1, maxdisp*2-1, 1, 1, 1).view(-1,1,h,w)
temp_array = np.tile(np.array(range(-maxdisp + 1, maxdisp)), b) * stride
batch_shift = torch.Tensor(
np.reshape(
temp_array,
[len(temp_array), 1, 1, 1]
)
).cuda().requires_grad_(False)
batch_disp = batch_disp - batch_shift
batch_feat_l = feat_l[:,None,:,:,:].repeat(1,maxdisp*2-1, 1, 1, 1).view(-1,c,h,w)
batch_feat_r = feat_r[:,None,:,:,:].repeat(1,maxdisp*2-1, 1, 1, 1).view(-1,c,h,w)
# cost = torch.norm(batch_feat_l - self.warp(batch_feat_r, batch_disp), 1, 1, keepdim=True)
batch_sub = batch_feat_l - self.warp(batch_feat_r, batch_disp)
batch_abs = torch.sqrt(torch.square(batch_sub))
cost = torch.sum(batch_abs, 1, keepdim=True)
cost = cost.view(b,1 ,-1, h, w).contiguous()
return cost
def forward(self, left, right):
img_size = left.size()
feats_l = self.feature_extraction(left)
feats_r = self.feature_extraction(right)
# print(f'@@@@@@@@@@@@@@@@@@@@@@ range(len(feats_l)): {range(len(feats_l))}')
pred = []
for scale in range(len(feats_l)):
if scale > 0:
wflow = F.upsample(
pred[scale-1],
(feats_l[scale].size(2), feats_l[scale].size(3)),
mode='bilinear'
) * feats_l[scale].size(2) / img_size[2]
cost = self._build_volume_2d3(
feats_l[scale],
feats_r[scale],
3,
wflow,
stride=1
)
else:
cost = self._build_volume_2d(
feats_l[scale],
feats_r[scale],
12,
stride=1
)
#cost = torch.unsqueeze(cost, 1)
cost = self.volume_postprocess[scale](cost)
cost = cost.squeeze(1)
if scale == 0:
pred_low_res = disparityregression2(0, 12)(F.softmax(cost, dim=1))
pred_low_res = pred_low_res * img_size[2] / pred_low_res.size(2)
disp_up = F.upsample(
pred_low_res,
(img_size[2], img_size[3]),
mode='bilinear'
)
pred.append(disp_up)
else:
pred_low_res = disparityregression2(-2, 3, stride=1)(F.softmax(cost, dim=1))
pred_low_res = pred_low_res * img_size[2] / pred_low_res.size(2)
disp_up = F.upsample(
pred_low_res,
(img_size[2], img_size[3]),
mode='bilinear'
)
pred.append(disp_up+pred[scale-1]) #
if self.training:
return pred[0],pred[1],pred[2]
else:
return pred[-1]
class disparityregression2(nn.Module):
def __init__(self, start, end, stride=1):
super(disparityregression2, self).__init__()
self.disp = torch.arange(
start*stride,
end*stride,
stride
).view(1, -1, 1, 1).type(torch.FloatTensor).cuda().requires_grad_(False)
def forward(self, x):
out = torch.sum(x * self.disp, 1, keepdim=True)
return out
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import time
import math
from models import *
import cv2
from PIL import Image
# 2012 data /media/jiaren/ImageNet/data_scene_flow_2012/testing/
parser = argparse.ArgumentParser(description='PSMNet')
parser.add_argument('--KITTI', default='2015', help='KITTI version')
parser.add_argument('--datapath', default='/media/jiaren/ImageNet/data_scene_flow_2015/testing/', help='select model')
parser.add_argument('--loadmodel', default='./trained/pretrained_model_KITTI2015.tar', help='loading model')
parser.add_argument('--leftimg', default= './VO04_L.png', help='load model')
parser.add_argument('--rightimg', default= './VO04_R.png', help='load model')
parser.add_argument('--model', default='RTStereoNet', help='select model')
parser.add_argument('--maxdisp', type=int, default=192, help='maxium disparity')
parser.add_argument('--no-cuda', action='store_true', default=False, help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
if args.model == 'stackhourglass':
model = stackhourglass(args.maxdisp)
model = nn.DataParallel(model)
elif args.model == 'basic':
model = basic(args.maxdisp)
model = nn.DataParallel(model)
elif args.model == 'RTStereoNet':
model = RTStereoNet(args.maxdisp)
else:
print('no model')
if args.cuda:
model.cuda()
if args.loadmodel is not None:
print('load model')
state_dict = torch.load(args.loadmodel)
model.load_state_dict(state_dict['state_dict'])
# H=180
# W=320
# x = torch.randn(1, 3, H, W).cuda()
# torch.onnx.export(model, args=(x,x), f=f"rtstereonet_maxdisp{args.maxdisp}_{H}x{W}.onnx", opset_version=11)
# import sys
# sys.exit(0)
print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
def test(imgL,imgR):
model.eval()
if args.cuda:
imgL = imgL.cuda()
imgR = imgR.cuda()
with torch.no_grad():
disp = model(imgL,imgR)
disp = torch.squeeze(disp)
pred_disp = disp.data.cpu().numpy()
return pred_disp
def main():
normal_mean_var = {'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225]}
infer_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(**normal_mean_var)
]
)
imgL_o = Image.open(args.leftimg).convert('RGB')
imgR_o = Image.open(args.rightimg).convert('RGB')
imgL = infer_transform(imgL_o)
imgR = infer_transform(imgR_o)
# pad to width and hight to 16 times
if imgL.shape[1] % 16 != 0:
times = imgL.shape[1]//16
top_pad = (times+1)*16 -imgL.shape[1]
else:
top_pad = 0
if imgL.shape[2] % 16 != 0:
times = imgL.shape[2]//16
right_pad = (times+1)*16-imgL.shape[2]
else:
right_pad = 0
imgL = F.pad(imgL,(0,right_pad, top_pad,0)).unsqueeze(0)
imgR = F.pad(imgR,(0,right_pad, top_pad,0)).unsqueeze(0)
start_time = time.time()
pred_disp = test(imgL,imgR)
print('time = %.2f' %(time.time() - start_time))
if top_pad !=0 and right_pad != 0:
img = pred_disp[top_pad:,:-right_pad]
elif top_pad ==0 and right_pad != 0:
img = pred_disp[:,:-right_pad]
elif top_pad !=0 and right_pad == 0:
img = pred_disp[top_pad:,:]
else:
img = pred_disp
# img = (img*256).astype('uint16')
# img = Image.fromarray(img)
# Heatmap
depth_map = np.squeeze(img)
d_min = np.min(depth_map)
d_max = np.max(depth_map)
depth_map = (depth_map - d_min) / (d_max - d_min)
depth_map = depth_map * 255.0
depth_map = np.asarray(depth_map, dtype="uint8")
depth_map = cv2.applyColorMap(depth_map, cv2.COLORMAP_JET)
img = Image.fromarray(depth_map)
img.save('Test_disparity.png')
if __name__ == '__main__':
main()
python3 Test_img.py \
--model RTStereoNet \
--loadmodel trained/pretrained_Kitti2015_realtime.tar \
--leftimg 0479_left.png \
--rightimg 0479_right.png
H=180
W=320
onnx_file = f"rtstereonet_maxdisp{args.maxdisp}_{H}x{W}.onnx"
x = torch.randn(1, 3, H, W).cuda()
torch.onnx.export(
model,
args=(x,x),
f=onnx_file,
opset_version=11
)
import onnx
from onnxsim import simplify
model = onnx.load(onnx_file)
model_simp, check = simplify(model)
onnx.save(model_simp, onnx_file)
import sys
sys.exit(0)
H=180
W=320
MODEL=rtstereonet
MAXDISP=192
$INTEL_OPENVINO_DIR/deployment_tools/model_optimizer/mo.py \
--input_model ${MODEL}_maxdisp${MAXDISP}_${H}x${W}.onnx \
--data_type FP32 \
--output_dir ${MODEL}_maxdisp${MAXDISP}_${H}x${W}/openvino/FP32
$INTEL_OPENVINO_DIR/deployment_tools/model_optimizer/mo.py \
--input_model ${MODEL}_maxdisp${MAXDISP}_${H}x${W}.onnx \
--data_type FP16 \
--output_dir ${MODEL}_maxdisp${MAXDISP}_${H}x${W}/openvino/FP16
mkdir -p ${MODEL}_maxdisp${MAXDISP}_${H}x${W}/openvino/myriad
${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/lib/intel64/myriad_compile \
-m ${MODEL}_maxdisp${MAXDISP}_${H}x${W}/openvino/FP16/rtstereonet_maxdisp${MAXDISP}_${H}x${W}.xml \
-ip U8 \
-VPU_NUMBER_OF_SHAVES 4 \
-VPU_NUMBER_OF_CMX_SLICES 4 \
-o ${MODEL}_maxdisp${MAXDISP}_${H}x${W}/openvino/myriad/${MODEL}_maxdisp${MAXDISP}_${H}x${W}.blob
- [WIP] TensorFlow version - bilinear only
import tensorflow as tf
def bilinear_sample_noloop(image, grid):
Nt, H, W, C = image.shape
grid_H = grid.shape[1]
grid_W = grid.shape[2]
xgrid, ygrid = tf.split(
value=grid,
num_or_size_splits=2,
axis=-1,
)
mask = tf.cast(
(xgrid >= 0) & (ygrid >= 0) & (xgrid < W - 1) & (ygrid < H - 1),
dtype=tf.float32
)
x0 = tf.math.floor(xgrid)
x1 = x0 + 1
y0 = tf.math.floor(ygrid)
y1 = y0 + 1
wa = tf.transpose(
a=(x1 - xgrid) * (y1 - ygrid),
perm=[3, 0, 1, 2],
)
wb = tf.transpose(
a=(x1 - xgrid) * (ygrid - y0),
perm=[3, 0, 1, 2],
)
wc = tf.transpose(
a=(xgrid - x0) * (y1 - ygrid),
perm=[3, 0, 1, 2],
)
wd = tf.transpose(
a=(xgrid - x0) * (ygrid - y0),
perm=[3, 0, 1, 2],
)
x0 = tf.cast(
tf.reshape(
tensor=(x0 * mask),
shape=[Nt, grid_H, grid_W],
),
dtype=tf.int64,
)
y0 = tf.cast(
tf.reshape(
tensor=(y0 * mask),
shape=[Nt, grid_H, grid_W]
),
dtype=tf.int64,
)
x1 = tf.cast(
tf.reshape(
tensor=(x1 * mask),
shape=[Nt, grid_H, grid_W]
),
dtype=tf.int64,
)
y1 = tf.cast(
tf.reshape(
tensor=(y1 * mask),
shape=[Nt, grid_H, grid_W]
),
dtype=tf.int64,
)
ind = tf.range(limit=Nt)
ind = tf.reshape(tensor=ind, shape=[Nt, 1])
ind = tf.tile(input=ind, multiples=[1, grid_H])
ind = tf.reshape(tensor=ind, shape=[Nt, grid_H, 1])
ind = tf.tile(input=ind, multiples=[1, 1, grid_W])
ind = tf.cast(ind, dtype=tf.int64)
image = tf.transpose(
a=image,
perm=[3,0,1,2],
)
output_tensor = \
image[:, ind, y0, x0] * wa \
+ image[:, ind, y1, x0] * wb \
+ image[:, ind, y0, x1] * wc \
+ image[:, ind, y1, x1] * wd
output_tensor = tf.transpose(
a=output_tensor,
perm=[1,2,3,0],
)
mask = tf.tile(
input=mask,
multiples=[1,1,1,C],
)
output_tensor = output_tensor * mask
return output_tensor