Stanford CS25: Transformers United V6 I The Ultra-Scale Talk: Scaling Training to Thousands of GPUs

| Podcasts | May 11, 2026 | 267 views | 1:01:48

TL;DR

Nuaman Tazzy from HuggingFace explains how to scale transformer training to thousands of GPUs using data parallelism strategies, from basic Distributed Data Parallel (DDP) to Fully Sharded Data Parallel (FSDP/ZeRO), emphasizing memory optimization techniques and the critical importance of overlapping communication with computation to keep GPUs fully utilized.

🚀 The Scaling Imperative 2 insights

Trillion-parameter models define modern AI

Current LLMs like Kimi K2.6 reach 1 trillion parameters, training on 15 trillion tokens with contexts up to 1 million tokens, creating extreme infrastructure pressure to complete iterations in approximately one second.

Memory constraints drive parallelization needs

With global batch sizes of 1-50 million tokens and models too large for single GPU VRAM, training requires distributing both data and model parameters across thousands of accelerators.

🔄 Data Parallelism Fundamentals 2 insights

DDP distributes batches but requires gradient synchronization

Distributed Data Parallel feeds different data batches to each GPU and uses all-reduce collective operations to synchronize gradients, ensuring all GPUs maintain identical model copies.

Gradient bucketing overlaps communication with computation

PyTorch DDP divides gradients into buckets and initiates all-reduce operations as soon as each bucket completes during backward pass, preventing GPUs from idling while waiting for communication.

💾 ZeRO Memory Optimization Stages 3 insights

ZeRO-1 shards optimizer states to eliminate duplication

Replace all-reduce with reduce-scatter to distribute gradients so each GPU handles optimizer states only for its assigned parameter shard, removing redundant optimizer work across devices.

ZeRO-3 shards parameters with prefetching tricks

Fully Sharded Data Parallel distributes parameters across GPUs, all-gathering layers just-in-time during forward/backward passes while prefetching next layers to maintain only two FSDP units in memory simultaneously.

FSDP2 preserves full tensors for modern optimizers

Unlike FSDP1 which flattens parameters, FSDP2 uses DTensor to maintain complete tensor shapes, enabling compatibility with optimizers like Muon that require full tensors for Newton-Schulz iterations.

⚖️ Implementation Trade-offs 2 insights

Match ZeRO stage to actual memory constraints

Only use FSDP/ZeRO-3 when necessary, as higher stages add communication overhead; ZeRO-1 trains faster than ZeRO-3 if your model fits in memory, so avoid over-sharding.

FSDP2 offers superior composability

FSDP2 integrates better with other parallelism forms through DTensor and automatically handles module wrapping, making it preferable to FSDP1 for complex scaling configurations.

Bottom Line

Choose the lowest ZeRO stage that fits your model in memory to minimize communication overhead, and always configure gradient bucketing or prefetching to overlap communication with computation, keeping GPU utilization near 100%.

More from Stanford Online

View all
Stanford CS336 Language Modeling from Scratch | Spring 2026 | Lecture 10: Inference
1:25:30
Stanford Online Stanford Online

Stanford CS336 Language Modeling from Scratch | Spring 2026 | Lecture 10: Inference

Inference now dominates AI economics, with OpenAI generating 8.6 trillion tokens daily—exceeding frontier model training compute in under four days. Unlike training, autoregressive inference cannot parallelize across sequences, making it fundamentally memory-bandwidth bound rather than compute bound, with batch sizes under 295 on H100s failing to saturate GPU capacity.

about 7 hours ago · 9 points