I have noticed the following while trying to increase the performance of my code:
>>> a, b = torch.randn(1000,1000), torch.randn(1000,1000)
>>> c, d = torch.randn(10000, 100), torch.randn(100, 1000)
>>> e, f = torch.randn(100000, 10), torch.randn(10, 1000)
>>> %timeit torch.mm(a, b)
17 ms ± 303 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit torch.mm(c, d)
24.4 ms ± 575 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit torch.mm(e, f)
138 ms ± 590 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Theoretically speaking, each matrix operation above requires 10^9 multiplications, but there is a big difference in practice! As the matrices become more rectangular, the performance starts to degrade. I thought of cache misses as one reason, but it seems that each of these multiplications are cache-friendly. Why is it faster to multiply the square matrices?