There is a index tensor like this: [[1,2,3],[1,2,3]] (shape is batch * length)
and there is a value tensor like this: (shape is batch * length * deep)
[[[0.9,0.9,0.1,0.1],[0.9,0.1,0.8,0.1],[0.9,0.1,0.1,0.6]],
[[0.1,0.9,0.8,1],[1,2,0.8,0.1],[0.1,0.1,2,0.6]]].
how can I get [[0.9,0.8,0.6],[0.9,0.8,0.6]] with tensorflow?
tf.math.reduce_max(tensor, axis=-1)?