ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
Paper: Rajbhandari et al., Microsoft, 2020 (DeepSpeed) Core contribution: Zero Redundancy Optimizer — partitions model states (optimizer states, gradients, parameters) across data-parallel ranks instead of replicating them, achieving memory reduction linear in DP degree while preserving the communication volume of standard data parallelism.
Fig 1: System Overview Block Diagram
┌──────────────────────────────────────────────────────────────┐
│ ZeRO Training System │
│ │
│ ┌───────────────────────────────────────────────────────┐ │
│ │ Data Parallel Process Group (N_d GPUs) │ │
│ │ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ GPU-0 │ │ GPU-1 │ ... │ GPU-N_d-1│ │ │
│ │ │ │ │ │ │ │ │ │
│ │ │ shard-0 │ │ shard-1 │ │ shard-k │ │ │
│ │ │ opt_st │ │ opt_st │ │ opt_st │ │ │
│ │ │ grads │ │ grads │ │ grads │ │ │
│ │ │ params │ │ params │ │ params │ │ │
│ │ └────┬─────┘ └────┬─────┘ └────┬─────┘ │ │
│ │ └─────────────┴─────────────────┘ │ │
│ │ reduce-scatter / all-gather (InfiniBand) │ │
│ └───────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────┐ ┌──────────────────────────────────┐ │
│ │ ZeRO-DP │ │ ZeRO-R │ │
│ │ (model states) │ │ (residual states) │ │
│ │ │ │ │ │
│ │ P_os (stage1) │ │ P_a (activation partitioning) │ │
│ │ P_os+g (stage2)│ │ C_B (constant-size buffers) │ │
│ │ P_os+g+p(st3) │ │ M_D (memory defragmentation) │ │
│ └─────────────────┘ └──────────────────────────────────┘ │
└──────────────────────────────────────────────────────────────┘
▲ Fig 1: ZeRO full system — ZeRO-DP partitions model states across
N_d data-parallel GPUs; ZeRO-R manages residual memory sources
Fig 2: Key Architecture Diagram — Three-Stage Memory Partitioning
MEMORY PER GPU: model with Ψ=7.5B params, K=12 (Adam), N_d=64
┌────────────────────────────────────────────────────────┐
│ Baseline (standard DP) — every GPU holds everything │
│ │
│ [ params fp16 (2Ψ) ][ grads fp16 (2Ψ) ] │
│ [ optimizer states (K·Ψ=12Ψ) ] │
│ = 120 GB │
└────────────────────────────────────────────────────────┘
│
│ Stage 1: P_os — partition optimizer states
▼
┌────────────────────────────────────────────────────────┐
│ P_os: each GPU owns 1/N_d of optimizer states │
│ │
│ [ params fp16 (2Ψ) ][ grads fp16 (2Ψ) ] │
│ [ opt_states (KΨ/N_d) ] │
│ = 31.4 GB │
│ Communication: same as baseline DP (all-reduce) │
└────────────────────────────────────────────────────────┘
│
│ Stage 2: P_os+g — also partition gradients
▼
┌────────────────────────────────────────────────────────┐
│ P_os+g: each GPU owns 1/N_d of grads + opt_states │
│ │
│ [ params fp16 (2Ψ) ][ grads (2Ψ/N_d) ] │
│ [ opt_states (KΨ/N_d) ] │
│ = 16.6 GB │
│ Communication: reduce-scatter then all-gather = 2Ψ │
└────────────────────────────────────────────────────────┘
│
│ Stage 3: P_os+g+p — also partition parameters
▼
┌────────────────────────────────────────────────────────┐
│ P_os+g+p: each GPU owns 1/N_d of everything │
│ │
│ [ params (2Ψ/N_d) ][ grads (2Ψ/N_d) ] │
│ [ opt_states (KΨ/N_d) ] │
│ = 1.9 GB │
│ Communication: 1.5x baseline (3Ψ total volume) │
└────────────────────────────────────────────────────────┘
▲ Fig 2: ZeRO-DP three stages. Memory shrinks by 4x, 8x, then N_d x.
Stages 1 and 2 incur zero additional communication vs baseline DP.
Fig 3: Control / Data Flow Diagram — ZeRO-DP Training Step
START: mini-batch arrives, each GPU holds its param shard
│
▼
① [Forward pass — P_os+g+p mode]
│
├── for each layer i (owner = GPU i mod N_d):
│ ├── GPU_owner broadcasts param_shard_i to all GPUs
│ │ (all-gather within DP group) ── Ψ/N_d bytes
│ ├── all GPUs compute forward with full params
│ └── GPU_owner discards received params after use
│
▼
② [Backward pass — gradient accumulation]
│
├── for each layer i (reverse order):
│ ├── all-gather params if needed for recomputation
│ ├── compute gradients locally
│ └── reduce-scatter gradients:
│ each GPU accumulates only its shard
│ (communication = Ψ/N_d per GPU)
│
▼
③ [Optimizer step]
│
├── each GPU updates its own parameter shard
│ using its local gradient shard and opt_state shard
│ (purely local compute, no communication)
│
▼
④ [All-gather updated parameters]
│ each GPU broadcasts its updated shard
│ (communication = Ψ bytes total = Ψ/N_d per GPU)
│
▼
⑤ [ZeRO-R: memory housekeeping]
│ ├── move short-lived activation grads to pre-alloc buffers
│ ├── release long-lived activation checkpoints via P_a
│ └── defragment memory via M_D
│
END — total communication per step = 3Ψ (1.5x baseline)
▲ Fig 3: ZeRO-DP (P_os+g+p) forward+backward+optimizer step.
Each GPU communicates exactly its param shard, not the full model.
Fig 4: State Machine — ZeRO-DP Parameter Lifecycle
broadcast (all-gather)
[PARTITIONED] ─────────────────────────► [MATERIALIZED]
▲ │
│ │ compute fwd/bwd
│ │ with full shard
│ ▼
[DISCARDED] ◄─── discard after use ───── [IN USE]
▲ │
│ │ grad ready
│ ▼
└─────── reduce-scatter grads ──── [GRAD AVAILABLE]
(bucket full trigger) │
│ optimizer step
▼
[UPDATED SHARD]
│
all-gather ──► [MATERIALIZED]
▲ Fig 4: Parameter lifecycle per ZeRO-DP rank. Parameters exist in
full form only during compute; all other time they are partitioned.
Fig 5: ZeRO-R Residual Memory Components
┌────────────────────────────────────────────────────────┐
│ ZeRO-R Subsystem │
│ │
│ ┌──────────────────────────────────────────────────┐ │
│ │ P_a: Partitioned Activation Checkpointing │ │
│ │ │ │
│ │ Forward pass: │ │
│ │ checkpoint activations ──► partition across │ │
│ │ model-parallel GPUs (each holds 1/MP_degree) │ │
│ │ │ │
│ │ Backward pass: │ │
│ │ all-gather partitions ──► recompute locally │ │
│ │ (1 all-gather per transformer block, ~10% │ │
│ │ of MP communication volume) │ │
│ └──────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────┐ │
│ │ C_B: Constant-Size Fused Buffers │ │
│ │ │ │
│ │ [cap at fixed size] ──► gradient fused buffer │ │
│ │ prevents O(model_size) buffer blowup │ │
│ │ trades some efficiency for memory safety │ │
│ └──────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────┐ │
│ │ M_D: Memory Defragmentation │ │
│ │ │ │
│ │ Forward: long-lived (checkpoints) interleaved │ │
│ │ with short-lived (recomputed activ.) │ │
│ │ Backward: long-lived (param grads) interleaved │ │
│ │ with short-lived (activ. grads) │ │
│ │ │ │
│ │ Solution: pre-allocate contiguous chunks, │ │
│ │ copy tensors on creation to maintain layout │ │
│ └──────────────────────────────────────────────────┘ │
└────────────────────────────────────────────────────────┘
▲ Fig 5: ZeRO-R three residual memory optimizations. P_a eliminates
activation replication in model parallelism; C_B caps buffer size;
M_D prevents OOM from fragmentation despite sufficient free memory.
Fig 6: Communication Volume Analysis
Standard DP all-reduce decomposition:
┌──────────────────────────────────────────────────────┐
│ AllReduce = ReduceScatter + AllGather │
│ │
│ reduce-scatter: Ψ data moved per GPU │
│ all-gather: Ψ data moved per GPU │
│ total: 2Ψ per GPU per step │
└──────────────────────────────────────────────────────┘
ZeRO-DP P_os and P_os+g:
┌──────────────────────────────────────────────────────┐
│ Same reduce-scatter + all-gather structure │
│ Communication volume = 2Ψ (identical to DP) │
│ But memory per device reduced 4x / 8x │
└──────────────────────────────────────────────────────┘
ZeRO-DP P_os+g+p:
┌──────────────────────────────────────────────────────┐
│ Fwd: all-gather params = Ψ (rescaled over fwd pass) │
│ Bwd: all-gather params = Ψ (rescaled over bwd pass) │
│ Bwd: reduce-scatter grads = Ψ │
│ total = 3Ψ (1.5x baseline DP) │
│ But memory per device reduced N_d times │
└──────────────────────────────────────────────────────┘
▲ Fig 6: ZeRO communication volume. Stages 1 and 2 are free in
communication cost; Stage 3 costs 1.5x for N_d x memory reduction.
Design Trade-off Analysis
| Design Decision | Alternative A | Alternative B (ZeRO choice) | Winner | Why |
|---|---|---|---|---|
| Model state storage | Replicate all states (DP) | Partition states across ranks | B | Linear memory reduction with N_d; DP communication efficiency retained |
| Gradient reduction | All-reduce (DP baseline) | Reduce-scatter to owner, then all-gather | B | Same 2Ψ volume but enables gradient partitioning; map to ReduceScatter + AllGather |
| Parameter availability | Always replicated (wasteful) | On-demand all-gather, discard after use | B | Enables full N_d x memory reduction; temporal nature of param need exploited |
| Communication overhead of P_p | Unavoidable if params partitioned | Pipelining param all-gather over fwd pass | B | 1.5x vs 2x overhead by spreading all-gather across layers |
| Activation memory in MP | Replicate activations (MP default) | Partition activations, all-gather on demand | B | MP degree x reduction in activation memory; <10% communication overhead |
| Temporary buffer sizing | Scale with model (unbounded) | Cap at fixed size C_B | B | Prevents OOM for large models; minor efficiency trade-off for large N |
| Memory fragmentation | Accept fragmentation | Pre-allocate contiguous buffers, copy on create | B | Eliminates OOM from fragmentation even when free memory exists |
| Usability vs MP | Requires model code changes (Megatron-style) | Wraps any torch.nn.Module with no changes | B | Democratizes large model training; no domain expertise required |
For DynamICCL context: ZeRO transforms the collective pattern from a single large all-reduce into a reduce-scatter followed by an all-gather, both of which are available as distinct NCCL primitives. This is directly actionable for DynamICCL's collective selection.
What to Borrow for DynamICCL
1. ReduceScatter + AllGather as the canonical large-scale gradient pattern. ZeRO proves that the modern data-parallel training pattern is not AllReduce but ReduceScatter followed by AllGather (the two primitives that together implement AllReduce). DynamICCL must treat these as separate collectives, each with its own optimal NCCL config. ReduceScatter is bandwidth-bound and benefits from ring algorithm + Simple protocol at large message sizes. AllGather after the optimizer step is also bandwidth-bound but latency is less critical since it runs after the optimizer step. These two collectives should be tuned independently by the Config Agent.
2. Bucketed gradient reduction maps to message-size scheduling. ZeRO uses gradient bucketization (accumulate a full bucket before reducing) to control the message size of each reduce-scatter operation. DynamICCL's Trigger Agent should monitor the effective message size crossing the NCCL boundary: when ZeRO's bucket size increases (due to model size or ZeRO stage change), the optimal NCCL config shifts from LL protocol (small, latency-optimized) to Simple protocol (large, bandwidth-optimized). The bucket size is a direct input to DynamICCL's state representation.
3. Stage-dependent collective mix as a workload fingerprint. ZeRO-DP stage 1 (P_os) generates only standard all-reduces; stage 2 (P_os+g) generates reduce-scatter + all-gather pairs; stage 3 (P_os+g+p) generates all-gather in the forward pass too. Each stage has a distinct collective fingerprint. DynamICCL's LSTM encoder should be trained to distinguish these fingerprints from the sequence of collective types and sizes, and apply distinct optimal configs for each regime without requiring explicit user labeling of the ZeRO stage.
4. Super-linear scaling from larger effective batch size. ZeRO's super-linear speedup arises because freeing memory allows larger batch sizes, increasing arithmetic intensity. DynamICCL should factor this into its reward model: when memory pressure drops (observable via GPU memory usage monitoring), the system may switch to larger batch sizes that in turn change collective message sizes. The Config Agent's state should include a memory-pressure signal to anticipate this regime transition.
5. On-demand parameter all-gather as a new latency-critical collective. ZeRO stage 3's per-layer parameter all-gather during the forward pass is on the critical path — it blocks compute at each transformer layer. This is analogous to Megatron's intra-layer all-reduce: a small, latency-critical, high-frequency collective. DynamICCL should classify these parameter all-gathers as LL-protocol candidates with low channel count to minimize latency, distinct from the large gradient reduce-scatter operations.