Linear Algebra for AI — Presentation 11

Transformer Block Anatomy

Walk a tensor through a full pre-norm transformer block. Pre-norm, multi-head attention with all four projections, residual, FFN with up-then-down, residual. Every shape, every matmul, every parameter.

pre-normRMSNorm QKV+OFFN SwiGLUresidual stream
x norm MHA + norm FFN + x'
00

Topics We'll Cover

01

The Pre-Norm Block in One Equation

Modern transformers (Llama, Mistral, Qwen, GPT-3+) use the pre-norm block:

$$\mathbf{x}' = \mathbf{x} + \mathrm{MHA}(\mathrm{Norm}(\mathbf{x})),$$

$$\mathbf{x}'' = \mathbf{x}' + \mathrm{FFN}(\mathrm{Norm}(\mathbf{x}')).$$

Two sub-layers, each with its own normalisation, sub-layer function and residual connection. The output of the block goes into the next block's input.

Three properties of this layout

Pre-norm vs post-norm

The original transformer (Vaswani 2017) used post-norm: $\mathbf{x}' = \mathrm{Norm}(\mathbf{x} + \mathrm{MHA}(\mathbf{x}))$. This requires careful warmup to train stably past ~12 layers. Pre-norm (Xiong et al., ICML 2020) trains stably to 100+ layers without warmup tricks. Modern models are universally pre-norm; we focus on it here.

02

The Residual Stream as a Linear Bus

The residual stream $\mathbf{x} \in \mathbb{R}^{d_{model}}$ is the bus that runs through every block. Each sub-layer reads from it (after norm) and writes back via the residual add.

Linear-algebraically: at the end of layer $L$, the residual stream is

$$\mathbf{x}^{(L)} = \mathbf{x}^{(0)} + \sum_{\ell=1}^L \big[\mathrm{MHA}^{(\ell)}(\cdots) + \mathrm{FFN}^{(\ell)}(\cdots)\big] = \mathbf{x}^{(0)} + \sum_\ell \Delta_\ell.$$

The residual stream is a sum: the embedding plus all the contributions written by every sub-layer.

What this implies

Shared coordinate system

Every layer reads and writes in the same $d_{model}$-dim space. Directions in this space have meaning that persists across layers; mechanistic interpretability builds on this.

Bandwidth budget

Each sub-layer can write $\le d_{model}$ independent dimensions per pass. The residual stream is linearly $d_{model}$-dimensional, but with superposition (deck 01) carries vastly more meaningful directions.

Decomposability

Reading "what does layer 5 contribute to the final logits?" is a linear question: project $\mathrm{MHA}^{(5)}(\cdots) + \mathrm{FFN}^{(5)}(\cdots)$ onto the unembedding directions. Logit lens, tuned lens, and direct logit attribution all rely on this.

"Mathematical Framework" (Anthropic 2021)

Elhage et al. introduced the residual-stream picture as the right way to read transformers: every component reads a linear projection of the stream and writes another linear contribution back. Non-linearities (softmax, SwiGLU) are local to each sub-layer; the inter-layer story is purely additive. This abstraction is what makes circuit-level interpretation tractable.

03

LayerNorm vs RMSNorm — the Geometry

LayerNorm (Ba 2016): subtract per-token mean, divide by per-token standard deviation, scale and shift:

$$\mathrm{LN}(\mathbf{x}) = \mathbf{g} \odot \frac{\mathbf{x} - \mu(\mathbf{x})}{\sqrt{\sigma^2(\mathbf{x}) + \varepsilon}} + \mathbf{b}.$$

RMSNorm (Zhang & Sennrich 2019): just divide by the per-token RMS:

$$\mathrm{RMS}(\mathbf{x}) = \mathbf{g} \odot \frac{\mathbf{x}}{\sqrt{\frac{1}{d}\sum_i x_i^2 + \varepsilon}}.$$

RMSNorm is a strict simplification: drop the centering, drop the bias. Almost all modern LLMs (Llama, Mistral, Qwen, DeepSeek, Gemma) use RMSNorm.

The geometric reading

Both normalisations are non-linear, but they have a geometric interpretation. For RMSNorm with $\mathbf{g} = \mathbf{1}, \mathbf{b} = 0$:

$$\mathrm{RMS}(\mathbf{x}) = \sqrt{d}\, \frac{\mathbf{x}}{\|\mathbf{x}\|_2}.$$

RMSNorm projects $\mathbf{x}$ onto the sphere of radius $\sqrt{d}$. Direction is preserved; magnitude is fixed. The learnable gain $\mathbf{g}$ then per-axis stretches the result.

For LayerNorm: project onto the affine hyperplane $\sum_i x_i = 0$ (zero mean), then onto the sphere of radius $\sqrt{d}$ in that hyperplane. Same idea with one extra constraint.

Why drop the centering

For activations with roughly zero mean — which residual streams already are by training dynamics — the mean subtraction is a no-op. Removing it saves one reduction per token and makes the gradient simpler. Empirically: RMSNorm matches LayerNorm in quality at lower compute. The simplification has stuck.

04

Walking the Tensor — Shapes & Matmuls

Take $d_{model} = 4096$, $H = 32$, $d_h = 128$, $d_{ff} = 11008$ (Llama-2 7B per-block). Walk a single token through.

StepOperationShape afterFLOPs (per token)
0Input x(4096,)
1x_norm = RMSNorm(x)(4096,)~$4d$ (negligible)
2q = W_Q · x_norm (32 heads concatenated)(4096,)$2d^2 = 33.5$M
3k = W_K · x_norm(4096,)$2d^2 = 33.5$M
4v = W_V · x_norm(4096,)$2d^2 = 33.5$M
5Reshape q, k, v to (32, 128); apply RoPE to q, k(32, 128) ×3~$4d$ (rotation)
6For each of 32 heads: S = q · K_cache_h · √(1/128)(32, L) per head$2 d L$
7A = softmax(S) rowwise(32, L)~$3HL$ (negligible)
8O_h = A · V_cache_h per head(32, 128)$2 d L$
9Concat heads → (4096,); attn_out = W_O · concat(4096,)$2d^2 = 33.5$M
10Residual add: x1 = x + attn_out(4096,)$d$ adds
11x1_norm = RMSNorm(x1)(4096,)~$4d$
12gate = W_gate · x1_norm(11008,)$2 d \cdot d_{ff} = 90.2$M
13up = W_up · x1_norm(11008,)$2 d \cdot d_{ff} = 90.2$M
14h = SiLU(gate) · up (elementwise)(11008,)$3 d_{ff}$ (elem)
15ffn_out = W_down · h(4096,)$2 d \cdot d_{ff} = 90.2$M
16Residual add: x2 = x1 + ffn_out(4096,)$d$ adds
17Output x2 → next block(4096,)

Per-token, per-block FLOPs (excluding $L$-dependent attention scoring): $4 \cdot 2d^2 + 3 \cdot 2 d \cdot d_{ff} \approx 4 \cdot 33.5\mathrm{M} + 3 \cdot 90\mathrm{M} \approx 404\mathrm{M}$. The $L$-dependent attention adds $4 d L = 16k \cdot L$ FLOPs on top — significant for long contexts.

05

Parameter Budget per Block

Counting parameters in a single Llama-style block ($d_{model} = d$, $d_{ff} = 8d/3$ for SwiGLU at the same effective ratio):

ComponentShapeParamsFor $d=4096$
$W_Q$$d \times d$$d^2$16.8 M
$W_K$$d \times d_{kv}$$d \cdot d_{kv}$16.8 M (MHA)
$W_V$$d \times d_{kv}$$d \cdot d_{kv}$16.8 M (MHA)
$W_O$$d \times d$$d^2$16.8 M
$W_{\text{gate}}$$d \times d_{ff}$$d \cdot d_{ff}$45.1 M (at $d_{ff}\approx 11k$)
$W_{\text{up}}$$d \times d_{ff}$$d \cdot d_{ff}$45.1 M
$W_{\text{down}}$$d_{ff} \times d$$d \cdot d_{ff}$45.1 M
2× RMSNorm gain$d$ each$2d$8 K (negligible)
Total per block$4d^2 + 3 d\, d_{ff}$~202 M

Where the parameters live

For Llama-2 7B with 32 blocks: $32 \times 202\mathrm{M} \approx 6.5$ B parameters from blocks alone, plus embedding + unembedding ($\sim 0.5$ B). Total matches the published 6.7 B figure.

For Llama-3 70B with GQA ($d_{kv} = d/8$): per-block QKVO becomes $d^2 + 2 d \cdot d/8 + d^2 = 2.25 d^2$ instead of $4d^2$ — nearly half off the attention parameter count. FFN remains the same. The GQA savings show up in both KV cache (deck 10) and static parameters.

06

FLOP Budget per Token, per Block

Each weight matrix used in a forward pass costs $2 \times \text{params}$ FLOPs per token (one multiply + one add per param). For our $d = 4096, d_{ff} = 11008$ block:

For 32 blocks, that's ~15 GFLOPs/token. On a H100 at ~990 TF/s FP16 we'd expect ~66k tokens/sec at 100% utilisation; real throughput is lower because of memory bandwidth and cache effects on attention scoring.

The 6N rule

For a model with $N$ parameters, training one token costs ~$6N$ FLOPs (Hoffmann et al., Chinchilla 2022): one forward (~$2N$) plus a backward that's twice the forward (~$4N$). This is a back-of-envelope for any dense decoder LLM, and falls directly out of the "3× for backward" rule (deck 09) plus the parameter count above.

When attention scoring dominates

For short context ($L \ll d$), per-token FLOPs are dominated by static-parameter matmuls (~480 M). For long context ($L \sim 32k$), attention scoring becomes $4 d L \approx 0.5$ G per token, comparable to the FFN. This is what makes long-context attention expensive. FlashAttention (deck 12) and MLA (deck 07) tackle the memory-bandwidth side of the same problem.

07

Attention vs FFN — Where the Compute Lives

Stepping back to the empirical splits per block:

Attention sub-layer

  • Static params: $\sim 4 d^2$ (MHA) or $\sim 2.25 d^2$ (GQA).
  • Per-token matmul FLOPs: $2 \times$ static params.
  • Plus $L$-dependent attention: $4 d L$ FLOPs/token. KV-cache memory: $2 d L$ floats per layer (MHA), ~8× less with GQA.
  • Role: token-to-token routing; copying/induction/name-binding circuits live here.

FFN sub-layer

  • Static params: $\sim 3 d \cdot d_{ff} \approx 8 d^2$ (SwiGLU at ratio $8d/3$).
  • Per-token FLOPs: $2 \times$ static params $\approx 16 d^2$.
  • No $L$-dependent term. KV cache: zero.
  • Role: per-token "memory" / pattern bank; mechanistic-interpretability work argues each hidden neuron is a (key, value) pair.

The empirical split

For Llama-2 7B: ~33% of params and FLOPs in attention QKVO, ~67% in FFN. For Llama-3 70B with GQA: the split tilts further toward FFN (~25% / 75%). MoE models like DeepSeek-V3 have huge FFN-MoE blocks — the ratio goes nearly all-FFN, but only the $k$ activated experts contribute per token.

What this tells us about transformers as a structure: most of the computation is per-token feature processing (FFN); the inter-token communication channel (attention) is comparatively narrow and grows only sub-linearly with model size. In an information-flow picture, attention is the network's bandwidth between positions; FFN is the bandwidth into the parameter store at each position.

08

Reference PyTorch Implementation

A minimal, correct decoder block in ~40 lines. This is essentially the structure inside Llama, Mistral, Qwen and most modern LLMs.

# Pre-norm decoder block, RMSNorm + GQA + RoPE + SwiGLU
import torch
import torch.nn.functional as F
from torch import nn

class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6):
        super().__init__()
        self.g = nn.Parameter(torch.ones(d))
        self.eps = eps
    def forward(self, x):
        rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
        return x * rms * self.g

class SwiGLUFFN(nn.Module):
    def __init__(self, d, d_ff):
        super().__init__()
        self.w_gate = nn.Linear(d, d_ff, bias=False)
        self.w_up   = nn.Linear(d, d_ff, bias=False)
        self.w_down = nn.Linear(d_ff, d, bias=False)
    def forward(self, x):
        return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))

class GQAAttention(nn.Module):
    def __init__(self, d, n_heads, n_kv_heads, d_head):
        super().__init__()
        self.h, self.h_kv, self.dh = n_heads, n_kv_heads, d_head
        self.w_q = nn.Linear(d, n_heads     * d_head, bias=False)
        self.w_k = nn.Linear(d, n_kv_heads  * d_head, bias=False)
        self.w_v = nn.Linear(d, n_kv_heads  * d_head, bias=False)
        self.w_o = nn.Linear(n_heads * d_head, d,     bias=False)
    def forward(self, x, rope, mask):
        B, L, _ = x.shape
        q = self.w_q(x).view(B, L, self.h,    self.dh)
        k = self.w_k(x).view(B, L, self.h_kv, self.dh)
        v = self.w_v(x).view(B, L, self.h_kv, self.dh)
        q, k = rope(q), rope(k)
        # GQA: repeat K, V to match number of Q heads
        k = k.repeat_interleave(self.h // self.h_kv, dim=2)
        v = v.repeat_interleave(self.h // self.h_kv, dim=2)
        q = q.transpose(1, 2)  # (B, H, L, dh)
        k = k.transpose(1, 2); v = v.transpose(1, 2)
        attn = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=mask is None)
        return self.w_o(attn.transpose(1, 2).reshape(B, L, -1))

class Block(nn.Module):
    def __init__(self, d, n_h, n_kv, dh, d_ff):
        super().__init__()
        self.norm1 = RMSNorm(d)
        self.attn  = GQAAttention(d, n_h, n_kv, dh)
        self.norm2 = RMSNorm(d)
        self.ffn   = SwiGLUFFN(d, d_ff)
    def forward(self, x, rope, mask=None):
        x = x + self.attn(self.norm1(x), rope, mask)   # residual #1
        x = x + self.ffn(self.norm2(x))                # residual #2
        return x

Every line maps directly to the linear-algebra you've now seen: RMSNorm projects to a sphere, the four attention projections are matrix maps to/from head-spaces, the FFN is up-then-down with a SwiGLU gate, and the residual adds keep the block defaulting to identity.

09

Variants — Pre/Post Norm, GeLU/SwiGLU, Bias-Free

ChoiceOldModern (Llama-3 / Qwen / DeepSeek)Why
Norm placementPost-norm (Vaswani 2017)Pre-normTrains stably to many more layers without warmup.
Norm typeLayerNorm (Ba 2016)RMSNormOne reduction, no centering, no bias — almost free win.
FFN activationReLU / GeLUSwiGLUGated; better quality at same params (Shazeer 2020).
FFN ratio$4d$ (GeLU)$\frac{8}{3}d$ (SwiGLU, 3 matrices)Keeps total FFN params constant despite extra matrix.
Position encodingSinusoidal / learnedRoPERelative position; better extrapolation; almost free.
Linear bias$y = Wx + b$$y = Wx$ (no bias)Negligible quality cost; saves params; matches RMSNorm's bias-less style.
AttentionMHA (one K, V per head)GQA (1 K/V per group of 8 heads) or MLAKV cache size; long-context inference cost.

What's stayed the same

The transformer block has been remarkably stable since 2017: each component has had one or two refinements, but the topology hasn't changed. New architectures (Mamba, Hyena, RWKV) propose alternatives at the per-token level — but inside their blocks you still find a residual stream, a normalisation, two sub-layers and a residual add. The skeleton is genuinely the right one.

10

Cheat Sheet

Read next

Deck 12 — Tensors, Einsum & Modern Tricks takes the per-token block and scales it: batched / sharded / striped GEMMs, FlashAttention as block matmul, MoE as sparse projection, and the einsum view that makes parallelism strategies legible.