1

In my PyTorch implementation of multi-head attention, i have those in __init__()

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,num_heads,use_bias=False):
        super().__init__()
        self.d_out=d_out
        self.num_heads=num_heads
        # In multi-head attention, the output dimension (d_out) is split across multiple attention heads.
        # Each head processes a portion of the total output dimensions independently before being concatenated back together.
        self.head_dim=d_out//num_heads
        self.query_weight = nn.Linear(d_in, d_out, bias=use_bias)
        self.key_weight = nn.Linear(d_in, d_out, bias=use_bias)
        self.value_weight = nn.Linear(d_in, d_out, bias=use_bias)

this is the forward method

def forward(self,x):
    batch_size,sequence_length,d_in=x.shape
    keys=self.key_weight(x)
    queries=self.query_weight(x)
    values=self.value_weight(x)
    # RESHAPING
    # .view() is a PyTorch tensor method that reshapes a tensor without changing its underlying data. It returns a new tensor with the same data but in a different shape.
    keys=keys.view(batch_size,sequence_length,self.num_heads,self.head_dim)
    values=values.view(batch_size,sequence_length,self.num_heads,self.head_dim)
    queries=queries.view(batch_size,sequence_length,self.num_heads,self.head_dim)

I understand that d_out is split across multiple attention heads, but I'm not entirely sure why this reshaping is necessary. How does adding num_heads as a new dimension affect the computation of attention, and what would happen if we skipped this step and kept the shape as "batch_size,sequence_length,d_in"

2 Answers 2

2

I'll give some intuition for multi-head attention first. Instead of taking the query, key and value embeddings and applying the scaled dot product operation on them all at once, mulit-head attention splits the query, key and value vectors into smaller sub-vectors and applies the scaled dot product operation on these sub-vectors in parallel, so that the model during training can "focus" on different parts of the input.

For example,

Single Head Attention
----------------------
Query: [q1 q2 q3 q4 q5 q6]  (dimension = 6)
Key:   [k1 k2 k3 k4 k5 k6]  (dimension = 6)
Value: [v1 v2 v3 v4 v5 v6]  (dimension = 6)

Multi-Head Attention (num_heads = 3)
-------------------------------------
       Head 1  | Head 2  |  Head 3
       (dim=2) | (dim=2) |  (dim=2)
Query: [q1 q2] | [q3 q4] | [q5 q6]
Key:   [k1 k2] | [k3 k4] | [k5 k6]
Value: [v1 v2] | [v3 v4] | [v5 v6]

Each head processes its own subset of the dimensions independently

Head 1: Attention( [q1 q2], [k1 k2] ) → [o1 o2]
Head 2: Attention( [q3 q4], [k3 k4] ) → [o3 o4]
Head 3: Attention( [q5 q6], [k5 k6] ) → [o5 o6]

So to answer this part of your question,

How does adding num_heads as a new dimension affect the computation of attention, and what would happen if we skipped this step and kept the shape as "batch_size,sequence_length,d_in"

This is fine, then it just becomes single head attention and you would have a simpler implementation of scaled dot product, but the model might not work in practice as well as multi-head attention.

Now to go into your code and show how scaled dot product on multiple heads would work:

    keys=self.key_weight(x)
    queries=self.query_weight(x)
    values=self.value_weight(x)

at this stage, you have applied the query, key and value projections on the input. Each of these is of size (batch_size, seq_len, d_in).

    keys=keys.view(batch_size,sequence_length,self.num_heads,self.head_dim)
    values=values.view(batch_size,sequence_length,self.num_heads,self.head_dim)
    queries=queries.view(batch_size,sequence_length,self.num_heads,self.head_dim)

Now after the reshaping each of these is of size (batch_size, seq_len, num_heads, head_dim)

With this, what you would do is a separate scaled dot product operation for each head. You have not written this part of the code.

Normally, you'd want to swap the positions of the num_heads dimension to the 2nd position to make it (batch_size, num_heads, seq_len, head_dim) like this

        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

Now to compute the dot products and get the weights,

scores = torch.matmul(queries, keys.transpose(-1, -2))

Think of this operation as:

  • queries: it is a tensor of dimension batch_size, num_heads, seq_len, head_dim meaning that for each example in the batch, for each head in the example, you have a seq_len, head_dim matrix

  • keys: ditto, same as queries

  • Now what you want to do is, for each example in the batch, for each each head in the example, do a dot product between the 2D query and key matrices. The keys.transpose(-1, -2) helps flip the last two dimensions of keys to make it head_dim, seq_len so that when right-multiplied with seq_len, head_dim it produces seq_len, seq_len which is just a weight for every pair of positions in the sequence.

  • So in the end you have scores as batch_size, num_heads, seq_len, seq_len

I'll leave the rest of the steps in scaled dot product out, because I think this adequately addresses the need for reshaping that you originall asked.

Sign up to request clarification or add additional context in comments.

Comments

1

Reshaping is necessary for the softmax. You want to softmax head attention head separately. This is done by creating a new axis for the attention heads and softmaxing all heads separately in parallel.

If you didn't reshape for the head dimension, you would get the same result as using a single attention head.

1 Comment

Softmax's purpose is to simply normalize the activations to make them sum up to 1. The purpose of reshaping goes well beyond that, i.e., all the way to the scaled dot product operation. By reshaping, you get to do the scaled dot product operation on the split up embeddings, and after that the final step would be to do the softmax. But really, the core operation is the parallel scaled dot product which the reshaping facilities.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.