Google TPUs Series — Presentation 11

The TPU Software Stack — XLA, JAX, Pallas

From a Python jax.jit down to LLO machine code on the chip. The compiler, the partitioner, the kernel DSL, the orchestration runtime — what makes a 9,216-chip pod look like one machine.

XLAHLOStableHLO JAXGSPMDShardy PallasMosaic PyTorch/XLAMaxText Pathways
Python (JAX / PyTorch) jaxpr / LazyTensor IR StableHLO XLA HLO GSPMD / Shardy LLO TPU machine code
00

Topics We'll Cover

01

The Whole Stack on One Slide

JAX jit, grad, vmap, shard_map PyTorch / XLA LazyTensor, mark_step TensorFlow tf.function, autograph Pallas low-level StableHLO portable IR · 5-yr backward / 2-yr forward compatibility · MLIR bytecode XLA HLO target-independent passes · fusion · layout assignment · GSPMD partitioning TPU backend (LLO codegen) low-level operations · MXU schedule · DMA · cycle-accurate VLIW issue Mosaic (TPU) Pallas-on-TPU codegen TPU chip · MXU · vector unit · scalar unit · VMEM · CMEM · HBM · ICI

Five frontends, two IRs, two compiler backends, one chip. Notice that everything lowers through StableHLO → HLO. HLO is the architectural dividing line — above it is framework / portability concerns, below it is target-specific code generation.

02

XLA & HLO — The Compiler Core

XLA (Accelerated Linear Algebra) was introduced by the TensorFlow team in February 2017 at the first TensorFlow Developer Summit. It is the compiler that turns ML programs into TPU machine code (and CPU/GPU code via separate backends).

HLO — "High Level Operations"

  • Closed opset operating on immutable tensors with statically-known shapes.
  • Operations: dot, convolution, reduce, broadcast, reshape, concatenate, while, conditional, fft, scatter, gather, … (~100 ops).
  • Pure functional style — ops produce new tensors, never mutate.
  • SSA-form, easy to optimise and reason about.

What HLO is good for

  • Common-subexpression elimination, dead-code removal — standard compiler passes apply directly.
  • Operator fusion (next slide) is local rewriting on the HLO graph.
  • Sharding and SPMD are graph-level annotations the partitioner consumes.
  • Layout assignment chooses physical tile shapes for each tensor before backend codegen.

The lowering path

03

Operator Fusion & Layout Assignment

The two XLA passes that drive 80%+ of measured TPU performance.

Operator fusion

Multiple HLO ops merged into a single fusion cluster, so intermediate tensors stay in registers / VMEM rather than spilling to HBM. Three flavours:

  • Loop fusion / element-wise fusion: sequential pointwise ops (relu · add · mul) become one loop.
  • Input fusion: multiple consumers of the same producer share a fused kernel (e.g. a tensor feeding both a reduce and a broadcast).
  • Output fusion: a producer (e.g. dot or conv) fuses with its consumers (often the bias-add, scale, and ReLU after a matmul).

Layout assignment

Picks the physical layout of each tensor in TPU memory. For TPU specifically:

  • Tiles match the MXU's 128×128 shape to avoid expensive transposes.
  • VMEM tiling chosen so that the working set of one fusion cluster fits in VMEM.
  • Layout decisions interact with sharding choices — the partitioner picks layouts that minimise cross-shard reshapes.
  • Bad layout = automatic 5–10× slowdown from layout-conversion ops.

The reference paper

Snider & Liang's "Operator Fusion in XLA: Analysis and Evaluation" (arXiv:2301.13062, 2023) is the most thorough public analysis of XLA's fusion passes. Worth a read if you want to understand why your jax.jit'd function is faster than the equivalent imperative loop — usually the answer is "XLA fused 30 ops into 3 kernels".

04

OpenXLA & StableHLO

OpenXLA (March 2023)

  • Spun out of TensorFlow into the standalone openxla/ GitHub organisation.
  • Multi-vendor: AMD, Apple, Arm, Cerebras, Google, Graphcore, Intel, NVIDIA, SiFive, Hugging Face all collaborated on launch.
  • Community-led; not a Linux Foundation project (despite occasional confusion on this point).
  • Repos: openxla/xla (the compiler), openxla/stablehlo (the IR), openxla/iree (an alternative ML runtime).

StableHLO

  • Portability layer between ML frameworks and ML compilers.
  • Built on the MHLO MLIR dialect, with 94 statically-shaped operations.
  • MLIR-bytecode serialisation — you can save a StableHLO program and reload it later.
  • Compatibility guarantee: 5 years backward, 2 years forward per the StableHLO compatibility RFC.
  • The canonical artefact emerging from jax.jit.lower(...).compile(...) — you can extract it and feed it to other XLA-compatible compilers.

Why this matters

Before StableHLO, the only stable interface to XLA was the Python frontend. Now compilers like IREE, hardware vendors like AMD's ROCm-XLA, and runtime systems can consume StableHLO directly. The TPU stack is increasingly portable — PyTorch can target XLA, JAX programs can be exported and re-run elsewhere, and OpenXLA itself has third-party backends.

05

JAX — Composable Function Transforms

JAX started as a Google Brain research project in 2018. Original team: Matt Johnson, Roy Frostig, Dougal Maclaurin, Chris Leary. Now under the jax-ml/ GitHub organisation (moved from google/jax).

The four primitives

JAX function transformations
import jax
import jax.numpy as jnp

def predict(params, x):
    return jnp.tanh(x @ params['W'] + params['b'])

# jit — trace, lower to StableHLO, compile, cache
fast_predict = jax.jit(predict)

# grad — reverse-mode automatic differentiation
loss_grad = jax.grad(lambda p, x, y: jnp.mean((predict(p, x) - y) ** 2))

# vmap — automatic vectorisation across a batch axis
batched_predict = jax.vmap(predict, in_axes=(None, 0))

# pmap (legacy) / jit + sharding (modern) — multi-device SPMD
sharded_predict = jax.jit(predict, in_shardings=...)

The composability invariant

Any JAX transform can wrap any other JAX transform. jit(grad(vmap(f))) is meaningful: vectorise f over a batch axis, take its gradient, then JIT-compile the whole thing. This composability is what makes JAX uniquely productive for research code — you can write a small, clean function, then wrap it in transforms to scale it to a TPU pod.

pmap is legacy

The original SPMD primitive in JAX was pmap — "parallel map" across devices. Modern JAX uses jax.jit with explicit shardings (GSPMD-managed) for automatic SPMD, and shard_map for the manual counterpart. pmap still works but isn't the recommended path for new code.

06

Mesh, NamedSharding & GSPMD

This is the abstraction that turns a 9,216-chip pod into a single mesh you can shard tensors across.

a 4-axis pod mesh, with sharded tensors
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import numpy as np

# a 9,216-chip Ironwood pod arranged as 4 axes
devices = np.asarray(jax.devices()).reshape(8, 8, 12, 12)
mesh = Mesh(devices, axis_names=('data', 'fsdp', 'tensor', 'expert'))

# weight matrix sharded across 'tensor' and 'expert' axes
W = jax.device_put(W, NamedSharding(mesh, P('tensor', 'expert')))

# input batch sharded across 'data' (and replicated across the rest)
x = jax.device_put(x, NamedSharding(mesh, P('data', None)))

# JIT the model — GSPMD will figure out where to insert collectives
y = jax.jit(model)(W, x)

What GSPMD does for you

Shardy — the successor

Shardy is the new MLIR-based partitioner that's gradually replacing GSPMD as XLA's sharding system. Same conceptual model (tensors annotated with shardings, compiler inserts collectives), more uniform IR representation. New code should expect Shardy semantics; old code keeps working.

07

shard_map — The Manual Counterpart

GSPMD is automatic; shard_map is manual. You write per-device code with explicit collectives, and the compiler hands each shard to a device.

manual collectives via shard_map
from jax.experimental.shard_map import shard_map

def tensor_parallel_matmul(W_local, x_local):
    # compute local partial product
    partial = x_local @ W_local
    # sum across the 'tensor' axis to produce the full result
    return jax.lax.psum(partial, 'tensor')

# map over the mesh, with explicit input/output partition specs
y = shard_map(
    tensor_parallel_matmul,
    mesh=mesh,
    in_specs=(P(None, 'tensor'), P('data', None)),
    out_specs=P('data', None),
)(W, x)

When to use which

The two compose: you can shard_map a function and then jit the result; you can have GSPMD-managed regions calling shard_map'd kernels. The choice is per-function, not per-program.

08

Pallas / Mosaic — Low-Level TPU Kernels

Pallas is a JAX-embedded kernel DSL, conceptually equivalent to Triton on GPU. The same Python source can target multiple backends; on TPU, Pallas lowers through Mosaic, an MLIR-based TPU kernel compiler.

a Pallas matmul kernel
from jax.experimental import pallas as pl

def matmul_kernel(x_ref, w_ref, out_ref):
    out_ref[...] = x_ref[...] @ w_ref[...]

# BlockSpec tells Pallas how to tile each input from HBM into VMEM
matmul = pl.pallas_call(
    matmul_kernel,
    out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),
    grid=(M//128, N//128),
    in_specs=[
        pl.BlockSpec(block_shape=(128, K), index_map=lambda i, j: (i, 0)),
        pl.BlockSpec(block_shape=(K, 128), index_map=lambda i, j: (0, j)),
    ],
    out_specs=pl.BlockSpec(block_shape=(128, 128), index_map=lambda i, j: (i, j)),
)

What Pallas-on-TPU lets you do

Pallas-on-GPU exists too (lowers to Triton or Mosaic GPU). The same source compiles to either — one of the few real cross-vendor kernel-DSL stories.

09

PyTorch/XLA → TorchTPU

PyTorch on TPU has been an option since 2018, via the PyTorch/XLA bridge. As of late 2025, Google has announced TorchTPU, a project to run PyTorch natively on TPUs that will eventually replace PyTorch/XLA.

PyTorch/XLA (current)

  • LazyTensor graph capture: PyTorch ops on XLA tensors aren't executed eagerly, they're recorded into an IR graph.
  • Materialised at barriers: xm.mark_step(), .item(), print(), etc.
  • The captured graph is lowered to HLO, compiled by XLA, cached by shape signature.
  • First-call compile is expensive (sometimes minutes for big LLMs); recompiles trigger on shape changes or graph breaks.
  • Steady-state performance is fully fused TPU-native execution.

TorchTPU (announced late 2025)

  • Native PyTorch backend for TPU — not via the XLA bridge.
  • Goal: eager-mode performance and native PyTorch ergonomics on TPU, without LazyTensor's compile-cost surprises.
  • Roadmap published for 2026 GA; PyTorch/XLA remains the production path until then.
  • Reference implementations under AI-Hypercomputer/torchprime (Llama 3 etc., FSDP, multislice).

Practical advice for May 2026

10

MaxText, Pax / Praxis, Flax NNX

The higher-level libraries that sit on top of JAX for actual model training.

MaxText (the public flagship)

  • Pure JAX/Flax LLM training reference under AI-Hypercomputer/maxtext.
  • Targets TPU and GPU equally well.
  • Composes Flax (model definition), Optax (optimisers), Orbax (checkpointing), Grain (data loading), Tunix (post-training).
  • Used in the Cloud TPU docs as the canonical "how to train an LLM" example.
  • Most-active open source TPU LLM stack.

Pax / Praxis (Google internal)

  • Praxis = the model-and-layer library; Pax = the trainer.
  • Used internally for PaLM-class and Gemini training.
  • Public release as paxml on GitHub.
  • More opinionated than MaxText; baked-in patterns for Google-style training runs.

Flax & Flax NNX

Tunix — post-training

Google's JAX-native library for LLM post-training (SFT, RLHF, DPO, GRPO). Released 2024; actively developed. The TPU answer to TRL and OpenRLHF in the GPU world.

11

Pathways & Multislice

The orchestration layer for very-large training.

Pathways

  • Single-controller, asynchronous distributed dataflow runtime.
  • Reference: Barham, Dean, Ghemawat et al., "Pathways: Asynchronous Distributed Dataflow for ML", MLSys 2022, arXiv:2203.12533.
  • Authors include Jeff Dean and Sanjay Ghemawat — the architects of Google search itself.
  • One Python client → thousands of TPU workers across multiple pods.
  • Trained PaLM 540B on 6,144 TPU v4 chips across multiple pods.
  • Now exposed externally as Pathways on Cloud.

Multislice

  • Cloud TPU feature, announced August 2023.
  • Lets a single training job span multiple TPU slices over Jupiter DCN.
  • Inside each slice: high-bandwidth ICI.
  • Between slices: Jupiter DCN at microsecond latency.
  • Required for any training that exceeds one pod's chip count.
  • Available on v4, v5e, v5p, Trillium, Ironwood.

Cloud TPU provisioning APIs

12

Cheat Sheet

Read next

Deck 12 — TPU vs GPU ties the silicon and software together by contrasting both sides of the divide.