From a Python jax.jit down to LLO machine code on the chip. The compiler, the partitioner, the kernel DSL, the orchestration runtime — what makes a 9,216-chip pod look like one machine.
Five frontends, two IRs, two compiler backends, one chip. Notice that everything lowers through StableHLO → HLO. HLO is the architectural dividing line — above it is framework / portability concerns, below it is target-specific code generation.
02
XLA & HLO — The Compiler Core
XLA (Accelerated Linear Algebra) was introduced by the TensorFlow team in February 2017 at the first TensorFlow Developer Summit. It is the compiler that turns ML programs into TPU machine code (and CPU/GPU code via separate backends).
HLO — "High Level Operations"
Closed opset operating on immutable tensors with statically-known shapes.
HLO → LLO. The TPU backend lowers HLO to LLO ("Low Level Operations") — a TPU-specific representation expressing the MXU schedule, vector-unit operations, DMA descriptors, and VLIW issue.
LLO → machine code. Final assembly for the chip.
03
Operator Fusion & Layout Assignment
The two XLA passes that drive 80%+ of measured TPU performance.
Operator fusion
Multiple HLO ops merged into a single fusion cluster, so intermediate tensors stay in registers / VMEM rather than spilling to HBM. Three flavours:
Loop fusion / element-wise fusion: sequential pointwise ops (relu · add · mul) become one loop.
Input fusion: multiple consumers of the same producer share a fused kernel (e.g. a tensor feeding both a reduce and a broadcast).
Output fusion: a producer (e.g. dot or conv) fuses with its consumers (often the bias-add, scale, and ReLU after a matmul).
Layout assignment
Picks the physical layout of each tensor in TPU memory. For TPU specifically:
Tiles match the MXU's 128×128 shape to avoid expensive transposes.
VMEM tiling chosen so that the working set of one fusion cluster fits in VMEM.
Layout decisions interact with sharding choices — the partitioner picks layouts that minimise cross-shard reshapes.
Bad layout = automatic 5–10× slowdown from layout-conversion ops.
The reference paper
Snider & Liang's "Operator Fusion in XLA: Analysis and Evaluation" (arXiv:2301.13062, 2023) is the most thorough public analysis of XLA's fusion passes. Worth a read if you want to understand why your jax.jit'd function is faster than the equivalent imperative loop — usually the answer is "XLA fused 30 ops into 3 kernels".
04
OpenXLA & StableHLO
OpenXLA (March 2023)
Spun out of TensorFlow into the standalone openxla/ GitHub organisation.
Multi-vendor: AMD, Apple, Arm, Cerebras, Google, Graphcore, Intel, NVIDIA, SiFive, Hugging Face all collaborated on launch.
Community-led; not a Linux Foundation project (despite occasional confusion on this point).
Repos: openxla/xla (the compiler), openxla/stablehlo (the IR), openxla/iree (an alternative ML runtime).
StableHLO
Portability layer between ML frameworks and ML compilers.
Built on the MHLO MLIR dialect, with 94 statically-shaped operations.
MLIR-bytecode serialisation — you can save a StableHLO program and reload it later.
Compatibility guarantee: 5 years backward, 2 years forward per the StableHLO compatibility RFC.
The canonical artefact emerging from jax.jit.lower(...).compile(...) — you can extract it and feed it to other XLA-compatible compilers.
Why this matters
Before StableHLO, the only stable interface to XLA was the Python frontend. Now compilers like IREE, hardware vendors like AMD's ROCm-XLA, and runtime systems can consume StableHLO directly. The TPU stack is increasingly portable — PyTorch can target XLA, JAX programs can be exported and re-run elsewhere, and OpenXLA itself has third-party backends.
05
JAX — Composable Function Transforms
JAX started as a Google Brain research project in 2018. Original team: Matt Johnson, Roy Frostig, Dougal Maclaurin, Chris Leary. Now under the jax-ml/ GitHub organisation (moved from google/jax).
Any JAX transform can wrap any other JAX transform. jit(grad(vmap(f))) is meaningful: vectorise f over a batch axis, take its gradient, then JIT-compile the whole thing. This composability is what makes JAX uniquely productive for research code — you can write a small, clean function, then wrap it in transforms to scale it to a TPU pod.
pmap is legacy
The original SPMD primitive in JAX was pmap — "parallel map" across devices. Modern JAX uses jax.jit with explicit shardings (GSPMD-managed) for automatic SPMD, and shard_map for the manual counterpart. pmap still works but isn't the recommended path for new code.
06
Mesh, NamedSharding & GSPMD
This is the abstraction that turns a 9,216-chip pod into a single mesh you can shard tensors across.
a 4-axis pod mesh, with sharded tensors
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import numpy as np
# a 9,216-chip Ironwood pod arranged as 4 axes
devices = np.asarray(jax.devices()).reshape(8, 8, 12, 12)
mesh = Mesh(devices, axis_names=('data', 'fsdp', 'tensor', 'expert'))
# weight matrix sharded across 'tensor' and 'expert' axes
W = jax.device_put(W, NamedSharding(mesh, P('tensor', 'expert')))
# input batch sharded across 'data' (and replicated across the rest)
x = jax.device_put(x, NamedSharding(mesh, P('data', None)))
# JIT the model — GSPMD will figure out where to insert collectives
y = jax.jit(model)(W, x)
What GSPMD does for you
Propagates the input shardings through every operation in your function.
Infers shardings for intermediates that the user didn't annotate.
Inserts all-reduce, all-gather, reduce-scatter, all-to-all collectives where the shardings disagree.
Optimises collective placement to overlap with compute when possible.
Reference: Xu et al., "GSPMD: General and Scalable Parallelization for ML Computation Graphs", 2021.
Shardy — the successor
Shardy is the new MLIR-based partitioner that's gradually replacing GSPMD as XLA's sharding system. Same conceptual model (tensors annotated with shardings, compiler inserts collectives), more uniform IR representation. New code should expect Shardy semantics; old code keeps working.
07
shard_map — The Manual Counterpart
GSPMD is automatic; shard_map is manual. You write per-device code with explicit collectives, and the compiler hands each shard to a device.
manual collectives via shard_map
from jax.experimental.shard_map import shard_map
deftensor_parallel_matmul(W_local, x_local):
# compute local partial product
partial = x_local @ W_local
# sum across the 'tensor' axis to produce the full resultreturn jax.lax.psum(partial, 'tensor')
# map over the mesh, with explicit input/output partition specs
y = shard_map(
tensor_parallel_matmul,
mesh=mesh,
in_specs=(P(None, 'tensor'), P('data', None)),
out_specs=P('data', None),
)(W, x)
When to use which
jit + GSPMD — high-level, automatic. Good for most training runs; optimal in 80%+ of cases.
shard_map — manual, explicit. When you want a non-standard collective pattern (e.g. ring-all-reduce overlapped with compute, custom expert routing in MoE), or when GSPMD's inferences are suboptimal.
The two compose: you can shard_map a function and then jit the result; you can have GSPMD-managed regions calling shard_map'd kernels. The choice is per-function, not per-program.
08
Pallas / Mosaic — Low-Level TPU Kernels
Pallas is a JAX-embedded kernel DSL, conceptually equivalent to Triton on GPU. The same Python source can target multiple backends; on TPU, Pallas lowers through Mosaic, an MLIR-based TPU kernel compiler.
a Pallas matmul kernel
from jax.experimental import pallas as pl
defmatmul_kernel(x_ref, w_ref, out_ref):
out_ref[...] = x_ref[...] @ w_ref[...]
# BlockSpec tells Pallas how to tile each input from HBM into VMEM
matmul = pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),
grid=(M//128, N//128),
in_specs=[
pl.BlockSpec(block_shape=(128, K), index_map=lambda i, j: (i, 0)),
pl.BlockSpec(block_shape=(K, 128), index_map=lambda i, j: (0, j)),
],
out_specs=pl.BlockSpec(block_shape=(128, 128), index_map=lambda i, j: (i, j)),
)
What Pallas-on-TPU lets you do
Write a kernel with explicit VMEM tile management (the BlockSpec says "bring this HBM region into VMEM for each step").
Control double-buffering: pltpu.emit_pipeline handles HBM↔VMEM software pipelining around an inner kernel.
Issue ICI sends/receives: outer pipelines can communicate between chips for things like sequence-parallel attention.
Get within ~5–10% of hand-written XLA performance on standard ops — and write things XLA can't fuse easily (FlashAttention variants, custom MoE routing).
Pallas-on-GPU exists too (lowers to Triton or Mosaic GPU). The same source compiles to either — one of the few real cross-vendor kernel-DSL stories.
09
PyTorch/XLA → TorchTPU
PyTorch on TPU has been an option since 2018, via the PyTorch/XLA bridge. As of late 2025, Google has announced TorchTPU, a project to run PyTorch natively on TPUs that will eventually replace PyTorch/XLA.
PyTorch/XLA (current)
LazyTensor graph capture: PyTorch ops on XLA tensors aren't executed eagerly, they're recorded into an IR graph.
Materialised at barriers: xm.mark_step(), .item(), print(), etc.
The captured graph is lowered to HLO, compiled by XLA, cached by shape signature.
First-call compile is expensive (sometimes minutes for big LLMs); recompiles trigger on shape changes or graph breaks.
Steady-state performance is fully fused TPU-native execution.
TorchTPU (announced late 2025)
Native PyTorch backend for TPU — not via the XLA bridge.
Goal: eager-mode performance and native PyTorch ergonomics on TPU, without LazyTensor's compile-cost surprises.
Roadmap published for 2026 GA; PyTorch/XLA remains the production path until then.
Reference implementations under AI-Hypercomputer/torchprime (Llama 3 etc., FSDP, multislice).
Practical advice for May 2026
For new TPU work, JAX is the most-supported, best-documented frontend.
For existing PyTorch codebases, PyTorch/XLA still works; expect TorchTPU to land through 2026 and gradually take over.
For inference serving on Ironwood specifically, JAX is required — the chip's GA path doesn't yet support PyTorch natively.
10
MaxText, Pax / Praxis, Flax NNX
The higher-level libraries that sit on top of JAX for actual model training.
MaxText (the public flagship)
Pure JAX/Flax LLM training reference under AI-Hypercomputer/maxtext.
Used in the Cloud TPU docs as the canonical "how to train an LLM" example.
Most-active open source TPU LLM stack.
Pax / Praxis (Google internal)
Praxis = the model-and-layer library; Pax = the trainer.
Used internally for PaLM-class and Gemini training.
Public release as paxml on GitHub.
More opinionated than MaxText; baked-in patterns for Google-style training runs.
Flax & Flax NNX
Flax is the most widely used JAX neural-net library.
Original API: flax.linen, functional / explicit-state.
Flax NNX (2024): new API with reference semantics, mutable modules, more PyTorch-like ergonomics. Becoming the default for new code.
DeepMind's Haiku is the older alternative; in maintenance, with internal teams migrating to Flax NNX.
Tunix — post-training
Google's JAX-native library for LLM post-training (SFT, RLHF, DPO, GRPO). Released 2024; actively developed. The TPU answer to TRL and OpenRLHF in the GPU world.
Reference: Barham, Dean, Ghemawat et al., "Pathways: Asynchronous Distributed Dataflow for ML", MLSys 2022, arXiv:2203.12533.
Authors include Jeff Dean and Sanjay Ghemawat — the architects of Google search itself.
One Python client → thousands of TPU workers across multiple pods.
Trained PaLM 540B on 6,144 TPU v4 chips across multiple pods.
Now exposed externally as Pathways on Cloud.
Multislice
Cloud TPU feature, announced August 2023.
Lets a single training job span multiple TPU slices over Jupiter DCN.
Inside each slice: high-bandwidth ICI.
Between slices: Jupiter DCN at microsecond latency.
Required for any training that exceeds one pod's chip count.
Available on v4, v5e, v5p, Trillium, Ironwood.
Cloud TPU provisioning APIs
Cloud TPU VM: SSH-able host attached to TPU chips. The modern API (replaced "TPU Node").
Queued Resources: async request for a slice shape; queue waits until capacity is free.
GKE TPU node pools: TPUs as Kubernetes resources, with multi-host slice support since GKE 1.27.2-gke.1500.
Pathways on Cloud: for jobs that exceed one slice or need async resharding.
12
Cheat Sheet
XLA (Feb 2017): Google's ML compiler. Turns ML programs into TPU/GPU/CPU code via two IRs — StableHLO (portable) and HLO (target-independent middle-end).
HLO ops: ~100, statically-shaped immutable tensors. Operator fusion + layout assignment do most of the perf work.