Open2
ArgMinのスクラッチ実装
import numpy as np
np.set_printoptions(threshold=np.inf, linewidth=500)
import pprint
np.random.seed(seed=32)
# (1, 10, 100, 17)
a = np.random.randint(0, 50, (1, 10, 3, 4))
print(a)
b = a.argmin(axis=2)
print(b.shape)
print(b)
"""
1 1 0 1
[[[[23 43 5 24]
[19 7 25 3]
[37 42 9 4]]
[[11 17 3 1]
[34 35 24 42]
[18 10 20 36]]
[[36 5 38 13]
[11 11 26 29]
[29 26 47 0]]
[[11 6 46 30]
[ 4 41 30 16]
[19 37 10 35]]
[[45 24 47 11]
[29 47 43 11]
[34 41 43 49]]
[[ 4 28 48 20]
[18 17 34 16]
[18 14 25 6]]
[[14 31 0 44]
[14 41 15 1]
[48 14 19 33]]
[[12 7 40 14]
[ 6 6 42 34]
[40 17 36 8]]
[[25 9 18 17]
[31 18 32 4]
[ 0 7 0 48]]
[[30 35 18 20]
[ 8 2 37 8]
[14 28 25 18]]]]
(1, 10, 4)
[[[1 1 0 1]
[0 2 0 0]
[1 0 1 2]
[1 0 2 1]
[1 0 1 0]
[0 2 2 2]
[0 2 0 1]
[1 1 2 2]
[2 2 2 1]
[1 1 0 1]]]
"""
import tensorflow as tf
import pprint
# a = np.random.randint(0, 50, (1, 10, 3, 4))
con = tf.constant(-a)
m = tf.math.argmax(con, axis=2)
pprint.pprint('m')
pprint.pprint(m)
pprint.pprint('m.shape')
pprint.pprint(m.shape)
'm'
<tf.Tensor: shape=(1, 10, 4), dtype=int64, numpy=
array([[[1, 1, 0, 1],
[0, 2, 0, 0],
[1, 0, 1, 2],
[1, 0, 2, 1],
[1, 0, 1, 0],
[0, 2, 2, 2],
[0, 2, 0, 1],
[1, 1, 2, 2],
[2, 2, 2, 1],
[1, 1, 0, 1]]])>
'm.shape'
TensorShape([1, 10, 4])