Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity

Paper: Fedus, Zoph, Shazeer, Google, JMLR 2022 Core contribution: Simplifies Mixture-of-Experts to top-1 routing (one expert per token), eliminating the k>1 requirement while preserving model quality, reducing routing compute, and halving expert capacity requirements. Achieves 7x pre-training speedup over T5-Base with the same FLOP budget.


Fig 1: System Overview Block Diagram

┌──────────────────────────────────────────────────────────────┐
│              Switch Transformer Training System              │
│                                                              │
│  ┌────────────────────────────────────────────────────────┐  │
│  │  N Cores (TPUs / GPUs), organized as n x m mesh        │  │
│  │                                                        │  │
│  │  ┌──────────────────────────────────────────────────┐  │  │
│  │  │  Switch Transformer Model (encoder or decoder)   │  │  │
│  │  │                                                  │  │  │
│  │  │  ┌────────────────────────────────────────────┐  │  │  │
│  │  │  │  Standard Transformer Layer (every other)  │  │  │  │
│  │  │  │  [Self-Attention] → [Add+Norm] → [FFN]     │  │  │  │
│  │  │  │                   → [Add+Norm]              │  │  │  │
│  │  │  └────────────────────────────────────────────┘  │  │  │
│  │  │                          │                        │  │  │
│  │  │  ┌────────────────────────────────────────────┐  │  │  │
│  │  │  │  Switch Layer (every other FFN replaced)   │  │  │  │
│  │  │  │  [Self-Attention] → [Add+Norm]             │  │  │  │
│  │  │  │  [Router] → [Expert dispatch]              │  │  │  │
│  │  │  │  [E experts, 1 per device] → [Add+Norm]   │  │  │  │
│  │  │  └────────────────────────────────────────────┘  │  │  │
│  │  └──────────────────────────────────────────────────┘  │  │
│  └────────────────────────────────────────────────────────┘  │
│                                                              │
│  Parallelism axes:                                           │
│  ├── Expert parallelism (E experts across E cores)           │
│  ├── Data parallelism (n-way token batch split)              │
│  └── Model parallelism (m-way weight split, optional)        │
└──────────────────────────────────────────────────────────────┘
▲ Fig 1: Switch Transformer full system. Expert layers replace every
  other FFN layer. Three parallelism axes combine multiplicatively.

Fig 2: Key Architecture Diagram — Switch Layer (Top-1 Routing)

  INPUT: token representations x₁, x₂, ..., xₜ
  (one vector per token, shape [batch, seq_len, d_model])
    │
    ▼
┌────────────────────────────────────────────────────────┐
│  Router  (learned weight matrix W_r: d_model → E)      │
│                                                        │
│  h(x) = W_r · x                  (logits over experts) │
│  p_i(x) = softmax(h(x))_i        (routing probability) │
│  selected = argmax p_i(x)         (top-1 expert index) │
│  gate_value = p_{selected}(x)     (scalar weight)      │
│                                                        │
│  Jitter noise added to h(x) during training only       │
│  Router ops cast to float32 for stability              │
└────────────────────────────────────────────────────────┘
    │
    │ dispatch tensor [n, B/n, E, C]  (boolean)
    ▼
┌───────────────────────────────────────────────────────┐
│  Expert Dispatch (All-to-All communication)           │
│                                                       │
│  Each core sends token slices to the expert owner     │
│  Size: E × C × d_model  (capacity factor C)          │
│  expert_capacity = (tokens_per_batch / E) × cap_factor│
│                                                       │
│  If tokens_per_expert > capacity: TOKEN DROPPED       │
│  (passed through residual, no FFN processing)         │
└───────────────────────────────────────────────────────┘
    │                            │
    ▼ (per expert device)        ▼ (all-to-all #2: return)
┌──────────────┐         ┌──────────────┐
│  Expert FFN-1│         │  Expert FFN-E│
│  (unique     │   ...   │  (unique     │
│   weights)   │         │   weights)   │
│  feed_forward│         │  feed_forward│
└──────┬───────┘         └──────┬───────┘
       └──────────────┬──────────┘
                      ▼
             combine_tensor × gate_value
                      │
                      ▼
             y = Σ p_i(x)·E_i(x)   (weighted output)
▲ Fig 2: Switch layer — top-1 routing sends each token to exactly one
  expert. Two all-to-all collectives bracket the expert FFN compute.

Fig 3: Control / Data Flow Diagram — One Switch Layer Forward Pass

  START: input token batch [B, seq_len, d_model]
    │
    ▼
① [Self-Attention block]  (standard, no routing)
    │  output: [B, seq_len, d_model]
    ▼
② [Add + LayerNorm]
    │
    ▼
③ [Router: compute logits and select top-1 expert per token]
    │
    ├── cast to float32  (stability for softmax)
    ├── add multiplicative jitter noise  (training only)
    ├── softmax → probabilities p_i(x)
    ├── argmax → expert index per token
    ├── compute expert_capacity = (B·T/E) × capacity_factor
    ├── enforce capacity limit → produce dispatch_tensor
    └── compute auxiliary load-balancing loss
    │
    ▼
④ [All-to-All #1: dispatch tokens to expert devices]
    │  tokens routed locally to correct expert (core-local gather)
    │  reshape [n,1,1,1] → [1,n,1,1]  (switch expert dim)
    │  communication: E × C × d_model  bfloat16 tensors
    ▼
⑤ [Expert FFN compute: each expert processes its token batch]
    │  standard 2-layer FFN with unique parameters per expert
    │  expert_outputs shape: [E, n, C, d_model]
    ▼
⑥ [All-to-All #2: return processed tokens to source cores]
    │  reshape [1,n,1,1] → [n,1,1,1]
    │  communication: E × C × d_model  bfloat16 tensors
    ▼
⑦ [Combine: multiply by gate value, sum over experts]
    │  expert_outputs_combined = einsum(outputs, combine_tensor)
    ▼
⑧ [Add + LayerNorm + residual connection]
    │
  END: output [B, seq_len, d_model], same shape as input
▲ Fig 3: One Switch layer forward pass. Two all-to-all collectives
  are the only cross-device communication per Switch layer.

Fig 4: Parallelism Topology for Expert + Data + Model

  N total cores = n (data-parallel) × m (model-parallel)
  E experts, one per core in expert-parallel mode

  ┌─────────────────────────────────────────────────────┐
  │  Core layout (2D mesh: data-parallel × model-       │
  │               parallel dimensions)                   │
  │                                                     │
  │  DP=4, MP=1, E=4                                    │
  │  ┌────┐  ┌────┐  ┌────┐  ┌────┐                    │
  │  │ E1 │  │ E2 │  │ E3 │  │ E4 │  ← each expert     │
  │  │core│  │core│  │core│  │core│    on one core      │
  │  │ 0  │  │ 1  │  │ 2  │  │ 3  │                    │
  │  └────┘  └────┘  └────┘  └────┘                    │
  │                                                     │
  │  Attention layers: all-reduce (size [B/n, d_model]) │
  │  Switch layers:   all-to-all (size E·C·d_model)     │
  │                                                     │
  │  With m>1 (model parallel added):                   │
  │  all-to-all costs PLUS all-reduce per attn layer    │
  │  → communication increases, batch size decreases    │
  └─────────────────────────────────────────────────────┘
▲ Fig 4: Expert and data parallelism topology. Each expert lives on
  one core; the all-to-all dispatches tokens across core boundaries.

Fig 5: Load Balancing — Auxiliary Loss Mechanism

  For each Switch layer with N experts and batch B of T tokens:

  ┌──────────────────────────────────────────────────────┐
  │  f_i = (1/T) Σ_{x∈B} 1{argmax p(x) = i}            │
  │        fraction of tokens dispatched to expert i     │
  │                                                      │
  │  P_i = (1/T) Σ_{x∈B} p_i(x)                        │
  │        fraction of router probability for expert i   │
  │                                                      │
  │  aux_loss = α · N · Σᵢ fᵢ · Pᵢ                     │
  │                                                      │
  │  α = 1e-2  (small enough not to dominate main loss)  │
  │  Minimized when fᵢ = Pᵢ = 1/N  (uniform routing)   │
  │                                                      │
  │  f is NOT differentiable (argmax)                    │
  │  P IS differentiable (soft softmax)                  │
  │  Gradient flows only through P                       │
  └──────────────────────────────────────────────────────┘
▲ Fig 5: Auxiliary load-balancing loss. Encourages uniform token
  distribution across experts via a differentiable surrogate.

Fig 6: Layered Software Stack

┌────────────────────────────────────────────────────────────┐
│  Training script (T5-style seq2seq objective on C4)        │
├────────────────────────────────────────────────────────────┤
│  Switch Transformer model (standard + Switch layers)       │
│  (top-1 router, load-balance loss, expert dropout)         │
├────────────────────────────────────────────────────────────┤
│  Mesh TensorFlow (MTF)                                     │
│  (logical mesh abstraction, named dimensions)              │
│  (handles data/model/expert sharding automatically)        │
├────────────────────────────────────────────────────────────┤
│  XLA Compiler                                              │
│  (static shape compilation, TPU optimization)              │
├────────────────────────────────────────────────────────────┤
│  TPUv3 hardware (all-to-all, all-reduce primitives)        │
└────────────────────────────────────────────────────────────┘
▲ Fig 6: Software stack. MTF provides the mesh abstraction used to
  implement expert, data, and model parallelism uniformly.

Design Trade-off Analysis

Design Decision Alternative A Alternative B (Switch choice) Winner Why
Number of experts per token Top-k (k≥2, MoE default) Top-1 (Switch routing) B Halves expert capacity needed; reduces routing compute; simplifies gradients; empirically matches quality
Router precision bfloat16 throughout Cast router input to float32 only B bfloat16 softmax diverges; full float32 wastes communication bandwidth; selective cast is best of both
Expert capacity overflow Drop token silently No-Token-Left-Behind (2-stage re-route) A (base) Re-routing found no empirical benefit; load balancing loss sufficient; keep capacity factor 1.0–1.5
Load balancing mechanism Two separate losses (Shazeer MoE) Single combined aux loss B Simpler, one hyperparameter α; empirically equally effective
Expert placement Colocate with model-parallel shards One expert per data-parallel core B Avoids all-reduce inside expert computation; all-to-all cost is lower than all-reduce + expert all-reduce
Expert capacity factor 1.0 (tight, drops tokens) 1.5 (buffer, wastes memory) Context-dependent Larger models need lower CF (memory scarce); Switch performs best at CF=1.0–1.25
Initialization scale Default Transformer (scale=1.0) Reduced scale (0.1x) B 68x lower variance in training quality; stable across 223M to 1T parameter models
Gradient routing exploration Deterministic argmax Input jitter noise B Jitter achieves -1.468 vs -1.471 for argmax; prevents expert collapse

For DynamICCL context: Switch Transformers introduce all-to-all as the dominant collective at expert layers, replacing all-reduce as the primary communication primitive. This is a fundamentally different workload pattern from Megatron/ZeRO and requires distinct NCCL configuration.


What to Borrow for DynamICCL

1. All-to-All as a first-class collective requiring dedicated tuning. Switch Transformers generate two all-to-all collectives per Switch layer per forward pass (dispatch + combine), in addition to the all-reduce collectives from the attention mechanism. NCCL's all-to-all is implemented differently from ring all-reduce: it has O(N) message sends vs O(N) ring steps. For DynamICCL, the Config Agent must have a distinct action space entry for all-to-all collectives — the optimal algorithm and protocol may differ from the all-reduce config. The message size for each all-to-all is E × C × d_model × bfloat16 bytes, which DynamICCL should track as a state feature.

2. Capacity factor as a dynamic congestion lever. The expert capacity factor directly controls how many tokens each expert processes and therefore the all-to-all message size. When DynamICCL's Trigger Agent detects congestion (CUSUM threshold crossed), one response is to recommend reducing the capacity factor slightly — this shrinks the all-to-all message and reduces network pressure at the cost of slightly more dropped tokens. This trade-off is analogous to back-pressure in network flow control: reduce injection rate when the network is congested.

3. Selective precision as a protocol selection analogy. Switch Transformers selectively cast the router to float32 while keeping the rest in bfloat16. The transport analogy in DynamICCL: for the all-to-all at expert layers (which is latency-sensitive due to blocking the forward pass), prefer NCCL's LL protocol (optimized for small-to-medium messages) even if the message is larger than the typical LL threshold, because the latency reduction outweighs the bandwidth cost for blocking collectives.

4. Load imbalance as a network congestion signal. When tokens route unevenly to experts (poor load balance), some cores send more data in the all-to-all than others, creating incast congestion at the destination expert. DynamICCL's Trigger Agent can use the variance of received data per rank across all-to-all operations as a congestion signal — high variance correlates with network hotspots. This extends the CUSUM detector beyond pure latency signals to include data skew signals.

5. Three-axis parallelism as distinct collective workload classes. Expert parallelism generates all-to-all; model parallelism generates all-reduce within layer; data parallelism generates all-reduce at gradient sync. Each axis has different message sizes, frequencies, and latency requirements. DynamICCL's state representation should encode which parallelism axis generated each collective (via message size fingerprinting and call frequency), and the Config Agent should maintain separate Q-tables or policy heads for each workload class rather than a single universal policy.