Distributed Training
distributed training, data parallelism, model parallelism, pipeline parallelism, all-reduce, gradient synchronization, ray, fsdp, scaling
Introduction
The first time the model didn’t fit, the team thought they had a bug. The training script raised CUDA out of memory before the first batch finished, on a card that had handled every previous model without complaint. There was no bug. The model had simply grown past what 80 gigabytes could hold — parameters, gradients, optimizer state, and activations overflowed the card the moment the backward pass tried to allocate. A different team had the opposite problem: their model fit fine, but a full run was projected at nineteen days on one GPU, and the experiment was due in five. Both teams arrived, from opposite directions, at the same conclusion: one GPU was no longer enough.
So they reached for a second card, and a third, and the speedup did not come. The two-GPU run was barely faster than one; the four-GPU run was slower in wall-clock time than two. The GPUs sat mostly idle, their utilization graphs sawtoothing between bursts of computation and long flat stretches of nothing. They were not computing — they were waiting for each other, waiting on the wire to exchange the gradients that had to be identical on every card before the next step could begin. Bolting on more hardware had produced more hardware and no more throughput.
This is the central lesson, worth stating before any technique: scaling training is a distributed-systems problem wearing an ML hat, and communication, not computation, is usually the bottleneck. The math of a neural network is embarrassingly parallel — every GPU can crunch its share independently. The hard part is keeping those workers consistent, and consistency means moving data across an interconnect orders of magnitude slower than the GPU’s own memory. Get the communication right and adding GPUs adds throughput almost linearly. Get it wrong and you have bought an expensive cluster that runs at the speed of its slowest link.
The Core Insight
There are exactly two reasons to distribute training, and they are not the same problem. Conflating them is the most common way teams pick the wrong tool.
The first reason is the data is too big, or training is too slow. The model fits comfortably on one GPU; you simply have more data than one GPU can chew through in time. The answer is data parallelism: put a complete copy of the model on every GPU, hand each a different shard of the batch, and let them all compute in parallel. The catch is that after each step the copies must agree — a model trained on shard A and one trained on shard B have diverged, and you need them to stay one model. So every step ends with an all-reduce: the GPUs collectively sum their gradients and divide by the worker count, so every replica applies the same averaged gradient and the copies stay bit-for-bit identical. Data parallelism is the common case, and all-reduce is its beating heart.
The second reason is the model is too big to fit on one GPU. No amount of data sharding helps, because the problem isn’t the data — a single replica won’t fit. The answer is model parallelism in one of its forms: split the model itself across devices so each GPU holds only a slice of the parameters — across layers (pipeline), within layers (tensor), or by sharding the optimizer and parameter state while keeping the data-parallel structure (sharded data parallelism — ZeRO, FSDP).
What unites both is the cost that dominates them. In data parallelism it is the gradient all-reduce; in model parallelism it is the activations passed between slices. Either way the GPUs spend real time moving bytes across the interconnect, and that time is pure overhead — no gradients are computed while a tensor is on the wire. The entire art is overlapping or minimizing communication so the GPUs compute instead of wait.
A mental model
Hold three pictures and the rest of the chapter is detail.
Data parallelism is a room full of identical workers splitting a stack of paperwork. Each takes a different pile, does the same job, and at the end of every round they pool their notes and average them so everyone starts the next round from the same place. The work scales with the number of workers — but only as long as the averaging stays cheap. If pooling notes takes longer than doing the work, hiring more workers makes things worse. That pooling step is all-reduce.
Model parallelism is an assembly line. One product is too big for any single station to build, so each installs one part and passes the half-finished product down the line. Station two cannot start until station one hands off, so the line only runs full when there is a steady stream flowing through it. Start it cold with a single item and most stations stand idle waiting for the front to fill — the “bubble” that pipeline parallelism spends real effort to shrink.
The interconnect decides whether scaling-out actually scales. Inside a single server, GPUs talk over NVLink at hundreds of gigabytes per second and communication is nearly free — scaling is close to linear. The moment a job spans machines, gradients travel over a network an order of magnitude slower and communication starts to dominate. Whether your cluster delivers eight GPUs of throughput or two depends less on the GPUs than on the wire between them.
Choosing a strategy
The decision starts with one question: does a single replica fit on one GPU? Count the memory honestly — for Adam, training costs roughly sixteen bytes per parameter (four for the parameter, four for its gradient, eight for the optimizer’s two moments) before activations. A seven-billion-parameter model in full precision needs about 126 GB to train, which does not fit on an 80 GB card even though its parameters alone are only 28 GB.
If it fits and you only want to go faster, use data parallelism — simplest to set up, easiest to debug, near-linear until the interconnect saturates. If it does not fit, your first move is usually sharded data parallelism (ZeRO / FSDP), which keeps the data-parallel programming model but shards the optimizer, gradients, and parameters so each GPU holds only a fraction. When even that isn’t enough — models many times larger than a single card — you reach for model parallelism proper: pipeline across devices, tensor within layers, and at the largest scale all three combined. Figure 40.1 contrasts the two foundational shapes.
What you’ll learn
- How data parallelism replicates a model, shards the batch, and uses all-reduce to keep every replica in sync — and why that delivers near-linear speedup only when communication overlaps computation
- Why distributing the batch changes the effective batch size, and the learning-rate scaling rule that follows from it
- How pipeline and tensor parallelism split a too-large model across devices, and where the pipeline “bubble” comes from
- How sharded data parallelism (ZeRO / FSDP) fits bigger models without abandoning the data-parallel model — the modern default
- Why interconnect bandwidth and communication overlap, not raw GPU count, set the ceiling on scaling efficiency
- How a general-purpose distributed framework like Ray orchestrates training, tuning, and serving on a cluster with a task-and-actor model
Prerequisites
- Deep learning and the training loop: forward pass, backward pass, gradients, and what an optimizer step does (the Deep Learning Frameworks material)
- Distributed-systems basics: processes versus threads, what a network round-trip costs, and why coordinating state across machines is hard
- Comfort with GPUs as devices: that a GPU has its own memory, and that moving data on and off it is not free
Data parallelism and all-reduce
Data parallelism is the workhorse, so it earns the most attention. The setup is almost disappointingly simple: launch one process per GPU, give each a complete copy of the model and optimizer, and feed each a different, non-overlapping shard of every batch. The thing that trips up everyone learning it is that this is not one program driving many GPUs. It is N independent programs running the same code, each aware of its own rank, that coordinate at exactly one moment per step.
That moment is the backward pass. The forward pass is entirely local: each GPU runs its shard through its own model copy and computes its own loss, with no communication. The backward pass computes local gradients — and then the magic happens. Because each replica saw different data, each computed a different gradient, and applying them locally would drift into N different models. Instead they perform an all-reduce: every GPU contributes its gradient, the framework sums across all workers and divides by the worker count, and every GPU receives the same averaged gradient back. Each then runs its optimizer step locally — and because the gradient was identical and the optimizer is deterministic, the copies stay perfectly in sync without any further communication.
In PyTorch this is DistributedDataParallel, and the API is small precisely because the framework hides the all-reduce inside .backward(). You wrap the model, give each process a sampler that hands it the right shard, and train almost exactly as you would on one GPU.
# Illustrative: each process runs this; only .backward() talks across GPUs.
model = DDP(model.to(rank), device_ids=[rank]) # wrap; registers all-reduce hooks
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size)
for epoch in range(epochs):
sampler.set_epoch(epoch) # reshuffle identically across ranks
for batch, target in loader:
loss = criterion(model(batch), target)
loss.backward() # local grads + all-reduce, overlapped
optimizer.step() # local; replicas stay in syncThe reason this scales near-linearly is the overlap. A well-built framework does not wait for the entire backward pass to finish and then start one giant all-reduce. Gradients become available in reverse layer order as the backward pass proceeds, so the framework buckets them and fires an asynchronous all-reduce per bucket the instant each is ready — communication for the early-finishing layers happens while later layers are still computing. When that overlap is perfect the all-reduce is effectively free, hiding entirely behind compute the GPUs were doing anyway. When it is not — a small model relative to its gradient volume, or a slow interconnect — the all-reduce pokes out from behind the compute and becomes the visible bottleneck.
One consequence is easy to miss and important to get right. The effective batch size is no longer the per-GPU batch; it is the per-GPU batch times the number of GPUs, because every step now consumes one shard from each worker. Eight GPUs at a local batch of 32 train on an effective batch of 256. A larger batch yields a smoother, lower-variance gradient, so each step can afford a larger learning rate — the linear scaling rule says to scale the learning rate roughly in proportion to the effective batch size (usually with a warmup to keep the early, large-LR steps stable). Forget this and your eight-GPU run trains at the same per-step learning rate as the one-GPU run, consumes eight times the data per step, and quietly underfits — same wall-clock speedup, worse accuracy, no error message to explain it.
Build it → A from-scratch, Horovod-style data-parallel engine — AllReduce, gradient bucketization, and computation/communication overlap, the exact mechanics described above — is the subject of Project 40: Distributed Autograd Engine.
Model, tensor, and pipeline parallelism
When a single replica won’t fit, data parallelism has nothing to offer — replicating something that doesn’t fit just fails N times in parallel. You have to split the model itself, and there are two orthogonal ways to cut it.
Pipeline parallelism splits the model across layers: GPU 0 holds layers 1–8, GPU 1 holds layers 9–16, and so on. A batch flows through like an assembly line — GPU 0 processes it and passes its activations to GPU 1, which passes to GPU 2. The communication is light (just the activations between adjacent stages) and tolerant of slower links, which makes pipeline parallelism the natural choice for splitting across machines. Its weakness is the bubble: at startup GPU 1 idles until GPU 0 finishes the first chunk, GPU 2 until GPU 1 finishes, and so on — with four naive stages, three-quarters of the hardware can be idle at any instant. The fix is micro-batching: slice each batch into many small pieces so that once the pipeline is full, every stage works on a different micro-batch simultaneously. The more micro-batches, the smaller the bubble’s share of total time.
Tensor parallelism splits within a layer: a single large matrix multiply is partitioned across GPUs, each computing a slice of the output, with a collective to recombine the result. There is no bubble — every GPU works on every step — but the communication is heavy and happens inside every layer, so tensor parallelism is only viable over a fast interconnect. The firm rule of thumb: tensor-parallel within a node (where NVLink is fast), pipeline-parallel across nodes (where the network is slow but per-stage traffic is light).
A team splitting a transformer across four GPUs did the natural thing: divided the layers evenly, eight per stage. Throughput barely improved over a single GPU. Profiling showed one GPU pegged at 100% while the other three idled — because “eight layers” is not “one quarter of the work.” The cheap embedding and projection layers shared a stage; the dense middle stages were expensive; one GPU got all the expensive ones. In a pipeline the slowest stage caps the entire line — the others can only run as fast as the bottleneck feeds them, and the rest of their time is bubble. Balancing stages by measured FLOPs and memory (not layer count), and feeding enough micro-batches, jumped throughput from roughly 1.2× to 3.4× on the same four GPUs. The lesson generalizes: in any pipeline, the bottleneck stage is the only number that matters, and “equal count” is rarely “equal work.”
Real systems at the largest scale combine all three dimensions — tensor parallelism within a node, pipeline parallelism across nodes, and data parallelism on top for throughput — a configuration usually called 3D parallelism. But that complexity is a last resort, justified only when a model is many times larger than a single card.
Build it → The model-sharding side of this problem — splitting parameters across machines, with push/pull update pipelines and multiple consistency models — is implemented in Project 30: Large-Scale Parameter Server, the parameter-server alternative to the all-reduce approach.
The communication bottleneck
Every strategy in this chapter is ultimately a different answer to the same question: how do you keep the GPUs computing instead of waiting on the wire? It is worth making the bottleneck explicit, because it explains every rule of thumb above.
The dominant collective, all-reduce, is implemented as a ring: each GPU talks only to its two neighbors, passing chunks around so the total data each GPU sends is roughly twice the gradient size, independent of the number of GPUs. That independence is the good news — adding GPUs doesn’t increase per-GPU traffic. The bad news is that the traffic still traverses the interconnect, and interconnects differ by orders of magnitude. Inside one server, NVLink at ~600 GB/s lets the all-reduce hide behind compute. Across servers on InfiniBand, maybe 25 GB/s, and communication starts to show. Over plain Ethernet at ~1 GB/s, all-reduce dominates so completely that large-model data parallelism becomes impractical.
Two levers fight back. The first is overlap, already described — firing communication as gradients become ready so it hides behind compute. The second is reducing the volume on the wire. Mixed-precision training (bf16) helps twice over: it halves compute time on Tensor Cores and halves the bytes each gradient occupies in the all-reduce. More aggressive schemes compress gradients further — top-k sparsification sends only the largest values, methods like PowerSGD send a low-rank approximation — trading a little fidelity for a lot less traffic. These are specialized tools for when the interconnect is genuinely the wall; for most single-node and well-networked multi-node jobs, overlap plus bf16 is enough.
Orchestration with Ray
Everything so far has been about splitting one model across GPUs. There is a layer above it: the framework that places those processes on a cluster, restarts them when a node dies, runs many jobs for hyperparameter search, and stands the result up for serving. Ray is the general-purpose distributed-execution framework that fills that role for Python and ML workloads.
Ray’s model is two primitives. A task is a stateless function you mark @ray.remote and call from anywhere; Ray schedules it on whatever worker has capacity and returns a future. An actor is a stateful object that lives on one worker and holds mutable state across calls — exactly the shape of a parameter server or a model replica. Results live in a shared object store and pass by reference, which matters more than it sounds: hand a multi-gigabyte model to a hundred tasks by value and Ray serializes and ships it a hundred times; ray.put it once and pass the reference, and the job runs an order of magnitude faster. That distinction — by reference, not value — is the most common Ray pitfall and its most common fix.
On those primitives, Ray Train wraps the distributed-training setup from earlier (it configures DDP across its workers), Ray Tune runs many training trials in parallel for hyperparameter search, and Ray Serve deploys the result with autoscaling. The point is the shape, not the specific libraries: the same task-and-actor substrate that scales training scales the work around it, so the whole ML lifecycle runs on one cluster abstraction rather than a pile of stitched-together systems.
Boundary note. This chapter teaches the general parallelism concepts — they apply to a vision model, a recommender, or any large network. Training a foundation model end-to-end (the data pipelines, the curriculum, the megawatt-scale orchestration that goes into a frontier LLM) is its own discipline, covered in the companion Complete AI Engineer book. The parallelism mechanics here are the substrate that work is built on; the foundation-model-specific concerns sit a layer above.
Build it → The orchestration layer — a scheduler that manages training jobs, GPU resources, distributed coordination, and checkpointing — is built in Project 04: ML Training Orchestrator, which drives both the all-reduce and parameter-server backends.
Practical exercise
Difficulty: Level I · Level II · Level III
- Level I — Pick the strategy and justify it. Three scenarios: (a) a ResNet using 12 GB on an 80 GB card but a week to train; (b) a 40-billion-parameter transformer;
- a 200-million-parameter model that trains in two hours on one GPU. For each, say whether you’d use data parallelism, sharded data parallelism, full model parallelism, or no distribution — and name the one fact (fits-or-not, slow-or-not) that drove it.
- Level II — Reason about data-parallel scaling. You move a model from one GPU (local batch 64, learning rate 1e-3) to eight GPUs at the same local batch. State the new effective batch size and the learning rate the linear scaling rule suggests, and why a warmup matters there. Then explain why, fully configured, you might measure only a 6× speedup rather than 8× — name the cost that caps it and the two levers that narrow the gap.
- Level III — Design a multi-node strategy for a too-large model. A model needs roughly 320 GB to train; you have a cluster of 8-GPU nodes (80 GB each, NVLink within, InfiniBand between). Sketch a combined strategy: which dimension of parallelism goes within a node versus across nodes and why, where sharded data parallelism fits, and the two places communication or pipeline bubbles will most likely bottleneck it — plus how you’d spot each in a utilization trace.
Summary
Distributing training is what you do when one GPU is no longer enough — either because training is too slow or because the model no longer fits — and the two reasons demand different tools. Data parallelism replicates the model on every GPU, shards the batch, and keeps the replicas identical with an all-reduce of gradients every step; it scales near-linearly when that communication overlaps compute, and it changes the effective batch size in a way that forces a learning-rate adjustment. Model parallelism — pipeline across layers, tensor within them — splits a too-large model across devices at the cost of pipeline bubbles and heavy in-layer traffic. Sharded data parallelism (ZeRO / FSDP) is the modern middle ground, fitting bigger models by sharding optimizer, gradient, and parameter state while keeping the data-parallel programming model. Across all of them the dominant cost is communication over the interconnect, so the entire discipline reduces to overlapping and minimizing it — and a framework like Ray wraps the whole lifecycle in a task-and-actor cluster abstraction.
Key takeaways
- Two distinct problems: too slow / data too big → data parallelism; model too big → model or sharded-data parallelism. Diagnose which one you have before picking a tool.
- Data parallelism’s correctness rests on the per-step all-reduce of gradients; its speed rests on overlapping that all-reduce with the backward pass.
- Distributing the batch multiplies the effective batch size, so scale the learning rate (linear rule, with warmup) or quietly lose accuracy.
- Pipelines are capped by their slowest stage; balance by measured work, not layer count, and feed micro-batches to shrink the bubble.
- Communication, not computation, is the bottleneck. Interconnect bandwidth plus overlap set the scaling ceiling; bf16 and gradient compression reduce the volume.
- ZeRO / FSDP is the default for large models — most of model parallelism’s memory win with most of data parallelism’s simplicity.
Connections to other chapters
- Deep Learning Frameworks (prerequisite): every strategy here transforms the single-GPU training loop taught there — the forward pass, the
.backward()that triggers the all-reduce, the optimizer step. You cannot reason about where the all-reduce hides until you know what the backward pass is doing. - GPU Programming and CUDA (extension): this chapter treats the GPU as a black box that computes and communicates; the next one opens it — the kernels that do the matrix multiplies, the memory hierarchy that makes data movement expensive, the Tensor Cores that make bf16 fast. The communication cost that dominates scaling is, at bottom, a hardware story.
- Concurrency and Parallelism Models (foundation): all-reduce, ring topologies, and keeping replicas consistent are distributed-systems problems first. Coordination primitives, synchronization cost, and contention on a shared resource are the same patterns that chapter teaches comparatively across languages — this is one high-stakes instance.
- Streaming and Real-Time Data (Part II): both are about coordinating work across many machines under a throughput constraint, where the bottleneck is moving data between nodes rather than the per-record compute. The tools for reasoning about backpressure, overlap, and slowest-link bottlenecks transfer directly.
- Note that training a foundation model end-to-end — the data, scale, and orchestration specific to frontier LLMs — is the domain of the companion Complete AI Engineer book; this chapter teaches the parallelism substrate underneath it.
Further reading
Essential
- PyTorch — Distributed Data Parallel and Fully Sharded Data Parallel (FSDP) documentation — the canonical, runnable references for the two strategies most teams actually use; the DDP internals notes explain the gradient bucketing and overlap.
- “How to Scale Your Model”-style practitioner guides (the scaling playbooks from major labs) — practical decision frameworks for picking and combining parallelism dimensions at a given model and cluster size.
Deep dives
- Rajbhandari et al., “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models” (SC, 2020) — the paper behind sharded data parallelism; reading it makes the stage-1/2/3 memory math concrete.
- Shoeybi et al., “Megatron-LM” (2019) — the reference for tensor parallelism, and Huang et al., “GPipe” (NeurIPS, 2019) — the reference for pipeline parallelism with micro-batching and the bubble analysis.
Historical context
- Moritz et al., “Ray: A Distributed Framework for Emerging AI Applications” (OSDI,
- — the system paper that introduced the task-and-actor model used in the orchestration section.
- Goyal et al., “Accurate, Large Minibatch SGD” (2017) — the work that established the linear learning-rate scaling rule and warmup for large-batch data-parallel training.