Open2

ArgMinのスクラッチ実装

PINTOPINTO
import numpy as np
np.set_printoptions(threshold=np.inf, linewidth=500)

import pprint
np.random.seed(seed=32)

# (1, 10, 100, 17)
a = np.random.randint(0, 50, (1, 10, 3, 4))
print(a)

b = a.argmin(axis=2)

print(b.shape)
print(b)

"""
     1  1  0  1
[[[[23 43  5 24]
   [19  7 25  3]
   [37 42  9  4]]

  [[11 17  3  1]
   [34 35 24 42]
   [18 10 20 36]]

  [[36  5 38 13]
   [11 11 26 29]
   [29 26 47  0]]

  [[11  6 46 30]
   [ 4 41 30 16]
   [19 37 10 35]]

  [[45 24 47 11]
   [29 47 43 11]
   [34 41 43 49]]

  [[ 4 28 48 20]
   [18 17 34 16]
   [18 14 25  6]]

  [[14 31  0 44]
   [14 41 15  1]
   [48 14 19 33]]

  [[12  7 40 14]
   [ 6  6 42 34]
   [40 17 36  8]]

  [[25  9 18 17]
   [31 18 32  4]
   [ 0  7  0 48]]

  [[30 35 18 20]
   [ 8  2 37  8]
   [14 28 25 18]]]]

(1, 10, 4)
[[[1 1 0 1]
  [0 2 0 0]
  [1 0 1 2]
  [1 0 2 1]
  [1 0 1 0]
  [0 2 2 2]
  [0 2 0 1]
  [1 1 2 2]
  [2 2 2 1]
  [1 1 0 1]]]
"""

import tensorflow as tf
import pprint

# a = np.random.randint(0, 50, (1, 10, 3, 4))

con = tf.constant(-a)
m = tf.math.argmax(con, axis=2)
pprint.pprint('m')
pprint.pprint(m)
pprint.pprint('m.shape')
pprint.pprint(m.shape)
PINTOPINTO
'm'
<tf.Tensor: shape=(1, 10, 4), dtype=int64, numpy=
array([[[1, 1, 0, 1],
        [0, 2, 0, 0],
        [1, 0, 1, 2],
        [1, 0, 2, 1],
        [1, 0, 1, 0],
        [0, 2, 2, 2],
        [0, 2, 0, 1],
        [1, 1, 2, 2],
        [2, 2, 2, 1],
        [1, 1, 0, 1]]])>
'm.shape'
TensorShape([1, 10, 4])