0

I have noticed the following while trying to increase the performance of my code:

>>> a, b = torch.randn(1000,1000), torch.randn(1000,1000)
>>> c, d = torch.randn(10000, 100), torch.randn(100, 1000)
>>> e, f = torch.randn(100000, 10), torch.randn(10, 1000)
  
>>> %timeit torch.mm(a, b)
17 ms ± 303 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit torch.mm(c, d)
24.4 ms ± 575 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit torch.mm(e, f)
138 ms ± 590 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Theoretically speaking, each matrix operation above requires 10^9 multiplications, but there is a big difference in practice! As the matrices become more rectangular, the performance starts to degrade. I thought of cache misses as one reason, but it seems that each of these multiplications are cache-friendly. Why is it faster to multiply the square matrices?

1 Answer 1

0

I am a bit confused that you state that both operations require 10^9 operations. O(n^3) is only true for square matrices.

Note, that I will neglect improved matrix algorithms like Strassen which lower this bound significantly. I'll prove my point with the naive implementation as you would have learned it in school or university.

For this naive matrix multiplication of matrices with shape (MxN) and (NxM) the following rules hold for the number of operations:

  • The number of additions is M^2 * (2N - 1)
  • The number of multiplications is M^2 * N
M N Additions [1e9] Multiplications [1e9] Operations [1e9]
1000 1000 0.1999 0.1 0.2999
10000 100 0.199 1.0 2.99
100000 10 19 10.0 29

As mentioned before this table neglects more efficient algorithms as well as computational barriers like cache locality. But it shows the general direction that in general square matrix multiplication requires less operations which is reflected in the times you measured (However, torch.mm is already very optimized which is a reason why the times do not scale directly with the number of operations in the table)

Sign up to request clarification or add additional context in comments.

1 Comment

Thank you, but multiplying a matrix of shape (n,m) and (m,k) requires nmk multiplications. The examples I provided are not of the form (n,m) and (m,n). If you look closely to the shapes you'll realize that they all have 10^9 multiplications.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.