An interview question about transformers
Here is an interview question about the transformer architecture I think is fairly solid:
Why should the unembedding matrix not simply be the transpose of the embedding matrix? The key insight is that optimization pressures will force the unembedding matrix to approximate a bigram statistics head rather than a pure “unembedding” operation. Think of it this way: if the embedding matrix maps the token one-hot vector representing “William” to a vector, the unembedding matrix will be optimized to map it back to “Shakespeare”, not to “William”. This is because the pretraining task that sets the matrices’ parameters is next-token prediction. The task is not recovering the original tokens. This intuition is described well in this tutorial.
To put it more formally, the unembedding matrix optimizes for $P(\text{next token} | \text{current token})$ rather than an identity mapping. Imagine the model had nothing but the embedding and unembedding matrices. Given just the current vector, what is the best the model can do? It cannot bake in contextual information using attention layers, nor memorized facts using MLPs. It can simply learn information about token pair frequencies and store them in its weights. This manifests in striking ways: for highly deterministic sequences like “Saudi” -> “Arabia”, the relevant row for the “Saudi” vector will have a sharp peak for what’s by far the most probable next token (“Arabia”). On the other hand, for tokens with more entropy in their continuations (like “the” or “New”), the row will have a more even distribution over plausible next tokens.
This is what a language model will do if it just has an embedding and an unembedding layer. And, crucially, this is what the two layers will approximately do even if there are a lot of attention/MLP layers in between, if for no other reason than because of the residual connection between them.
Tying the embedding and unembedding matrices is suboptimal because they serve different purposes. The embedding matrix captures semantic/syntactic features useful for the entire network, while the unembedding matrix approximates a conditional probability distribution. (This theoretical observation suggests an experiment: if you yank the unembedding matrix out of a large language model like Llama, do you get something like a bigram frequency-based next-token predictor? In other words, for each input token, how similar are the top-k predicted next tokens to the actual corpus bigram statistics? If the unembedding matrices in powerful LLMs are doing more sophisticated things, what are they?)
This has implications for model pretraining: the unembedding matrix should not be initialized as the transpose of the embedding matrix. This is actively worse than random initialization: most tokens aren’t followed by a copy of themselves!
You might point out that many models like GPT-2 do share the embedding and unembedding parameters. The explanation is that this is a parameter-saving exercise. In GPT-2, the embedding layer hogs a whopping 38 million of the model’s 124 million parameters. Separating the embedding and unembedding layers intuitively can’t give performance gains justifying a ~31% increase in model size. You can trust the optimization process to produce an embedding layer that serves two distinct objectives: generating semantically/syntactically rich representations for the other layers, and doing bigram predictions when used for unembedding.