MegaScale: Scaling Large Language Model Training to More Than 10,000 GPUs
Ziheng Jiang, Haibin Lin, Yinmin Zhong, Qi Huang, Yangrui Chen, Zhi Zhang, Yanghua Peng, Xiang Li, Cong Xie, Shibiao Nong, Yulu Jia, Sun He, Hongmin Chen, Zhihao Bai, Qi Hou, Shipeng Yan, Ding Zhou, Yiyao Sheng, Zhuo Jiang, Haohan Xu, Haoran Wei, Zhang Zhang, Pengfei Nie, Leqi Zou, Sida Zhao, Liang Xiang, Zherui Liu, Zhe Li, Xiaoying Jia, Jianxi Ye, Xin Jin, Xin Liu | ByteDance, Peking University | NSDI '24
Problem
Training LLMs of GPT-3 / PaLM scale (175B-540B parameters) requires distributing the work across more than 10,000 GPUs for weeks at a time. At this scale two distinct problems become first-order: training efficiency, measured by Model FLOPs Utilization (MFU = observed throughput / theoretical maximum), and training stability, the fraction of wall-clock time the job spends making forward progress rather than recovering from failures or straggling. A single failure on a 12k-GPU job is proportionally far more expensive than on a small cluster, so even rare faults dominate end-to-end training duration. Public reports on giant-model training (GPT-4, PaLM, LLaMA) describe the model but rarely share the infrastructure that makes them run; this paper fills that gap with a production-tested systems story.
Core Insight
Two tightly-linked principles must work together to scale LLM training to 10k+ GPUs: algorithm-system co-design (modify the model and optimizer in concert with the parallelism schedule, the operators, the data pipeline, and the network) and in-depth observability (CUDA-event-level monitoring, heartbeat-driven diagnostics, and a 3D-parallel visualization layer) so that root causes of failures and stragglers can be located and remediated within minutes rather than days. Pure system optimizations alone are insufficient: the ablation shows that op-level kernels deliver only +1.7% MFU, while algorithmic and overlap changes deliver +14.6%.
Method
The paper is a full-stack production system. Major components:
Algorithmic optimizations.
- Parallel Transformer Block (PTB): replaces the
serial
y = x + MLP(LN(x + Attention(LN(x))))with the parallely = x + MLP(LN(x)) + Attention(LN(x)), halving LayerNorms on the critical path. - Sliding Window Attention (SWA): O(s*w) attention instead of O(s^2); effective receptive field recovered by stacking layers.
- LAMB optimizer: scales global batch by 4x without
accuracy loss, reducing pipeline bubble from
4(p-1)/(v*m)to(p-1)/(v*4m)— an 87.5% bubble reduction.
3D-parallelism communication overlap.
- DP: all-gather/reduce-scatter scheduled per model chunk; first iteration's all-gather pre-fetched at iteration start.
- PP: Send/Receive in 1F1B's warm-up and cool-down decoupled and launched asynchronously, overlapping with adjacent forward/backward.
- TP/SP: all-gather/reduce-scatter fused with parallel Linears; GEMMs chunked into A0..An so chunk i+1's compute pipelines with chunk i's communication on a separate CUDA stream.
Efficient operators. FlashAttention-2 plus fused LayerNorm and GeLU kernels.
Data pipeline. Asynchronous preprocessing (next iteration's data prep overlaps current iteration's gradient sync); single per-machine dataloader that exposes data via shared memory to all 8 workers, eliminating redundant disk reads.
Collective communication initialization. Replace
PyTorch's TCPStore with Redis, and reduce
barrier complexity from O(n^2) to O(n). Initialization on 2,048 GPUs
falls from ~1,047s to <5s; on >10,000 GPUs to <30s.
Network performance tuning. Three-layer CLOS with
Broadcom Tomahawk-4 chips (25.6 Tbps, 64x400G); 400G ports split into
2x200G to reduce ECMP hashing collisions; 8 NICs to 8 ToRs (multi-rail);
same-DP-group nodes under same ToR; in-house congestion control fusing
Swift (RTT-based) with DCQCN
(ECN-based) to avoid PFC head-of-line blocking; NCCL retransmit
parameters tuned and NIC adap_retrans enabled.
Fault tolerance. Driver/executor architecture on Kubernetes; heartbeat daemon reports IP, hardware info, logs, and RDMA traffic metrics; on suspicion, lightweight diagnostics run (intra-host loopback, RNIC-to-RNIC, NCCL all-to-all/all-reduce); two-stage checkpointing (Stage 1: GPU to host RAM in seconds, blocking; Stage 2: host RAM to HDFS, async); recovery uses one-reader-per-DP-group HDFS read plus RDMA broadcast to avoid HDFS saturation.
Troubleshooting. CUDA event monitor produces heat-maps and traces without global synchronization; distributed tracer; 3D-parallel visualizer maps logs to (TP, PP, DP) topology to localize the rank that triggered cascading timeouts.
Experimental Setup
| Component | Value |
|---|---|
| Cluster scale | up to 12,288 GPUs (175B), 11,200 GPUs (530B) |
| Switch fabric | three-layer CLOS, Broadcom Tomahawk-4 (25.6 Tbps, 64x400G), 1:1 oversub. |
| NICs | 8 per node, multi-rail to 8 ToRs |
| Congestion control | Custom Swift + DCQCN |
| Models | 175B (128 heads, 12,288 hidden, 96 layers), 530B (160 heads, 20,480 hidden, 105 layers) |
| Sequence length / vocab | 2,048 / 64,000 |
| Baseline | Megatron-LM |
| Metrics | MFU, iteration time, throughput (tokens/s), aggregate PFlops/s, training days for 300B tokens |
| Ablation | 256 GPUs, batch 256 |
Headline Quantitative Results
Strong scaling (175B, batch 6,144):
| GPUs | Megatron-LM MFU | MegaScale MFU | Speedup |
|---|---|---|---|
| 3,072 | 48.7% | 59.1% | 1.21x |
| 6,144 | 47.8% | 57.3% | 1.19x |
| 8,192 | 43.3% | 54.9% | 1.26x |
| 12,288 | 41.2% | 55.2% | 1.34x |
Strong scaling (175B, batch 768): MFU 65.3% at 256 GPUs falling gracefully to 59.0% at 1,024 GPUs (1.32x over Megatron-LM).
Weak scaling (530B): MegaScale 54.3% MFU at 11,200 GPUs vs. Megatron-LM 48.2% — a 6.1 percentage-point absolute MFU advantage.
Aggregate compute: 2,166.3 PFlops/s at 12,288 GPUs.
Ablation (175B / 256 GPUs / batch 256): baseline 47.7% MFU → +PTB 52.3 → +SWA 53.3 → +TP overlap 55.5 → +PP overlap 58.0 → +DP overlap 59.5 → +efficient ops 61.2 → +misc 62.3 → +LAMB 65.3% (cumulative +17.6).
Cluster initialization: 2,048 GPUs from ~1,047s to <5s; >10,000 GPUs in <30s.
Production stability:
- Multi-trillion-token training run lasting several weeks on 10,000+ GPUs.
- Auto-recovered more than 100 times.
- >90% effective training time rate.
- Failure detection + diagnostics: <10 minutes.
- Catch-up to last checkpoint: <15 minutes.
Stragglers: 0.5% of machines were ~10% slower in forward; removing them recovered ~0.7% MFU.
Limitations
- Reactive, not predictive, fault tolerance. Failures are detected after they occur; predicting them at this scale remains hard.
- Staleness deliberately avoided. Accuracy preserved by serial semantics, but balanced-staleness designs that could expose more concurrency are not explored.
- Hardware-specific tuning. Network and operator gains are tuned for ByteDance's Tomahawk-4 / 8x400G NIC topology; transferability to different switch radix or NIC counts is not characterized.
- No full open-source release at publication time;
partial release planned via
veScale.
Open Problems
- Predictive fault tolerance. Move from reactive recovery to anticipating failures (e.g., from RDMA traffic anomalies, link signal drift, GPU ECC trajectories) before they cause cluster-wide stalls.
- Balanced staleness for further overlap. Combine MegaScale's accuracy-preserving overlap with bounded-staleness designs without compromising final model quality.
- Topology-portable tuning. Current network and operator optimizations are co-designed with one specific datacenter; principled methods to retarget them to other clusters are an open question.
- Tail-latency-aware collective scheduling. The
mid-training MFU drop case study localized the bottleneck to a single
collective (last DP
reduce-scatter) whose tail blew up under cross-rank time skew; principled tail-aware collective scheduling at 10k-GPU scale remains an open problem.
Note on NCCL Tuning
MegaScale tunes NCCL itself, not just the schedule above it: it
raises NCCL retransmit timer and retry count, sets NCCL timeouts
explicitly large enough to absorb multi-second link flapping, and
enables NIC adap_retrans. The mid-training MFU-drop case
study localized the bottleneck to the last
reduce-scatter in the DP step, whose tail
latency blew up under cross-rank time skew — evidence that collective
tail latency, not aggregate bandwidth, is the failure mode that matters
at 10k+ GPUs and that NCCL chunking and timeout parameters deserve
first-class status as tuning dimensions.