Linear Algebra for AI — Presentation 12

Tensors, Einsum & Modern Tricks

When you have batch, head, sequence, head-dim and feature, matrix algebra runs out of letters. Tensors, einsum and named indices give you the language to write — and parallelise — the algorithms that move large models across thousands of GPUs.

tensoreinsum FlashAttentionMoE tensor parallelpipeline parallel
matrix → tensor einsum tile / shard FlashAttn / MoE / TP
00

Topics We'll Cover

01

From Matrix to Tensor

A scalar is a 0D array. A vector is a 1D array. A matrix is a 2D array. Anything with more dimensions is a tensor. Mathematically the word has a more specific meaning (multilinear maps with covariance under change of basis); in ML it's used loosely for "$k$-dimensional array of numbers".

Real ML code lives in 3D and 4D tensors:

Two ways to read a tensor

Either as a (typically batched) collection of matrices — in which case the "extra" axes are just iterators — or as a multilinear object whose contractions express more than matrix algebra can. Most code lives at the first reading; deep theory uses the second.

Tensor algebra is matrix algebra plus indexing

Almost everything you saw in earlier decks — matmul, projection, attention, FFN — is matrix algebra. The "tensors" are just batched stacks of those matrices. The few genuine higher-order operations (high-order convolutions, certain quantum-circuit contractions) are rare in practice. Einsum is the bridge: it lets you write batched operations with the right dimensions implicit.

02

Einsum — the Notation that Scales

Einstein summation convention: any repeated index is summed over. The form taken in NumPy / PyTorch / JAX:

$$\texttt{einsum("ij,jk->ik", A, B)} \quad\Longleftrightarrow\quad C_{ik} = \sum_j A_{ij} B_{jk}.$$

That's standard matmul. The right-of-arrow is the output indices; everything not on the right is summed. Inside the quotes, comma-separated index strings list the input tensors.

The three operations einsum captures

Contraction

Repeated index → summed. Matrix multiply, dot product, attention scoring, all matmul views (deck 03) are contractions.

Broadcasting

Index that appears in only one input → preserved unchanged. Adds an "outer" dimension. The B (batch) and H (head) indices in attention are broadcast.

Permutation

Index reordering: "ij->ji" is transpose. "bhld->blhd" is reshaping multi-head from (batch, head, seq, dim) to (batch, seq, head, dim).

Why einsum scales

Standard matrix notation has two indices — rows and columns. Once you have $\{batch, head, query, key, head\_dim, feature\}$ at the same time, you need six. Naming each one explicitly turns the algebra into something you can read. The einsum string is a shape annotation that's also executable.

03

Reading Real Transformer Code in Einsum

The single attention-block forward pass in einsum:

# Input: x[B, L, d]; weights W_Q[d, H, d_h], W_K[d, H, d_h], W_V[d, H, d_h], W_O[H, d_h, d]
q = einsum("bld,dhk->bhlk", x, W_Q)              # q[B, H, L, d_h]
k = einsum("bld,dhk->bhlk", x, W_K)              # k[B, H, L, d_h]
v = einsum("bld,dhk->bhlk", x, W_V)              # v[B, H, L, d_h]

scores = einsum("bhlk,bhmk->bhlm", q, k) / sqrt(d_h)   # scores[B, H, L, L]
attn   = softmax(scores + mask, axis=-1)                  # attn[B, H, L, L]
out    = einsum("bhlm,bhmk->bhlk", attn, v)            # out[B, H, L, d_h]

y = einsum("bhlk,hkd->bld", out, W_O)              # y[B, L, d]

Six lines. Read each einsum string and you can immediately see which axes are batched, which are contracted, which are permuted. No .transpose(), no .reshape(), no .permute(). The semantics are right there in the string.

Index conventions you'll see

JAX makes einsum the lingua franca

JAX's jax.numpy.einsum handles arbitrary contractions. jax.experimental.shard_map and the GSPMD partitioner reason about which axis is sharded across which device dim — using exactly the einsum index names. PyTorch is moving in the same direction with torch.einsum and DTensor. The notation that lets you write fast batched code on one device is also the notation that lets you partition it across thousands.

04

Strides, Views & Layout

A tensor is two things: a flat 1D buffer of numbers and a small "metadata" header that says how to interpret it. The header has the shape and the strides — the number of elements you skip in the flat buffer to advance one step along each axis.

For $X \in \mathbb{R}^{B \times L \times d}$ stored in C-order: stride is $(L \cdot d, \ d, \ 1)$. Element $(b, l, k)$ is at flat offset $b \cdot L d + l \cdot d + k$.

Why this matters

The "BHLD vs BLHD" decision

Multi-head attention can store $(B, H, L, d_h)$ or $(B, L, H, d_h)$. The score matmul einsum("bhlk,bhmk->bhlm", q, k) is a batched matmul; with BHLD layout, both operands have the same (B, H) outer dims and the matmul is along last two — cuBLAS-friendly. With BLHD layout you'd need a transpose. Most modern code uses BHLD for attention.

05

FlashAttention as Block Matmul

Naive attention materialises the full $L \times L$ score matrix in memory. For $L = 32k$, in fp16, that's ~2 GB per attention head, per layer — instantly unfeasible.

FlashAttention (Dao 2022, FlashAttention-2 2023, FlashAttention-3 2024) computes the same thing without ever writing the full score matrix to HBM. The trick is block matmul plus an online softmax.

The block decomposition

Split the queries into row-blocks $Q^{(1)}, Q^{(2)}, \ldots$ each of $B_r$ rows, and the keys/values into column-blocks $K^{(1)}, K^{(2)}, \ldots$ of $B_c$ rows. Each query-block, against each key/value-block:

  1. Load $Q^{(i)}, K^{(j)}, V^{(j)}$ tiles into on-chip SRAM.
  2. Compute the partial score $S^{(ij)} = Q^{(i)} (K^{(j)})^\top$ — a $B_r \times B_c$ tile of scores in SRAM.
  3. Update a running softmax: keep the maximum-so-far $m^{(i)}$ and the running denominator $\ell^{(i)}$, recompute the partial outputs and rescale.
  4. Accumulate into $O^{(i)}$.

Score matrix never leaves SRAM. HBM traffic is $O(N d^2 / B_c)$ instead of $O(N^2 d)$. For $N = 32k$ and $B_c = 64$, that's a 500× reduction in memory traffic. Result: identical numerical output, much faster wall-clock, much less peak memory.

The online-softmax algorithm

At its heart, FlashAttention is a generalised reduction. The softmax-then-weighted-sum can be re-formulated as a streaming reduction over key-blocks: keep the running max $m$, running denominator $\ell$, and running output $o$; on each new tile, rescale them to a new common max. The arithmetic is identical to plain softmax-attention up to floating-point rounding. Block matmul + clever bookkeeping = a 5-10× wall-clock speedup at no quality cost. This is the rare "free lunch" trick that has stuck across an entire field.

06

MoE as Sparse Projection

A Mixture of Experts FFN replaces the single $W_{\text{up}}, W_{\text{down}}$ pair with $E$ "experts" of which only $k$ activate per token (typically $k = 2$ out of $E = 8, 64$, or 128).

For each token, a learned router produces a softmax distribution over experts, picks the top-$k$, and the FFN is

$$\mathrm{MoE}(\mathbf{x}) = \sum_{e \in \mathrm{TopK}(\mathbf{x})} g_e(\mathbf{x})\, \mathrm{Expert}_e(\mathbf{x}),$$

with $g_e$ the gating weights (renormalised softmax over chosen experts).

Linear-algebra reading

Each $\mathrm{Expert}_e$ is a standard FFN: $W_{\text{up},e}, W_{\text{gate},e}, W_{\text{down},e}$. The router output is a sparse vector of length $E$ with exactly $k$ non-zero entries. Read MoE as a sparse-projection FFN: only $k$ of the available "directions" of feature-space contribute on any given token.

The numbers (Mixtral 8×7B)

This is why MoE has eaten the high-end frontier: at fixed inference FLOPs, MoE models train cheaper and serve more capability.

The downside, in linear-algebra form

The router's output is non-differentiable (top-$k$ is hard). Standard practice: differentiate through only the chosen experts; use auxiliary load-balancing losses to encourage even use; tolerate some routing instability. The gating logic is the part of MoE most likely to misbehave, and it's the part that doesn't fit cleanly into matrix algebra.

07

Data, Tensor & Pipeline Parallelism

For very large models, parameters and activations don't fit on one GPU. Three orthogonal ways to shard:

Data Parallelism (DP)

Replicate the whole model on every GPU; split the batch. Each GPU does a forward+backward on its slice of inputs, then all-reduces gradients. Fully model-replicating — bandwidth-limited at large scales.

Variants: ZeRO-1/2/3 (Rajbhandari 2020) shard optimiser state / gradients / parameters across DP rank to reduce memory.

Tensor Parallelism (TP)

Split a single matmul across GPUs. E.g. $W_{\text{up}} \in \mathbb{R}^{d \times d_{ff}}$ becomes $\big[W_{\text{up}}^{(1)} \mid W_{\text{up}}^{(2)}\big]$ on two GPUs. Output activation is concatenated. Megatron-LM's column/row split is the canonical recipe.

Communication: one all-gather + one reduce-scatter per layer. Bandwidth-heavy; usually kept within a single node (NVLink).

Pipeline Parallelism (PP)

Split the layers: GPU 0 does layers 1-10, GPU 1 does layers 11-20, etc. Activations stream forward, gradients stream back. Overlap with micro-batching to keep devices busy. GPipe / 1F1B schedules are standard.

Bubble overhead: with $P$ pipeline stages, ~$1 - 1/P$ of the time is productive (less with smart scheduling).

The three are orthogonal — you stack them

A typical big training run combines all three: DP across nodes (pods of GPUs), TP within a node (8 GPUs sharing NVLink), PP across stages of a layer-stack. With GQA + MoE the picture also includes expert parallelism: experts sharded across devices, routed-to as a fourth axis.

The 4D mesh of (DP, TP, PP, EP) on a typical GPU cluster is a non-trivial scheduling problem. Frameworks (PyTorch FSDP, JAX shard_map / GSPMD, Megatron-DeepSpeed) handle it; the einsum view of the underlying op (next slide) is what makes the parallelism legible.

08

Sharding via Einsum — The GSPMD Picture

Google's GSPMD framework (Xu et al. 2021) made an observation: parallelising an einsum is a property of the indices, not of the algorithm.

Take the FFN up-projection: einsum("bld,df->blf", x, W_up). Each axis can be sharded (S) or replicated (R) across a 1D device mesh:

x sharding (b, l, d)W sharding (d, f)OutputCommunication
(S, R, R)(R, R)(S, R, R) — DPNone for forward; all-reduce gradients.
(R, R, R)(R, S)(R, R, S) — TP column-splitNone for forward; all-gather/reduce-scatter at later stages.
(R, R, S)(S, R)(R, R, R) — TP row-splitAll-reduce on output (the sum over the contracted, sharded $d$ axis).
(S, R, S)(S, S)2D-shardedTwo-axis collectives.

Reading the table: shard axes that are not contracted, and you get parallelism with no communication. Shard axes that are contracted, and you need a reduction. The pattern of communication is determined by which einsum indices were sharded; the algorithm itself doesn't care.

How modern frameworks express this

JAX's shard_map and PartIR/GSPMD let you annotate each tensor with a sharding spec; the compiler inserts the all-reduces, all-gathers, reduce-scatters that the einsum requires. PyTorch DTensor has the same idea. You write the algorithm in einsum; the system figures out the parallel collective primitives.

The grand unification

Decks 01-05 taught you that everything in a transformer is matrix algebra. Decks 06-08 told you which structure theorems make it work. Decks 10-11 showed you the actual transformer block in those terms. This deck closes the loop: the same algebra plus axis names is also the language of distributed training. Sharding is just sharding einsum indices. There is one structure here, all the way down.

09

Where the Series Lands

After twelve decks: vectors and matrices, projections and SVD, gradients and chain rules, attention and FFNs, all the way to FlashAttention and tensor parallelism. The arc has been deliberately one-way:

Decks 01-04 — Foundations

Vectors, matrices, matmul, inner products. The language every later deck reads from.

Decks 05-09 — Structure theorems

Projection, eigen, SVD, QR, gradients. The why behind transformer shape and gradient flow.

Decks 10-12 — Inside the transformer

Attention, the full block forward pass, and the tensor / sharding view that scales it across pods.

Where to go next from here

Linear algebra is the floor everything else stands on. Every non-trivial idea in modern AI — attention, RoPE, LoRA, MLA, FlashAttention, GQA, MoE, GSPMD — is a specific structural fact about matrices and tensors, applied to a specific problem. With this series in hand you can read those structural facts directly off the algorithms.

10

Cheat Sheet

End of series

Thanks for reading. The companion sub-hubs in the LLMs hub take it from here — transformer architecture, GPU/TPU hardware, modern architectures, training, evaluations and production. Drop into whichever interests you next.