8  Multi-Head Attention — Many Conversations at Once

A single attention head is expressive, but limited. It produces one set of attention weights — one pattern of “who attends to whom.” In natural language, multiple distinct relationships coexist in the same sentence. Consider:

"The animal didn't cross the street because it was too tired."

Multi-head attention runs H independent attention heads in parallel, each free to specialize on a different relationship type. Their outputs are concatenated and projected back to the model dimension.

8.1 The Idea

A single attention operation produces one pattern of “who attends to whom.” But a sentence carries many different kinds of relationships at the same time.

In “The animal didn’t cross the street because it was too tired”:

  • it refers back to animal — that is a coreference relationship.
  • cross links to street — that is a verb-object relationship.
  • because connects a cause to an effect — that is a logical relationship.

One attention head can only focus on one of these at a time. Multi-head attention runs several independent attention operations in parallel — each one free to specialize on a different pattern. Each head sees the same input but learns to ask a different question of it.

At the end, the results from all heads are stitched back together and projected into a single vector, the same size as before. The model learns entirely from data which head should track grammar, which should track meaning, which should track proximity — no one programs this in explicitly.

The result is a richer representation than any single head could produce: each token’s final vector carries signals from multiple independent attention patterns at once.

8.2 The Math

Step 1 — Compute each head independently.

Each head \(h \in \{1, \ldots, H\}\) projects the input into its own query, key, and value spaces, then runs a standard attention:

\[ \begin{aligned} Q^{h} &= X W_q^{h} \in \mathbb{R}^{T\times d_k} \\ K^{h} &= X W_k^{h} \in \mathbb{R}^{T\times d_k} \\ V^{h} &= X W_v^{h} \in \mathbb{R}^{T\times d_v} \\ \text{head}^{h} &= \operatorname{softmax}\!\left(\frac{Q^{h}{K^{h}}^{\top}}{\sqrt{d_k}} + M\right) V^{h} \end{aligned} \]

Step 2 — Concatenate.

The \(H\) outputs are placed side by side:

\[ \operatorname{MultiHead} = \operatorname{concat}(\text{head}^1, \text{head}^2, \ldots, \text{head}^{H}) \in \mathbb{R}^{T \times (H\cdot d_v)} \]

Because \(d_v = d/H\), the result has shape \([T \times d]\) — the same as the input.

Step 3 — Output projection.

A learned matrix \(W_o \in \mathbb{R}^{d\times d}\) mixes information across heads:

\[ \text{Output} = \operatorname{MultiHead} \cdot W_o \in \mathbb{R}^{T\times d} \]

Putting it all together:

\[ \text{MHA}(X) = [\text{head}^1 \| \text{head}^2 \| \ldots \| \text{head}^{H}]\, W_o \]

where \(\text{head}^{h} = \operatorname{Attention}(X W_q^{h}, X W_k^{h}, X W_v^{h})\).

8.3 The Matrix: Worked Example

Let T = 3, d = 4, H = 2 heads, so \(d_k = d_v = 2\).

Input:

X = [[1, 0, 1, 0],
     [0, 1, 0, 1],
     [1, 1, 0, 0]]   (3×4)

Head 1 uses the first 2 dimensions primarily.

Wq¹ = Wk¹ = Wv¹ = [[1,0],[0,1],[0,0],[0,0]]   (4×2)

Q¹ = X Wq¹ = [[1,0],[0,1],[1,1]]   (3×2)

Scores: \(S^1 = Q^1{K^1}^{\top} / \sqrt{2}\):

Q¹K¹ᵀ = [[1,0,1],[0,1,1],[1,1,2]]
S¹    = [[0.71, 0.00, 0.71],
         [0.00, 0.71, 0.71],
         [0.71, 0.71, 1.41]]

After causal mask and softmax:

A¹ = [[1.000, 0.000, 0.000],
      [0.414, 0.586, 0.000],
      [0.221, 0.221, 0.558]]

head¹ = A¹ V¹ = [[1.000, 0.000],
                 [0.586, 0.414],
                 [0.779, 0.779]]    (3×2)

Head 2 focuses on last 2 dimensions, producing similarly shaped output. After concatenating both heads and applying the output projection \(W_o \in \mathbb{R}^{4\times 4}\):

MultiHead = [head¹ ‖ head²] =
  [[1.000, 0.000, 1.000, 0.000],
   [0.586, 0.414, 0.500, 0.500],
   [0.779, 0.779, 0.421, 0.211]]   (3×4)

Figure Figure 8.1 shows parallel attention heads whose outputs are concatenated and projected back to model width.

Multi-head attention — H parallel heads each attend independently, then concat + project.
Figure 8.1: Multi-head attention

8.4 The Code: Multi-Head Attention in Python

@dataclass
class AttentionHead:
    wq: Matrix
    wk: Matrix
    wv: Matrix


@dataclass
class MultiHeadAttention:
    heads: list[AttentionHead]
    wo: Matrix


def make_multi_head_attention(d_model: int, num_heads: int, rng: random.Random) -> MultiHeadAttention:
    if d_model % num_heads != 0:
        raise ValueError("d_model must be divisible by num_heads")
    d_key = d_model // num_heads
    heads = [
        AttentionHead(
            random_matrix(d_model, d_key, rng),
            random_matrix(d_model, d_key, rng),
            random_matrix(d_model, d_key, rng),
        )
        for _ in range(num_heads)
    ]
    return MultiHeadAttention(heads=heads, wo=random_matrix(d_model, d_model, rng))

Each attention head is a triple of weight matrices \((W_q, W_k, W_v)\), each of shape \([d \times d_k]\). MultiHeadAttention groups the heads with the output projection \(W_o\). make_multi_head_attention checks that the model width can be split evenly, allocates one parameter set per head, and creates the final output projection.

def multi_head_attention(x: Matrix, params: MultiHeadAttention) -> tuple[Matrix, list[Matrix]]:
    results = [
        self_attention(x, head.wq, head.wk, head.wv)
        for head in params.heads
    ]
    concatenated = hstack([output for output, _weights in results])
    return matrix_multiply(concatenated, params.wo), [weights for _output, weights in results]

multi_head_attention runs each head’s SDPA independently, concatenates the per-head outputs column-wise, and applies the output projection.

def chapter_08(seed: int = 7) -> dict[str, object]:
    rng = random.Random(seed)
    x = random_matrix(4, 8, rng)
    params = make_multi_head_attention(8, 2, rng)
    output, weights = multi_head_attention(x, params)
    return {
        "output_shape": (len(output), len(output[0])),
        "num_heads": len(weights),
    }

The demo creates a two-head attention module, runs it over a four-token sequence, and reports the output shape plus the number of attention-weight matrices. Run it with python3 src/python/chapter_demos.py.

8.5 Why Multi-Head Attention Works

Each head learns to specialize. Research has identified heads that:

  • Track syntactic structure (subject-verb agreement)
  • Resolve coreference (“it” → “the cat”)
  • Handle positional offsets (“look 2 tokens back”)
  • Track rare-word semantics

These specializations emerge from training, not from explicit design.

Math Minute — Expressivity

H heads of dimension d/H can represent attention patterns that a single head of dimension d cannot easily learn. This is analogous to having H different “lenses” looking at the same sequence; each lens focuses on different features. The output projection Wo then combines the views.

8.6 Key Takeaways

  • Multi-head attention runs H attention heads in parallel, each in a lower-dimensional subspace \(d_k = d/H\).
  • Each head has independent \(W_q, W_k, W_v\) matrices — each learns a different “question to ask.”
  • Outputs are concatenated (not averaged), then projected back to d with \(W_o\).
  • The total parameter count is \(3Hd\cdot d_k + d^2 = 3d^2 + d^2\) — same as one large head.
  • Different heads specialize in different linguistic relationships.

What’s next? After attention mixes information across tokens, each token’s vector goes through a small feed-forward network — a two-layer MLP applied identically to every position. This is where most of the model’s stored “knowledge” lives. See Chapter 9.