Open1

ONNX: ScatterND,TensorFlow: tf.tensor_scatter_nd_updateをSlice(strided_slice)やConcat(concat)で置き換え

PINTOPINTO
  • 一例
  def forward(self, pred):
      anchor_grid = np.load('model_105_anchor_grid.npy')
      anchor_grid = torch.tensor(anchor_grid, dtype=torch.float32).to(DEVICE)
      z = []
      st = [8,16,32]
      for i in range(3):
          bs, _, ny, nx = pred[i].shape
          pred[i] = pred[i].view(bs, 3, 85, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
          y = pred[i].sigmoid()
          gr = self._make_grid(nx, ny).to(pred[i].cpu())

          ### MyriadX not supported - Scatter_ND_Update
          # y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + gr) * st[i]  # xy
          # y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i]  # wh
          # z.append(y.view(bs, -1, 85))
          ### MyriadX Compatible - Cat
          y0 = (y[..., 0:1] * 2. - 0.5 + gr[..., 0:1]) * st[i]  # xy
          y1 = (y[..., 1:2] * 2. - 0.5 + gr[..., 1:2]) * st[i]  # xy
          y2 = (y[..., 2:3] * 2) ** 2 * anchor_grid[i][..., 0:1]  # wh
          y3 = (y[..., 3:4] * 2) ** 2 * anchor_grid[i][..., 1:2]  # wh
          y4 = y[..., 4:]
          y5 = torch.cat([y0,y1,y2,y3,y4], dim=4)
          z.append(y5.view(bs, -1, 85))

      pred = torch.cat(z, 1)
      return pred