I'll give some intuition for multi-head attention first. Instead of taking the query, key and value embeddings and applying the scaled dot product operation on them all at once, mulit-head attention splits the query, key and value vectors into smaller sub-vectors and applies the scaled dot product operation on these sub-vectors in parallel, so that the model during training can "focus" on different parts of the input.
For example,
Single Head Attention
----------------------
Query: [q1 q2 q3 q4 q5 q6] (dimension = 6)
Key: [k1 k2 k3 k4 k5 k6] (dimension = 6)
Value: [v1 v2 v3 v4 v5 v6] (dimension = 6)
Multi-Head Attention (num_heads = 3)
-------------------------------------
Head 1 | Head 2 | Head 3
(dim=2) | (dim=2) | (dim=2)
Query: [q1 q2] | [q3 q4] | [q5 q6]
Key: [k1 k2] | [k3 k4] | [k5 k6]
Value: [v1 v2] | [v3 v4] | [v5 v6]
Each head processes its own subset of the dimensions independently
Head 1: Attention( [q1 q2], [k1 k2] ) → [o1 o2]
Head 2: Attention( [q3 q4], [k3 k4] ) → [o3 o4]
Head 3: Attention( [q5 q6], [k5 k6] ) → [o5 o6]
So to answer this part of your question,
How does adding num_heads as a new dimension affect the computation of attention, and what would happen if we skipped this step and kept the shape as "batch_size,sequence_length,d_in"
This is fine, then it just becomes single head attention and you would have a simpler implementation of scaled dot product, but the model might not work in practice as well as multi-head attention.
Now to go into your code and show how scaled dot product on multiple heads would work:
keys=self.key_weight(x)
queries=self.query_weight(x)
values=self.value_weight(x)
at this stage, you have applied the query, key and value projections on the input. Each of these is of size (batch_size, seq_len, d_in).
keys=keys.view(batch_size,sequence_length,self.num_heads,self.head_dim)
values=values.view(batch_size,sequence_length,self.num_heads,self.head_dim)
queries=queries.view(batch_size,sequence_length,self.num_heads,self.head_dim)
Now after the reshaping each of these is of size (batch_size, seq_len, num_heads, head_dim)
With this, what you would do is a separate scaled dot product operation for each head. You have not written this part of the code.
Normally, you'd want to swap the positions of the num_heads dimension to the 2nd position to make it (batch_size, num_heads, seq_len, head_dim) like this
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
Now to compute the dot products and get the weights,
scores = torch.matmul(queries, keys.transpose(-1, -2))
Think of this operation as:
queries: it is a tensor of dimension batch_size, num_heads, seq_len, head_dim meaning that for each example in the batch, for each head in the example, you have a seq_len, head_dim matrix
keys: ditto, same as queries
Now what you want to do is, for each example in the batch, for each each head in the example, do a dot product between the 2D query and key matrices. The keys.transpose(-1, -2) helps flip the last two dimensions of keys to make it head_dim, seq_len so that when right-multiplied with seq_len, head_dim it produces seq_len, seq_len which is just a weight for every pair of positions in the sequence.
So in the end you have scores as batch_size, num_heads, seq_len, seq_len
I'll leave the rest of the steps in scaled dot product out, because I think this adequately addresses the need for reshaping that you originall asked.