I would suggest looking into the math a little more. I think all of the matrices in the attention layer are a (linear) function of the input sequence. So the output of the attention layer is softmax of a quadratic of the input iirc
Machine Learning
Community Rules:
- Be nice. No offensive behavior, insults or attacks: we encourage a diverse community in which members feel safe and have a voice.
- Make your post clear and comprehensive: posts that lack insight or effort will be removed. (ex: questions which are easily googled)
- Beginner or career related questions go elsewhere. This community is focused in discussion of research and new projects that advance the state-of-the-art.
- Limit self-promotion. Comments and posts should be first and foremost about topics of interest to ML observers and practitioners. Limited self-promotion is tolerated, but the sub is not here as merely a source for free advertisement. Such posts will be removed at the discretion of the mods.
Can you show ordering equivariance of the single matrix with the two matrices?
This form of Attention much be equivariant with respect to token order, eg
attn(ABCD) == rot2(attn(rot2(ABCD))) == rot2(attn(CDAB))
I am using rot here for token rotation.
If I'm understanding your question correctly, it probably doesn't make any differences computation wise. But if we have query dot key as one single input, then the attention layer would just have two inputs: 1.query dot key matrix; 2. value matrix. I think this would be a worse formulation thant the original paper altough they are the same computation wise. By allowing separate key and value matrices, the data flow is clearer. For example the Encoder-Decoder attention layer takes the result of Encoder block as key and value but the processed target sequence as value. This idea is very clear with the original attention layer formation.
It's the same mathematically but not computation wise, the tokens are projected to a smaller dimension. The complexity is 2Nd whereas it'd be N² if you'd fuse the weight matrices.
This.
These answers seem weird to me. Am I misunderstanding? Here's the obvious-seeming answer to me:
You need two different matrices because you need an attention coefficient for every single pair of vectors.
If there are n tokens, then for the n'th token you need n-1 different attention coefficients (one for each token it attends). For the n-1'th token, you need n-2 different coefficients, and so on, until the 2nd vector which needs only one coefficient, and the first vector which needs zero (it can't attend anything).
That's ~n^2 coefficients in total. If you compute key and query vectors, then you only need 2n different vectors (one key and one query for each of the n vectors). If the key/query vectors are d-dimensional that's 2dn numbers, which is still smaller than n^2 if the context size is bigger than the key/query dimension
So using separate vectors is more efficient and more scalable.
The other answers on this thread seem different, which is surprising to me since this answer feels very straightforward. If I'm missing something, I'd love an explanation
dammit all of the answers are fkin terrible. Looks like the ai bots took over or everyone in this subreddit has become braindead since the blackout.
You obviously don't do W_q @ W_k. That's totally stupid.
What transformers do is (x_i@W_q) @ (x_j@W_k) where x_i and x_j are two tokens in the sequence. This is an interaction operation. This can't be precomputed. What you see noted in the papers is Q = x_i @ W_q, and K = x_j @ W_k.
(Transposes omitted for notational clarity, work that out yourself)
Your answer is also terrible. It does not answer his question.
Look at the top 2 replies to see correct interpretations of the question.
If you keep the matrices separate, you can control the rank of the learned weights.
Otherwise, the (single) matrix will be full rank.
Something to add to the other great answers here - you can say something similar about head-specific matrices W_V and W_O - they always act together as well. In fact, Anthropic recommends thinking of W_OW_V and W_Q^TW_K as basic primitives in their transformer interpretability framework: https://transformer-circuits.pub/2021/framework/index.html
I don't remember the ref but I browsed a theory paper at some point that did consider that representation (the product explicitly), possibly with something like nuclear norm regularization to keep the rank low.
On the Eleuther AI discord, someone once asked that question. And someone else replied that yeah, obviously having 1 matrix instead of 2 should be better in theory, but then, in practice, empirically, that makes things worse. Why? Noone knows.