4

I'm trying to understand the math behind Transformers, specifically self-attention. This link, and many others, gives the formula to compute the output vectors from the input embeddings as:

$$Q=XW_Q,\;\;\;K=XW_K,\;\;\;V=XW_V$$ $$Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d_k})V$$

But this eventually becomes

$$Attention(Q,K,V)=softmax(X\frac{W_QW_K^T}{\sqrt d_k}X^T)V$$

If $W_Q$ and $W_K$ are only ever used in the form $\frac{W_QW_K^T}{\sqrt d_k}$, why do we initialize both matrices at all? why not just define and initialize a single matrix $W_{QK}$, skip the matrix multiplication, and get rid of the redundant weights?

itrase
  • 43
  • 3

1 Answers1

3

The weight matrices are $n$ by $m$ with $n >> m$. So $W_Q W_K^T$ is not just any matrix, it's $n$ by $n$ but with rank only $m$ -- there are fewer parameters, and computing $QK^T$ is much faster than $X W' X^T$ for some full rank $W'$

shimao
  • 22,706
  • 2
  • 42
  • 81
  • Ah - that makes total sense! So even though $XW'X^T$ is one less matrix multiplication than $XW_QWK^TX^T$: 1) $W$ has $n^2$ weights while $W_Q$ and $W_K$ total have $2nm$ weights - so we would be adding unnecessary extra weights when $n>2m$, which will almost always be the case. 2): The number of operations to multiply an $i\times j$ matrix and a $j\times k$ matrix is approximately $ijk$. So $XW_QWK^TX^T$ takes around $2pnm + p^2m$ operations, which will almost always be less than the $pn^2+p^2n$ operations that $XW'X^T$ would take. – itrase Mar 26 '21 at 14:27
  • 1
    I do not see why the weight matrices would not be square matrices. Take a look [here](https://nlp.seas.harvard.edu/2018/04/03/attention.html) for example, in the `MultiHeadedAttention` class where the line `self.linears = clones(nn.Linear(d_model, d_model), 4)` initializes these matrices, if I am not mistaken. – Timo Denk Jul 13 '21 at 18:32