When you have batch, head, sequence, head-dim and feature, matrix algebra runs out of letters. Tensors, einsum and named indices give you the language to write — and parallelise — the algorithms that move large models across thousands of GPUs.
A scalar is a 0D array. A vector is a 1D array. A matrix is a 2D array. Anything with more dimensions is a tensor. Mathematically the word has a more specific meaning (multilinear maps with covariance under change of basis); in ML it's used loosely for "$k$-dimensional array of numbers".
Real ML code lives in 3D and 4D tensors:
Either as a (typically batched) collection of matrices — in which case the "extra" axes are just iterators — or as a multilinear object whose contractions express more than matrix algebra can. Most code lives at the first reading; deep theory uses the second.
Almost everything you saw in earlier decks — matmul, projection, attention, FFN — is matrix algebra. The "tensors" are just batched stacks of those matrices. The few genuine higher-order operations (high-order convolutions, certain quantum-circuit contractions) are rare in practice. Einsum is the bridge: it lets you write batched operations with the right dimensions implicit.
Einstein summation convention: any repeated index is summed over. The form taken in NumPy / PyTorch / JAX:
$$\texttt{einsum("ij,jk->ik", A, B)} \quad\Longleftrightarrow\quad C_{ik} = \sum_j A_{ij} B_{jk}.$$
That's standard matmul. The right-of-arrow is the output indices; everything not on the right is summed. Inside the quotes, comma-separated index strings list the input tensors.
Repeated index → summed. Matrix multiply, dot product, attention scoring, all matmul views (deck 03) are contractions.
Index that appears in only one input → preserved unchanged. Adds an "outer" dimension. The B (batch) and H (head) indices in attention are broadcast.
Index reordering: "ij->ji" is transpose. "bhld->blhd" is reshaping multi-head from (batch, head, seq, dim) to (batch, seq, head, dim).
Standard matrix notation has two indices — rows and columns. Once you have $\{batch, head, query, key, head\_dim, feature\}$ at the same time, you need six. Naming each one explicitly turns the algebra into something you can read. The einsum string is a shape annotation that's also executable.
The single attention-block forward pass in einsum:
# Input: x[B, L, d]; weights W_Q[d, H, d_h], W_K[d, H, d_h], W_V[d, H, d_h], W_O[H, d_h, d] q = einsum("bld,dhk->bhlk", x, W_Q) # q[B, H, L, d_h] k = einsum("bld,dhk->bhlk", x, W_K) # k[B, H, L, d_h] v = einsum("bld,dhk->bhlk", x, W_V) # v[B, H, L, d_h] scores = einsum("bhlk,bhmk->bhlm", q, k) / sqrt(d_h) # scores[B, H, L, L] attn = softmax(scores + mask, axis=-1) # attn[B, H, L, L] out = einsum("bhlm,bhmk->bhlk", attn, v) # out[B, H, L, d_h] y = einsum("bhlk,hkd->bld", out, W_O) # y[B, L, d]
Six lines. Read each einsum string and you can immediately see which axes are batched, which are contracted, which are permuted. No .transpose(), no .reshape(), no .permute(). The semantics are right there in the string.
b = batch, l, m, n = sequence positions, h = head, k = head-dim or contraction, d = model dim, e = expert (in MoE).JAX's jax.numpy.einsum handles arbitrary contractions. jax.experimental.shard_map and the GSPMD partitioner reason about which axis is sharded across which device dim — using exactly the einsum index names. PyTorch is moving in the same direction with torch.einsum and DTensor. The notation that lets you write fast batched code on one device is also the notation that lets you partition it across thousands.
A tensor is two things: a flat 1D buffer of numbers and a small "metadata" header that says how to interpret it. The header has the shape and the strides — the number of elements you skip in the flat buffer to advance one step along each axis.
For $X \in \mathbb{R}^{B \times L \times d}$ stored in C-order: stride is $(L \cdot d, \ d, \ 1)$. Element $(b, l, k)$ is at flat offset $b \cdot L d + l \cdot d + k$.
x.transpose(1,2) in PyTorch usually returns a view: same buffer, different strides. x.contiguous() forces a copy with C-order strides. Some kernels require contiguous inputs (cuBLAS) and silently allocate a copy if not.x.view(...) changes shape without moving data. If the requested shape isn't compatible with the current strides, you must .contiguous().view(...).Multi-head attention can store $(B, H, L, d_h)$ or $(B, L, H, d_h)$. The score matmul einsum("bhlk,bhmk->bhlm", q, k) is a batched matmul; with BHLD layout, both operands have the same (B, H) outer dims and the matmul is along last two — cuBLAS-friendly. With BLHD layout you'd need a transpose. Most modern code uses BHLD for attention.
Naive attention materialises the full $L \times L$ score matrix in memory. For $L = 32k$, in fp16, that's ~2 GB per attention head, per layer — instantly unfeasible.
FlashAttention (Dao 2022, FlashAttention-2 2023, FlashAttention-3 2024) computes the same thing without ever writing the full score matrix to HBM. The trick is block matmul plus an online softmax.
Split the queries into row-blocks $Q^{(1)}, Q^{(2)}, \ldots$ each of $B_r$ rows, and the keys/values into column-blocks $K^{(1)}, K^{(2)}, \ldots$ of $B_c$ rows. Each query-block, against each key/value-block:
Score matrix never leaves SRAM. HBM traffic is $O(N d^2 / B_c)$ instead of $O(N^2 d)$. For $N = 32k$ and $B_c = 64$, that's a 500× reduction in memory traffic. Result: identical numerical output, much faster wall-clock, much less peak memory.
At its heart, FlashAttention is a generalised reduction. The softmax-then-weighted-sum can be re-formulated as a streaming reduction over key-blocks: keep the running max $m$, running denominator $\ell$, and running output $o$; on each new tile, rescale them to a new common max. The arithmetic is identical to plain softmax-attention up to floating-point rounding. Block matmul + clever bookkeeping = a 5-10× wall-clock speedup at no quality cost. This is the rare "free lunch" trick that has stuck across an entire field.
A Mixture of Experts FFN replaces the single $W_{\text{up}}, W_{\text{down}}$ pair with $E$ "experts" of which only $k$ activate per token (typically $k = 2$ out of $E = 8, 64$, or 128).
For each token, a learned router produces a softmax distribution over experts, picks the top-$k$, and the FFN is
$$\mathrm{MoE}(\mathbf{x}) = \sum_{e \in \mathrm{TopK}(\mathbf{x})} g_e(\mathbf{x})\, \mathrm{Expert}_e(\mathbf{x}),$$
with $g_e$ the gating weights (renormalised softmax over chosen experts).
Each $\mathrm{Expert}_e$ is a standard FFN: $W_{\text{up},e}, W_{\text{gate},e}, W_{\text{down},e}$. The router output is a sparse vector of length $E$ with exactly $k$ non-zero entries. Read MoE as a sparse-projection FFN: only $k$ of the available "directions" of feature-space contribute on any given token.
This is why MoE has eaten the high-end frontier: at fixed inference FLOPs, MoE models train cheaper and serve more capability.
The router's output is non-differentiable (top-$k$ is hard). Standard practice: differentiate through only the chosen experts; use auxiliary load-balancing losses to encourage even use; tolerate some routing instability. The gating logic is the part of MoE most likely to misbehave, and it's the part that doesn't fit cleanly into matrix algebra.
For very large models, parameters and activations don't fit on one GPU. Three orthogonal ways to shard:
Replicate the whole model on every GPU; split the batch. Each GPU does a forward+backward on its slice of inputs, then all-reduces gradients. Fully model-replicating — bandwidth-limited at large scales.
Variants: ZeRO-1/2/3 (Rajbhandari 2020) shard optimiser state / gradients / parameters across DP rank to reduce memory.
Split a single matmul across GPUs. E.g. $W_{\text{up}} \in \mathbb{R}^{d \times d_{ff}}$ becomes $\big[W_{\text{up}}^{(1)} \mid W_{\text{up}}^{(2)}\big]$ on two GPUs. Output activation is concatenated. Megatron-LM's column/row split is the canonical recipe.
Communication: one all-gather + one reduce-scatter per layer. Bandwidth-heavy; usually kept within a single node (NVLink).
Split the layers: GPU 0 does layers 1-10, GPU 1 does layers 11-20, etc. Activations stream forward, gradients stream back. Overlap with micro-batching to keep devices busy. GPipe / 1F1B schedules are standard.
Bubble overhead: with $P$ pipeline stages, ~$1 - 1/P$ of the time is productive (less with smart scheduling).
A typical big training run combines all three: DP across nodes (pods of GPUs), TP within a node (8 GPUs sharing NVLink), PP across stages of a layer-stack. With GQA + MoE the picture also includes expert parallelism: experts sharded across devices, routed-to as a fourth axis.
The 4D mesh of (DP, TP, PP, EP) on a typical GPU cluster is a non-trivial scheduling problem. Frameworks (PyTorch FSDP, JAX shard_map / GSPMD, Megatron-DeepSpeed) handle it; the einsum view of the underlying op (next slide) is what makes the parallelism legible.
Google's GSPMD framework (Xu et al. 2021) made an observation: parallelising an einsum is a property of the indices, not of the algorithm.
Take the FFN up-projection: einsum("bld,df->blf", x, W_up). Each axis can be sharded (S) or replicated (R) across a 1D device mesh:
| x sharding (b, l, d) | W sharding (d, f) | Output | Communication |
|---|---|---|---|
| (S, R, R) | (R, R) | (S, R, R) — DP | None for forward; all-reduce gradients. |
| (R, R, R) | (R, S) | (R, R, S) — TP column-split | None for forward; all-gather/reduce-scatter at later stages. |
| (R, R, S) | (S, R) | (R, R, R) — TP row-split | All-reduce on output (the sum over the contracted, sharded $d$ axis). |
| (S, R, S) | (S, S) | 2D-sharded | Two-axis collectives. |
Reading the table: shard axes that are not contracted, and you get parallelism with no communication. Shard axes that are contracted, and you need a reduction. The pattern of communication is determined by which einsum indices were sharded; the algorithm itself doesn't care.
JAX's shard_map and PartIR/GSPMD let you annotate each tensor with a sharding spec; the compiler inserts the all-reduces, all-gathers, reduce-scatters that the einsum requires. PyTorch DTensor has the same idea. You write the algorithm in einsum; the system figures out the parallel collective primitives.
Decks 01-05 taught you that everything in a transformer is matrix algebra. Decks 06-08 told you which structure theorems make it work. Decks 10-11 showed you the actual transformer block in those terms. This deck closes the loop: the same algebra plus axis names is also the language of distributed training. Sharding is just sharding einsum indices. There is one structure here, all the way down.
After twelve decks: vectors and matrices, projections and SVD, gradients and chain rules, attention and FFNs, all the way to FlashAttention and tensor parallelism. The arc has been deliberately one-way:
Vectors, matrices, matmul, inner products. The language every later deck reads from.
Projection, eigen, SVD, QR, gradients. The why behind transformer shape and gradient flow.
Attention, the full block forward pass, and the tensor / sharding view that scales it across pods.
Linear algebra is the floor everything else stands on. Every non-trivial idea in modern AI — attention, RoPE, LoRA, MLA, FlashAttention, GQA, MoE, GSPMD — is a specific structural fact about matrices and tensors, applied to a specific problem. With this series in hand you can read those structural facts directly off the algorithms.
Thanks for reading. The companion sub-hubs in the LLMs hub take it from here — transformer architecture, GPU/TPU hardware, modern architectures, training, evaluations and production. Drop into whichever interests you next.