Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

Paper: Shoeybi et al., NVIDIA, 2020 Core contribution: Intra-layer (tensor) model parallelism for transformer blocks implemented via targeted all-reduce insertions in PyTorch — no new compiler, no graph rewriting, orthogonal to pipeline parallelism.


Fig 1: System Overview Block Diagram

┌─────────────────────────────────────────────────────────────┐
│                  Megatron-LM Training System                │
│                                                             │
│  ┌──────────────────────────────────────────────────────┐   │
│  │              Data Parallel Groups                    │   │
│  │   (64-way DP: each group holds one model replica)    │   │
│  │                                                      │   │
│  │  ┌─────────────────────────────────────────────┐    │   │
│  │  │        Model Parallel Group (8 GPUs)         │    │   │
│  │  │                                              │    │   │
│  │  │  ┌────────┐ ┌────────┐ ┌────────┐ ┌──────┐  │    │   │
│  │  │  │ GPU-1  │ │ GPU-2  │ │ GPU-3  │ │GPU-8 │  │    │   │
│  │  │  │(shard1)│ │(shard2)│ │(shard3)│ │(sh8) │  │    │   │
│  │  │  └───┬────┘ └───┬────┘ └───┬────┘ └──┬───┘  │    │   │
│  │  │      └──────────┴──────────┴──────────┘      │    │   │
│  │  │           NVLink all-reduce bus               │    │   │
│  │  └─────────────────────────────────────────────┘    │   │
│  │                      ║                               │   │
│  │              InfiniBand (inter-node)                 │   │
│  │        gradient all-reduce across DP groups          │   │
│  └──────────────────────────────────────────────────────┘   │
│                                                             │
│  ┌────────────────┐   ┌─────────────────┐                  │
│  │  Training Data │══►│  PyTorch DDP    │                  │
│  │  (174 GB text) │   │  (data loader)  │                  │
│  └────────────────┘   └────────┬────────┘                  │
│                                │ mini-batch                 │
│                                ▼                            │
│                    ┌───────────────────────┐                │
│                    │  Transformer Layers   │                │
│                    │  (N layers, each      │                │
│                    │   model-parallelized) │                │
│                    └───────────────────────┘                │
└─────────────────────────────────────────────────────────────┘
▲ Fig 1: Megatron-LM full system — 512 GPUs, 8-way model parallel
  x 64-way data parallel for 8.3B parameter GPT-2 training

This design places the model-parallel communication entirely within a single NVLink-connected server (8 GPUs per DGX-2H node), while gradient all-reduces for data parallelism cross node boundaries over InfiniBand. This topology match minimizes inter-node communication by keeping the high-bandwidth intra-layer communication on NVLink.


Fig 2: Key Architecture Diagram — Intra-Layer Tensor Parallelism

  MLP BLOCK PARALLELISM (column-then-row split)
  ┌───────────────────────────────────────────────────────┐
  │  Input X (full, replicated on each GPU)               │
  └──────────────────────┬────────────────────────────────┘
                         │
          ┌──────────────┴──────────────┐
          ▼                             ▼
  ┌───────────────┐             ┌───────────────┐
  │   GPU-1       │             │   GPU-2       │
  │  A₁ (cols)   │             │  A₂ (cols)   │
  │  GeLU(XA₁)  │             │  GeLU(XA₂)  │
  │     ↓         │             │     ↓         │
  │  Y₁ = out₁   │             │  Y₂ = out₂   │
  │  B₁ (rows)   │             │  B₂ (rows)   │
  │  Y₁B₁ → Z₁   │             │  Y₂B₂ → Z₂   │
  └───────┬───────┘             └───────┬───────┘
          └──────────┬──────────────────┘
                     ▼
              [g: all-reduce]          ← single sync point
                     │
                     ▼
              Z = Z₁ + Z₂ + Dropout
▲ Fig 2: MLP block tensor parallelism — column-parallel first GEMM,
  row-parallel second GEMM, one all-reduce in forward pass
  SELF-ATTENTION BLOCK PARALLELISM (head-parallel split)
  ┌───────────────────────────────────────────────────────┐
  │  Input X (full, replicated)                           │
  └──────────────────────┬────────────────────────────────┘
                         │
          ┌──────────────┴──────────────┐
          ▼                             ▼
  ┌───────────────┐             ┌───────────────┐
  │   GPU-1       │             │   GPU-2       │
  │  Q₁, K₁, V₁  │             │  Q₂, K₂, V₂  │
  │  (heads 1..h/2)│            │(heads h/2+1..h)│
  │  Attention₁   │             │  Attention₂   │
  │  → Y₁         │             │  → Y₂         │
  │  Linear B₁    │             │  Linear B₂    │
  └───────┬───────┘             └───────┬───────┘
          └──────────┬──────────────────┘
                     ▼
              [g: all-reduce]          ← single sync point
                     │
                     ▼
              Z = dropout(Y)
▲ Fig 3: Self-attention tensor parallelism — attention heads split
  across GPUs, single all-reduce after output projection

The key insight is that both MLP and self-attention are structured so that the non-linearity (GeLU) or the softmax can be applied locally on each GPU without a synchronization, deferring the only required all-reduce to after the second GEMM in each sub-block.


Fig 4: Control / Data Flow Diagram — One Forward + Backward Pass

  START: mini-batch arrives at each GPU in model parallel group
    │
    ▼
① [Input embedding lookup]
    │  X replicated on all model-parallel GPUs
    ▼
② [Layer 1 .. N: for each transformer layer]
    │
    ├─► [Self-Attention block]
    │       │
    │       ├── local Q,K,V GEMMs (column-parallel, no sync)
    │       ├── local softmax + attention (no sync)
    │       ├── local output projection (row-parallel)
    │       └── all-reduce(Z_attn)  ◄── sync point 1 (fwd)
    │
    ├─► [LayerNorm + residual]  (duplicated on each GPU, no sync)
    │
    ├─► [MLP block]
    │       │
    │       ├── local GeLU GEMM (column-parallel, no sync)
    │       ├── local second GEMM (row-parallel, no sync)
    │       └── all-reduce(Z_mlp)  ◄── sync point 2 (fwd)
    │
    └─► [LayerNorm + residual + Dropout] (local)
    │
    ▼
③ [Output embedding GEMM — column-parallel]
    │  parallel logits Y₁, Y₂
    ├── all-gather(Y)
    └── fused cross-entropy loss (scalar communicated, not logits)
    │
    ▼
④ [Backward pass — same 2 all-reduces per layer, conjugate ops]
    │  f: identity fwd / all-reduce bwd
    │  g: all-reduce fwd / identity bwd
    ▼
⑤ [Gradient all-reduce across data-parallel replicas]
    │  (InfiniBand, between servers)
    ▼
⑥ [Optimizer step — each GPU updates its own parameter shard]
    │
  END
▲ Fig 4: Control and data flow for one Megatron-LM training step.
  Exactly 4 all-reduces per transformer layer (2 fwd + 2 bwd).

The design minimizes synchronization to exactly 2 all-reduces per layer per direction by fusing two GEMMs per sub-block and exploiting the column/row split structure. This is less communication than adding one all-reduce per GEMM (which would be 4 per direction per layer).


Fig 5: Layered Software Stack

┌───────────────────────────────────────────────────┐
│  User Training Script (GPT-2 / BERT config)       │
│  (defines layers, hidden size, attention heads)   │
├───────────────────────────────────────────────────┤
│  Megatron-LM Model Parallelism Layer              │
│  (f/g conjugate ops, column/row-parallel linear)  │
├───────────────────────────────────────────────────┤
│  PyTorch DDP (data parallelism, gradient sync)    │
├───────────────────────────────────────────────────┤
│  NCCL (all-reduce primitives)                     │
├───────────────────────────────────────────────────┤
│  NVLink (intra-node) / InfiniBand (inter-node)    │
│  (hardware transport layer)                       │
└───────────────────────────────────────────────────┘
▲ Fig 5: Software stack — Megatron parallelism inserted between
  PyTorch model definition and DDP without compiler changes

Fig 6: Hybrid Model + Data Parallelism GPU Grouping

  512 GPUs total: 8-way model parallel x 64-way data parallel

  ┌────────────────────────────────────────────────────────┐
  │  Server 1 (DGX-2H)          Server 2 (DGX-2H)         │
  │  ┌──────────────────┐       ┌──────────────────┐       │
  │  │ Model Parallel   │       │ Model Parallel   │       │
  │  │ Group 1          │       │ Group 2          │       │
  │  │ GPU-1 .. GPU-8   │       │ GPU-9 .. GPU-16  │       │
  │  │ (same model shard│       │ (same model shard│       │
  │  │  different data) │       │  different data) │       │
  │  └────────┬─────────┘       └────────┬─────────┘       │
  └───────────┼─────────────────────────┼─────────────────┘
              │   Data Parallel          │
              │   all-reduce             │
              │   (gradient sync)        │
              │◄────────────────────────►│
              │                         │
              │   GPU-1 (grp1) ◄────────► GPU-9 (grp2)
              │   (same model position,  same shard)
              │   communicate gradients only
▲ Fig 6: GPU grouping for hybrid parallelism. GPUs at same position
  in different model-parallel groups form a data-parallel group.

Design Trade-off Analysis

Design Decision Alternative A Alternative B (Megatron choice) Winner Why
Parallelism granularity Layer-wise pipeline (GPipe) Intra-layer tensor split B No pipeline bubbles, no microbatch logic, simpler implementation
MLP first-GEMM split Row-parallel (requires sync before GeLU) Column-parallel (GeLU applied locally) B Eliminates one all-reduce per MLP block; GeLU is element-wise, separable
Attention head split Replicate all Q,K,V (no split) Column-parallel per head B Each head computed independently; no cross-head communication needed
Output embedding Compute full logits, all-gather Parallel GEMMs, fuse with cross-entropy B Avoids communicating bsv logits; only scalar loss crosses GPUs
LayerNorm / residual placement Model-parallel (split across GPUs) Duplicated on every GPU B LayerNorm operates on full hidden dim; duplication avoids all-reduce
Framework coupling Custom compiler (Mesh-TensorFlow) Native PyTorch + few custom ops B No rewriting, no compilation step, easy adoption
Communication fabric PCIe (inter-GPU) NVLink intra-node, IB inter-node B Model-parallel all-reduces stay on NVLink (600 GB/s); DP gradient sync on IB
BERT layer norm position Post-attention (original BERT) Pre-attention (rearranged) B Eliminates training instability at large scale; loss decreases monotonically

For DynamICCL context: The Megatron-LM workload generates exactly 4 all-reduces per transformer layer per training step (2 forward, 2 backward), all of fixed size determined by hidden dimension. This is highly predictable traffic — a prime candidate for static NCCL configuration tuning.


What to Borrow for DynamICCL

1. Predictable, periodic collective pattern as a scheduling signal. Megatron-LM's training loop generates all-reduces at fixed intervals and fixed sizes (hidden_dim x batch_size x seq_len). DynamICCL's Trigger Agent (LSTM+CUSUM) can exploit this periodicity: if it observes that collective sizes are constant across iterations, it can suppress re-probing and lock in the optimal NCCL config for longer, reducing tuning overhead.

2. NVLink-vs-InfiniBand traffic separation maps to algorithm selection. Megatron model-parallel all-reduces (small, intra-node, latency-sensitive) use NVLink; data-parallel gradient all-reduces (large, inter-node, bandwidth-sensitive) use InfiniBand. DynamICCL should treat these as two distinct collective classes: intra-node small all-reduces should prefer NCCL LL or LL128 protocol with ring algorithm on NVLink channels; inter-node large gradient all-reduces should prefer Simple protocol with ring or collnet_direct depending on switch topology.

3. Communication-compute overlap opportunity. Megatron notes that model-parallel all-reduces (4 per layer) block the forward/backward compute. DynamICCL's Config Agent should minimize latency (not just throughput) for these blocking all-reduces — the reward function should weight latency heavily for small intra-layer collectives and throughput for large gradient collectives.

4. Fixed message size bins per collective class. The all-reduce size in Megatron is determined by the split shard of the weight matrix: (hidden_size / num_model_parallel) x batch x seq. These create natural message-size bins that DynamICCL can use as state features: when the Config Agent observes a collective with message size in the "small, intra-layer" bin, it should immediately select LL protocol; for "large, gradient-sync" bin, select Simple + ring.

5. Scaling efficiency as an optimization target signal. Megatron reports 76% weak scaling efficiency at 512 GPUs. Each percentage point of scaling inefficiency corresponds to wasted collective communication time. DynamICCL can use scaling efficiency (throughput / (N x single-GPU throughput)) as a secondary reward signal, aligning config selection with the practitioner's actual goal.