Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
Paper: Shoeybi et al., NVIDIA, 2020 Core contribution: Intra-layer (tensor) model parallelism for transformer blocks implemented via targeted all-reduce insertions in PyTorch — no new compiler, no graph rewriting, orthogonal to pipeline parallelism.
Fig 1: System Overview Block Diagram
┌─────────────────────────────────────────────────────────────┐
│ Megatron-LM Training System │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Data Parallel Groups │ │
│ │ (64-way DP: each group holds one model replica) │ │
│ │ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Model Parallel Group (8 GPUs) │ │ │
│ │ │ │ │ │
│ │ │ ┌────────┐ ┌────────┐ ┌────────┐ ┌──────┐ │ │ │
│ │ │ │ GPU-1 │ │ GPU-2 │ │ GPU-3 │ │GPU-8 │ │ │ │
│ │ │ │(shard1)│ │(shard2)│ │(shard3)│ │(sh8) │ │ │ │
│ │ │ └───┬────┘ └───┬────┘ └───┬────┘ └──┬───┘ │ │ │
│ │ │ └──────────┴──────────┴──────────┘ │ │ │
│ │ │ NVLink all-reduce bus │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ ║ │ │
│ │ InfiniBand (inter-node) │ │
│ │ gradient all-reduce across DP groups │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
│ ┌────────────────┐ ┌─────────────────┐ │
│ │ Training Data │══►│ PyTorch DDP │ │
│ │ (174 GB text) │ │ (data loader) │ │
│ └────────────────┘ └────────┬────────┘ │
│ │ mini-batch │
│ ▼ │
│ ┌───────────────────────┐ │
│ │ Transformer Layers │ │
│ │ (N layers, each │ │
│ │ model-parallelized) │ │
│ └───────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
▲ Fig 1: Megatron-LM full system — 512 GPUs, 8-way model parallel
x 64-way data parallel for 8.3B parameter GPT-2 training
This design places the model-parallel communication entirely within a single NVLink-connected server (8 GPUs per DGX-2H node), while gradient all-reduces for data parallelism cross node boundaries over InfiniBand. This topology match minimizes inter-node communication by keeping the high-bandwidth intra-layer communication on NVLink.
Fig 2: Key Architecture Diagram — Intra-Layer Tensor Parallelism
MLP BLOCK PARALLELISM (column-then-row split)
┌───────────────────────────────────────────────────────┐
│ Input X (full, replicated on each GPU) │
└──────────────────────┬────────────────────────────────┘
│
┌──────────────┴──────────────┐
▼ ▼
┌───────────────┐ ┌───────────────┐
│ GPU-1 │ │ GPU-2 │
│ A₁ (cols) │ │ A₂ (cols) │
│ GeLU(XA₁) │ │ GeLU(XA₂) │
│ ↓ │ │ ↓ │
│ Y₁ = out₁ │ │ Y₂ = out₂ │
│ B₁ (rows) │ │ B₂ (rows) │
│ Y₁B₁ → Z₁ │ │ Y₂B₂ → Z₂ │
└───────┬───────┘ └───────┬───────┘
└──────────┬──────────────────┘
▼
[g: all-reduce] ← single sync point
│
▼
Z = Z₁ + Z₂ + Dropout
▲ Fig 2: MLP block tensor parallelism — column-parallel first GEMM,
row-parallel second GEMM, one all-reduce in forward pass
SELF-ATTENTION BLOCK PARALLELISM (head-parallel split)
┌───────────────────────────────────────────────────────┐
│ Input X (full, replicated) │
└──────────────────────┬────────────────────────────────┘
│
┌──────────────┴──────────────┐
▼ ▼
┌───────────────┐ ┌───────────────┐
│ GPU-1 │ │ GPU-2 │
│ Q₁, K₁, V₁ │ │ Q₂, K₂, V₂ │
│ (heads 1..h/2)│ │(heads h/2+1..h)│
│ Attention₁ │ │ Attention₂ │
│ → Y₁ │ │ → Y₂ │
│ Linear B₁ │ │ Linear B₂ │
└───────┬───────┘ └───────┬───────┘
└──────────┬──────────────────┘
▼
[g: all-reduce] ← single sync point
│
▼
Z = dropout(Y)
▲ Fig 3: Self-attention tensor parallelism — attention heads split
across GPUs, single all-reduce after output projection
The key insight is that both MLP and self-attention are structured so that the non-linearity (GeLU) or the softmax can be applied locally on each GPU without a synchronization, deferring the only required all-reduce to after the second GEMM in each sub-block.
Fig 4: Control / Data Flow Diagram — One Forward + Backward Pass
START: mini-batch arrives at each GPU in model parallel group
│
▼
① [Input embedding lookup]
│ X replicated on all model-parallel GPUs
▼
② [Layer 1 .. N: for each transformer layer]
│
├─► [Self-Attention block]
│ │
│ ├── local Q,K,V GEMMs (column-parallel, no sync)
│ ├── local softmax + attention (no sync)
│ ├── local output projection (row-parallel)
│ └── all-reduce(Z_attn) ◄── sync point 1 (fwd)
│
├─► [LayerNorm + residual] (duplicated on each GPU, no sync)
│
├─► [MLP block]
│ │
│ ├── local GeLU GEMM (column-parallel, no sync)
│ ├── local second GEMM (row-parallel, no sync)
│ └── all-reduce(Z_mlp) ◄── sync point 2 (fwd)
│
└─► [LayerNorm + residual + Dropout] (local)
│
▼
③ [Output embedding GEMM — column-parallel]
│ parallel logits Y₁, Y₂
├── all-gather(Y)
└── fused cross-entropy loss (scalar communicated, not logits)
│
▼
④ [Backward pass — same 2 all-reduces per layer, conjugate ops]
│ f: identity fwd / all-reduce bwd
│ g: all-reduce fwd / identity bwd
▼
⑤ [Gradient all-reduce across data-parallel replicas]
│ (InfiniBand, between servers)
▼
⑥ [Optimizer step — each GPU updates its own parameter shard]
│
END
▲ Fig 4: Control and data flow for one Megatron-LM training step.
Exactly 4 all-reduces per transformer layer (2 fwd + 2 bwd).
The design minimizes synchronization to exactly 2 all-reduces per layer per direction by fusing two GEMMs per sub-block and exploiting the column/row split structure. This is less communication than adding one all-reduce per GEMM (which would be 4 per direction per layer).
Fig 5: Layered Software Stack
┌───────────────────────────────────────────────────┐
│ User Training Script (GPT-2 / BERT config) │
│ (defines layers, hidden size, attention heads) │
├───────────────────────────────────────────────────┤
│ Megatron-LM Model Parallelism Layer │
│ (f/g conjugate ops, column/row-parallel linear) │
├───────────────────────────────────────────────────┤
│ PyTorch DDP (data parallelism, gradient sync) │
├───────────────────────────────────────────────────┤
│ NCCL (all-reduce primitives) │
├───────────────────────────────────────────────────┤
│ NVLink (intra-node) / InfiniBand (inter-node) │
│ (hardware transport layer) │
└───────────────────────────────────────────────────┘
▲ Fig 5: Software stack — Megatron parallelism inserted between
PyTorch model definition and DDP without compiler changes
Fig 6: Hybrid Model + Data Parallelism GPU Grouping
512 GPUs total: 8-way model parallel x 64-way data parallel
┌────────────────────────────────────────────────────────┐
│ Server 1 (DGX-2H) Server 2 (DGX-2H) │
│ ┌──────────────────┐ ┌──────────────────┐ │
│ │ Model Parallel │ │ Model Parallel │ │
│ │ Group 1 │ │ Group 2 │ │
│ │ GPU-1 .. GPU-8 │ │ GPU-9 .. GPU-16 │ │
│ │ (same model shard│ │ (same model shard│ │
│ │ different data) │ │ different data) │ │
│ └────────┬─────────┘ └────────┬─────────┘ │
└───────────┼─────────────────────────┼─────────────────┘
│ Data Parallel │
│ all-reduce │
│ (gradient sync) │
│◄────────────────────────►│
│ │
│ GPU-1 (grp1) ◄────────► GPU-9 (grp2)
│ (same model position, same shard)
│ communicate gradients only
▲ Fig 6: GPU grouping for hybrid parallelism. GPUs at same position
in different model-parallel groups form a data-parallel group.
Design Trade-off Analysis
| Design Decision | Alternative A | Alternative B (Megatron choice) | Winner | Why |
|---|---|---|---|---|
| Parallelism granularity | Layer-wise pipeline (GPipe) | Intra-layer tensor split | B | No pipeline bubbles, no microbatch logic, simpler implementation |
| MLP first-GEMM split | Row-parallel (requires sync before GeLU) | Column-parallel (GeLU applied locally) | B | Eliminates one all-reduce per MLP block; GeLU is element-wise, separable |
| Attention head split | Replicate all Q,K,V (no split) | Column-parallel per head | B | Each head computed independently; no cross-head communication needed |
| Output embedding | Compute full logits, all-gather | Parallel GEMMs, fuse with cross-entropy | B | Avoids communicating bsv logits; only scalar loss crosses GPUs |
| LayerNorm / residual placement | Model-parallel (split across GPUs) | Duplicated on every GPU | B | LayerNorm operates on full hidden dim; duplication avoids all-reduce |
| Framework coupling | Custom compiler (Mesh-TensorFlow) | Native PyTorch + few custom ops | B | No rewriting, no compilation step, easy adoption |
| Communication fabric | PCIe (inter-GPU) | NVLink intra-node, IB inter-node | B | Model-parallel all-reduces stay on NVLink (600 GB/s); DP gradient sync on IB |
| BERT layer norm position | Post-attention (original BERT) | Pre-attention (rearranged) | B | Eliminates training instability at large scale; loss decreases monotonically |
For DynamICCL context: The Megatron-LM workload generates exactly 4 all-reduces per transformer layer per training step (2 forward, 2 backward), all of fixed size determined by hidden dimension. This is highly predictable traffic — a prime candidate for static NCCL configuration tuning.
What to Borrow for DynamICCL
1. Predictable, periodic collective pattern as a scheduling signal. Megatron-LM's training loop generates all-reduces at fixed intervals and fixed sizes (hidden_dim x batch_size x seq_len). DynamICCL's Trigger Agent (LSTM+CUSUM) can exploit this periodicity: if it observes that collective sizes are constant across iterations, it can suppress re-probing and lock in the optimal NCCL config for longer, reducing tuning overhead.
2. NVLink-vs-InfiniBand traffic separation maps to algorithm selection. Megatron model-parallel all-reduces (small, intra-node, latency-sensitive) use NVLink; data-parallel gradient all-reduces (large, inter-node, bandwidth-sensitive) use InfiniBand. DynamICCL should treat these as two distinct collective classes: intra-node small all-reduces should prefer NCCL LL or LL128 protocol with ring algorithm on NVLink channels; inter-node large gradient all-reduces should prefer Simple protocol with ring or collnet_direct depending on switch topology.
3. Communication-compute overlap opportunity. Megatron notes that model-parallel all-reduces (4 per layer) block the forward/backward compute. DynamICCL's Config Agent should minimize latency (not just throughput) for these blocking all-reduces — the reward function should weight latency heavily for small intra-layer collectives and throughput for large gradient collectives.
4. Fixed message size bins per collective class. The all-reduce size in Megatron is determined by the split shard of the weight matrix: (hidden_size / num_model_parallel) x batch x seq. These create natural message-size bins that DynamICCL can use as state features: when the Config Agent observes a collective with message size in the "small, intra-layer" bin, it should immediately select LL protocol; for "large, gradient-sync" bin, select Simple + ring.
5. Scaling efficiency as an optimization target signal. Megatron reports 76% weak scaling efficiency at 512 GPUs. Each percentage point of scaling inefficiency corresponds to wasted collective communication time. DynamICCL can use scaling efficiency (throughput / (N x single-GPU throughput)) as a secondary reward signal, aligning config selection with the practitioner's actual goal.