Input example:
I have a numpy array, e.g.
a=np.array([[0,1], [2, 1], [4, 8]])
Desired output:
I would like to produce a mask array with the max value along a given axis, in my case axis 1, being True and all others being False. e.g. in this case
mask = np.array([[False, True], [True, False], [False, True]])
Attempt:
I have tried approaches using np.amax but this returns the max values in a flattened list:
>>> np.amax(a, axis=1)
array([1, 2, 8])
and np.argmax similarly returns the indices of the max values along that axis.
>>> np.argmax(a, axis=1)
array([1, 0, 1])
I could iterate over this in some way but once these arrays become bigger I want the solution to remain something native in numpy.