nnScaler: Constraint-Guided Parallelization Plan Generation for Deep Learning
Paper: Lin et al., Microsoft Research, USENIX OSDI 2024 Core contribution: A constraint-guided framework that generates novel parallelization plans from three primitives — op-trans (operator transformation), op-assign (device placement), op-order (temporal scheduling). Introduces vTensor/pTensor dependency tracking to detect cycles and materialize communications, and discovers previously unknown plans (Coshard, Interlaced Pipeline, 3F1B) that outperform DeepSpeed/Megatron-LM/Alpa by up to 3.5x.
Fig 1: System Overview Block Diagram
┌──────────────────────────────────────────────────────────────┐
│ nnScaler Framework │
│ │
│ ┌───────────────────────────────────────────────────────┐ │
│ │ Input: DNN Model (PyTorch) │ │
│ │ Compute Graph: ops + data dependencies (vTensors) │ │
│ └──────────────────────┬────────────────────────────────┘ │
│ │ compute graph │
│ ▼ │
│ ┌───────────────────────────────────────────────────────┐ │
│ │ Plan Generator (three primitive operations) │ │
│ │ │ │
│ │ op-trans ── partition operator across devices │ │
│ │ op-assign ── assign partitioned op to device(s) │ │
│ │ op-order ── impose temporal ordering constraints │ │
│ └──────────────────────┬────────────────────────────────┘ │
│ │ candidate plan │
│ ▼ │
│ ┌───────────────────────────────────────────────────────┐ │
│ │ Dependency Checker (vTensor → pTensor tracker) │ │
│ │ │ │
│ │ Detects cycles in partitioned dependency graph │ │
│ │ Materializes required communication ops │ │
│ │ (AllReduce / AllGather / ReduceScatter / P2P) │ │
│ └──────────────────────┬────────────────────────────────┘ │
│ │ validated plan │
│ ▼ │
│ ┌───────────────────────────────────────────────────────┐ │
│ │ Placement Optimizer (ILP or cost-model search) │ │
│ │ + Tessel (temporal op-order optimizer) │ │
│ │ │ │
│ │ Minimizes: memory footprint per device │ │
│ │ Maximizes: compute-communication overlap │ │
│ └──────────────────────┬────────────────────────────────┘ │
│ │ optimized plan │
│ ▼ │
│ ┌───────────────────────────────────────────────────────┐ │
│ │ Code Generator │ │
│ │ Emits: per-device Python/CUDA kernels │ │
│ │ + communication calls (NCCL / custom collectives) │ │
│ └───────────────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────────────┘
▲ Fig 1: nnScaler full system — compute graph → plan generation →
dependency validation → placement + temporal optimization → codegen.
Fig 2: Key Architecture Diagram — vTensor / pTensor Abstraction
BEFORE PARTITIONING (logical graph):
┌──────────────┐ vTensor A ┌──────────────┐
│ Op X │ ════════════►│ Op Y │
│ (produces A)│ │ (consumes A) │
└──────────────┘ └──────────────┘
AFTER op-trans on Op X (column-split) and Op Y (row-split):
┌────────────────────────────────────────────────────────┐
│ Device 0 Device 1 │
│ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Op X shard0 │ │ Op X shard1 │ │
│ │ produces │ │ produces │ │
│ │ pTensor A₀ │ │ pTensor A₁ │ │
│ └──────┬───────┘ └──────┬───────┘ │
│ │ │ │
│ ┌───────▼──────────────────────── ▼──────────────────┐ │
│ │ Dependency Tracker │ │
│ │ │ │
│ │ Checks: can Op Y shard0 consume pTensor A₀ only? │ │
│ │ Case A: YES → no communication needed │ │
│ │ Case B: NO → insert AllGather / ReduceScatter │ │
│ │ Case C: CYCLE detected → plan invalid, backtrack │ │
│ └────────────────────────────────────────────────────┘ │
│ │ │ │
│ ┌──────▼───────┐ ┌──────▼───────┐ │
│ │ Op Y shard0 │ │ Op Y shard1 │ │
│ │ (consumes │ │ (consumes │ │
│ │ pTensor A₀) │ │ pTensor A₁) │ │
│ └──────────────┘ └──────────────┘ │
└────────────────────────────────────────────────────────┘
▲ Fig 2: vTensor-to-pTensor dependency tracking. After partitioning,
the tracker determines whether communication must be materialized.
Fig 3: Control Flow Diagram — Plan Generation Loop
START: DNN compute graph G, device count N, memory budget M
│
▼
① [Enumerate op-trans candidates for each op in G]
│ For each operator: list valid partition strategies
│ (batch-split, row-split, column-split, no-split)
▼
② [Apply op-assign: map each op partition to device(s)]
│ ILP or heuristic assigns pTensors to devices
│ respecting memory constraint M per device
▼
③ [Check data dependencies: vTensor → pTensor consistency]
│
├── No conflict → plan is valid, continue
│
├── Conflict resolvable → insert communication op
│ (AllReduce, AllGather, ReduceScatter, or P2P send)
│ Estimate comm cost → add to plan cost
│
└── Cycle detected → discard plan, backtrack to ②
│
▼
④ [Apply op-order via Tessel temporal optimizer]
│ Reorder ops within device to maximize overlap:
│ communication of batch k overlaps compute of batch k+1
▼
⑤ [Evaluate plan cost: estimated step time]
│
├── Cost < best_so_far → update best plan
│
└── Search budget exhausted? → YES → emit best plan
│
NO → goto ① (next candidate)
▼
⑥ [Code generation: emit per-device CUDA + comm calls]
│
OUTPUT: parallelized training program
▲ Fig 3: nnScaler plan generation control flow. Three primitives drive
a search over valid plans; Tessel optimizes temporal ordering.
Fig 4: Data Flow Diagram — Coshard Novel Plan
Standard Tensor Parallel (Megatron-style):
┌─────────────┐ AllReduce ┌─────────────┐ AllReduce
│ Op A shard0 │════════════► │ Op B shard0 │════════════►...
│ (device 0) │◄════════════ │ (device 0) │
└─────────────┘ └─────────────┘
each op independently synced → 2 AllReduces per layer pair
Coshard Plan (nnScaler discovery):
┌─────────────┐ ┌─────────────┐
│ Op A shard0 │══ pTensor ══►│ Op B shard0 │
│ (device 0) │ (local, │ (device 0) │
│ │ no comm) │ │
└─────────────┘ └─────────────┘
┌─────────────┐ ┌─────────────┐
│ Op A shard1 │══ pTensor ══►│ Op B shard1 │ ══ AllReduce ══►
│ (device 1) │ (local, │ (device 1) │ (once, after
│ │ no comm) │ │ both ops)
└─────────────┘ └─────────────┘
KEY: Op A and Op B share the same partition dimension.
pTensor flows locally without AllReduce between them.
One AllReduce at the end replaces two — halves comm volume.
▲ Fig 4: Coshard data flow. Compatible partition dimensions allow
chaining ops without intermediate AllReduce insertions.
Fig 5: Data Flow Diagram — Interlaced Pipeline for T5 Embedding
PROBLEM: T5 embedding layer is used by both encoder and decoder.
Standard 1F1B assigns it to one device → bottleneck.
INTERLACED PIPELINE PLAN:
TIME →
┌───────────────────────────────────────────────────────────┐
│ Device 0 [emb fwd mb1][enc fwd mb1] [enc bwd mb1]│
│ │
│ Device 1 [enc fwd mb2][emb fwd mb2] │
│ [enc bwd mb2][emb bwd]│
│ │
│ Device 2 [dec fwd mb1][dec bwd mb1] │
│ │
│ Embedding layer distributed across devices: │
│ ┌────────────────────────────────────────────────────┐ │
│ │ op-assign: embed shard i → device (i mod N) │ │
│ │ op-order: embed fwd of mb_k scheduled between │ │
│ │ enc fwd of mb_(k-1) and mb_(k+1) │ │
│ └────────────────────────────────────────────────────┘ │
└───────────────────────────────────────────────────────────┘
RESULT: embedding compute distributed; no single-device
bottleneck; idle time reduced vs. standard pipeline.
▲ Fig 5: Interlaced Pipeline for T5. op-order interleaves embedding
forward passes across devices to eliminate the embedding bottleneck.
Fig 6: Layered Software Stack
┌──────────────────────────────────────────────────────────┐
│ User DNN model (PyTorch nn.Module) │
├──────────────────────────────────────────────────────────┤
│ nnScaler API │
│ (parallelize() call: specifies device count, budget) │
├──────────────────────────────────────────────────────────┤
│ nnScaler Plan Generator │
│ (op-trans / op-assign / op-order primitives) │
│ (vTensor-pTensor dependency tracker) │
│ (Tessel temporal ordering optimizer) │
├──────────────────────────────────────────────────────────┤
│ nnScaler Code Generator │
│ (per-device Python, communication stub calls) │
├──────────────────────────────────────────────────────────┤
│ NCCL / custom collectives │
│ (AllReduce, AllGather, ReduceScatter, P2P send/recv) │
├──────────────────────────────────────────────────────────┤
│ NVLink (intra-node) / InfiniBand (inter-node) │
└──────────────────────────────────────────────────────────┘
▲ Fig 6: nnScaler software stack. The plan generator and code
generator sit between the user model and NCCL/transport layer.
Fig 7: State Machine — Per-Device Execution Under nnScaler Plan
new_plan_emitted
[IDLE] ──────────────────────► [LOADING_PLAN]
│
plan code loaded
│
▼
┌─────────────────────────────── [EXECUTING] ──────────────┐
│ │
│ [LOCAL_COMPUTE] │
│ │ pTensor ready on device │
│ ▼ │
│ [COMM_WAIT?] │
│ │ │
│ ┌────┴──────────────────────────────────────────┐ │
│ │ yes: blocking AllReduce / AllGather / RS │ │
│ │ → wait for NCCL completion │ │
│ │ no: pTensor flows locally → continue compute │ │
│ └────┬──────────────────────────────────────────┘ │
│ ▼ │
│ [TESSEL_OVERLAP] │
│ │ schedule next micro-batch compute while │
│ │ current micro-batch comm is in flight │
│ ▼ │
│ [UPDATE_WEIGHTS] (optimizer step, local shard) │
│ │ │
│ └──────────────────────────────► [LOCAL_COMPUTE] │
└──────────────────────────────────────────────────────────┘
│
end of training
│
[IDLE]
▲ Fig 7: Per-device execution state machine. Tessel overlap state
hides communication latency behind local compute of next micro-batch.
Design Trade-off Analysis
| Design Decision | Alternative A | Alternative B (nnScaler choice) | Winner | Why |
|---|---|---|---|---|
| Plan generation approach | Manual expert design (Megatron, ZeRO) | Constraint-guided search over primitives | B | Discovers non-obvious plans (Coshard, Interlaced) that manual design missed; generalizes across models |
| Dependency representation | Tensor-level (full materialized tensors) | vTensor → pTensor abstraction | B | vTensor tracks partition intent; pTensor tracks physical shard; cycle detection operates on the mapping |
| Communication insertion | Require user to specify comm (Megatron) | Auto-materialize from dependency check | B | Reduces user burden; correctness guaranteed by dependency tracker |
| Temporal scheduling | Fixed execution order (standard pipeline) | op-order + Tessel dynamic reordering | B | Up to 40% additional overlap from reordering; especially effective for AlphaFold2 3F1B |
| Placement search | Greedy heuristic (fast, suboptimal) | ILP-based exact search (slow, optimal) | Context-dependent | ILP optimal for small graphs; cost-model heuristic used for large models |
| Plan novelty | Reuse known plans (1F1B, ring-allreduce) | Generate new plans from primitives | B | 3.5x speedup on SwinTransformer from Coshard; 2.1x on AlphaFold2 from 3F1B |
| Cycle detection cost | None (fail at runtime) | Static graph analysis before codegen | B | Eliminates invalid plans before expensive profiling; critical for large search space |
| User API surface | Model-parallel API (requires rewrite) | Single parallelize() call on nn.Module | B | Zero model code changes; same philosophy as ZeRO's wrapping approach |
For DynamICCL context: nnScaler's plan search generates a diverse mix of collective types and sizes depending on which plan is selected — the same model under Coshard generates fewer AllReduces than under standard tensor parallelism, fundamentally changing the collective workload. DynamICCL must adapt its config to the active parallelization plan, not just the model.
What to Borrow for DynamICCL
1. Plan-fingerprint as a collective workload identifier. nnScaler's different plans (Coshard vs. standard tensor parallel vs. Interlaced Pipeline) generate fundamentally different collective sequences — different op types, sizes, and frequencies. DynamICCL's LSTM encoder should treat the sequence of recent collective calls as a fingerprint that identifies the active parallelization plan. When the fingerprint shifts (e.g., AllReduce count drops and AllGather count rises, signaling a ZeRO-stage change or Coshard activation), the Config Agent should trigger re-exploration rather than continuing with the previously learned config. The plan fingerprint is a higher-level signal than message size alone.
2. Coshard's communication reduction as a config selection signal. Coshard chains compatible op partitions without intermediate AllReduce, then issues a single AllReduce at the chain boundary. For DynamICCL, this means the AllReduce at the Coshard boundary is larger than the per-op AllReduce in standard tensor parallelism (it accumulates the product of multiple ops). The Config Agent should prefer Simple protocol + ring algorithm for these larger boundary AllReduces, even within a single node, because the message size has grown beyond the LL protocol threshold. DynamICCL must not assume that intra-node collectives are always small.
3. op-order (Tessel) overlap pattern as a numChannels selection heuristic. Tessel's temporal reordering achieves compute-communication overlap by scheduling the next micro-batch's compute while the current micro-batch's AllReduce is in flight. This overlap is only effective if the AllReduce completes within the compute window — meaning the AllReduce throughput must be high enough. DynamICCL should increase numChannels for collective calls that appear in Tessel-overlapped positions (identifiable by the very short inter-call interval), because higher channel count increases AllReduce throughput and allows the overlap window to close before the next compute step requires the result.
4. Cycle detection maps to deadlock avoidance in NCCL config selection. nnScaler's dependency tracker rejects plans that create cycles in the partitioned dependency graph — cycles would require two devices to wait for each other simultaneously. The analogous failure mode in NCCL is deadlock from misconfigured ring orderings when ranks have asymmetric collective participation (e.g., in expert-parallel all-to-all with non-power-of-2 expert counts). DynamICCL's Config Agent should maintain a constraint table of (collective_type, rank_count) pairs that are known to produce NCCL errors with specific algorithm choices, and treat these as hard exclusions in the action space.
5. 3F1B scheduling's distinct memory-communication pattern. nnScaler's 3F1B for AlphaFold2 runs three forward passes before one backward, increasing the number of in-flight activations relative to standard 1F1B. This raises peak memory and also changes the temporal spacing of gradient AllReduces — they arrive in bursts of three rather than one-at-a-time. DynamICCL's Trigger Agent should recognize bursty AllReduce patterns (short inter-arrival time between multiple large collectives) as a distinct regime from periodic single-collective patterns, and apply a burst-tolerant config: lower numThreads per collective to avoid NIC contention between the burst members, versus high numThreads for a single isolated large collective.