Suppose I have a H*W numpy array like this (H=3 and W=2):
[[0 1]
[2 3]
[4 5]]
I would like to replace each element with a matrix of M*N repeated element (M=2 and N=3):
[[[[0 0 0]
[0 0 0]]
[[1 1 1]
[1 1 1]]]
[[[2 2 2]
[2 2 2]]
[[3 3 3]
[3 3 3]]]
[[[4 4 4]
[4 4 4]]
[[5 5 5]
[5 5 5]]]]
My current code first turn each element into a 2D matrix by calling expand_dims() twice, and then expand these matrices using repeat() twice too:
import numpy as np
H, W, M, N = 3, 2, 2, 3
array = np.arange(H * W).reshape(H, W)
array = np.expand_dims(array, axis=2)
array = np.expand_dims(array, axis=3)
array = np.repeat(array, M, axis=2)
array = np.repeat(array, N, axis=3)
Is there a more straightforward and elegant way to obtain the same result?