I warn in advance: I may be utterly confused at the moment. I tell a short story about what I actually try to achieve because that may clear things up. Say I have f(a,b,c,d,e), and I want to find arg max (d,e) f(a,b,c,d,e). Consider a (trivial example of a ) discretized grid F of f:
F = np.tile(np.arange(0,10,0.1)[newaxis,newaxis,:,newaxis,newaxis], [10, 10, 1, 10, 10])
maxE = F.max(axis=-1)
argmaxD = maxE.argmax(axis=-1)
maxD = F.max(axis=-2)
argmaxE = maxD.argmax(axis=-1)
This is the case how I typically solve the discretized version. But now assume instead, that I want to solve arg max d f(a,b,c,d,e=X): Instead of optimally chosen e for every other input, e is a fixed and given (of size AxBxCxD, which in this example would be 10x10x100x10). I have troubles solving this.
My naive approach was
X = np.tile(np.arange(0,10)[newaxis,newaxis,:,newaxis], [10,10,1,10])
maxX = F[X]
argmaxD = maxX.argmax(axis=-1)
However, the huge surge of memory that crashes my IDE implies that F[X] is apparently not what I was looking for.
Performance is key.
np.argmax(np.max(F, axis=-1), axis=-1)