Detailed Summary: Switch Transformers

Full title: Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity Authors: William Fedus, Barret Zoph, Noam Shazeer (Google Brain) Year: JMLR 23 (2022), submitted August 2021, published April 2022 Venue: Journal of Machine Learning Research Code: https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py


Abstract (paraphrased)

Dense neural networks reuse all parameters for every input. Mixture-of-Experts (MoE) models instead activate a different subset of parameters per example — achieving high parameter count at constant computational cost. Switch Transformers simplify the MoE routing algorithm to route each token to a single expert (k=1), reduce communication costs, and address training instability through selective precision, improved initialization, and expert dropout. Using the T5 architecture as a base, Switch-Base achieves up to 7× pre-training speedup over T5-Base at equal compute. Switch Transformer models scale to 1.6 trillion parameters, achieving 4× speedup over T5-XXL. Gains extend to multilingual settings (all 101 languages improved) and the model can be distilled to 1/20th its size while preserving 30% of the quality gain.


Motivation

Scaling laws (Kaplan et al., 2020) show that neural language model performance improves as a power law with model size, dataset size, and compute. The three standard axes for scaling — model depth, model width, training FLOPs — all increase compute per token proportionally. MoE offers a fourth axis: increase parameter count (and thus knowledge capacity) while keeping FLOPs per token approximately fixed by only activating a subset of parameters per input.

Prior MoE work (Shazeer et al., 2017; Lepikhin et al., 2020 / GShard) showed strong results in machine translation but had three key obstacles:

  1. Complexity: Top-k routing (k≥2) with separate load-balancing and importance-weighting loss terms.
  2. Communication cost: All-to-all operations to route tokens to experts across devices.
  3. Training instability: bfloat16 precision causes the softmax in the router to diverge; Lepikhin et al. used full float32 throughout, incurring expensive communication of float32 tensors.

Background

Mixture of Experts (MoE) Routing

For N experts {E_i}, a router with weight matrix W_r computes logits h(x) = W_r · x. Gate probabilities:

p_i(x) = exp(h(x)_i) / sum_j exp(h(x)_j)         [Equation 1]

Standard MoE selects top-k experts and computes the output as a weighted sum:

y = sum_{i in top-k} p_i(x) · E_i(x)              [Equation 2]

Switch routing uses k=1: route to argmax(p(x)). The output is simply p_{i*}(x) · E_{i*}(x) where i* = argmax p(x).

Expert Capacity

Each expert can process a fixed number of tokens per batch. The capacity is:

expert_capacity = (tokens_per_batch / num_experts) × capacity_factor   [Equation 3]

A capacity_factor > 1.0 provides a buffer for uneven routing. Tokens exceeding an expert's capacity are "dropped" — passed through the residual connection without expert processing. Token dropping rate is typically <1% with the auxiliary load-balancing loss.


System Design

Switch Transformer Architecture

Standard Transformer Layer:
┌─────────────────────────────────┐
│  Self-Attention                 │
│  Add + Normalize                │
│  FFN (dense, same for all tokens)│
│  Add + Normalize                │
└─────────────────────────────────┘

Switch Transformer Layer (every other layer):
┌─────────────────────────────────────────────┐
│  Self-Attention                             │
│  Add + Normalize                            │
│  Switch FFN Layer:                          │
│    Router → argmax → Expert_i* (per token)  │
│    Output = p_i*(x) · E_i*(x)               │
│    Add + Normalize                          │
└─────────────────────────────────────────────┘

Experts are distributed one-per-device:
Device 0: Expert 0 (unique FFN weights)
Device 1: Expert 1
...
Device N-1: Expert N-1

Distributed Execution (Expert + Data Parallelism)

In expert-and-data parallel mode with n data-parallel cores and E experts:

Communication per Switch layer = 2 × all-to-all of size E × C × d_model.

Load Balancing Loss

Auxiliary loss added to total loss at each Switch layer. For N experts, batch B, T tokens:

f_i = (1/T) sum_{x in B} 1{argmax p(x) = i}    [fraction of tokens to expert i]
P_i = (1/T) sum_{x in B} p_i(x)                  [mean router probability for expert i]

auxiliary_loss = alpha * N * sum_i (f_i * P_i)    [Equation 4]

With alpha = 10^{-2}, this loss is minimized when all f_i = P_i = 1/N (uniform routing). The f_i vector is not differentiable (it uses argmax), but P_i is; the product f_i * P_i still provides gradient signal to W_r.

Training Stabilization Techniques

Selective Precision: Only the router computation (W_r · x → softmax) is cast to float32. The resulting dispatch and combine tensors are recast to bfloat16 before the all-to-all communication. This eliminates the expensive float32 inter-device communication while still providing float32 numerical stability in the softmax.

Configuration Quality (Neg. Log Perp.) Speed (ex/sec)
Switch-Base (float32) -1.718 1160
Switch-Base (bfloat16) diverged 1390
Switch-Base (selective precision) -1.716 1390

Reduced Initialization Scale: Standard Transformer initialization uses σ = √(1/n). Switch Transformer uses σ = √(s/n) with s=0.1, reducing the initialization scale by 10×.

Init scale Avg quality Std. Dev. quality
0.1× -2.72 0.01
1.0× -3.60 0.68

Expert Dropout: During fine-tuning, set dropout rate inside expert FFN layers to 0.4, while keeping non-expert layers at 0.1. This addresses the severe overfitting that occurs because Switch models have far more parameters than FLOPs-matched dense counterparts, and fine-tuning datasets are small.

Router Exploration

Token routing is deterministic (argmax), which is exploitative. To introduce exploration, the paper uses multiplicative jitter noise on router inputs during training:

router_logits *= uniform(minval=1-eps, maxval=1+eps)  [input jitter]

Input jitter gives the best quality (-1.468 neg. log perp.) vs. argmax (-1.471), sample softmax (-1.570), and input dropout (-1.480).


Scaling Properties

Scaling Along the Expert Axis (Step-Basis)

For a fixed computational budget (same FLOPs per token), increasing the number of experts (from 1 to 256) consistently improves model quality on a per-step basis. The Switch-Base 64 expert model achieves the quality of T5-Base at step 60k by step 450k — a 7.5× step speedup.

Scaling on a Time-Basis

Additional communication overhead from all-to-all operations means step speedup does not fully translate to wall-clock speedup. However, Switch-Base 64 experts still achieves the same quality as T5-Base in 1/7th the wall-clock time. Against T5-Large (3.5× more FLOPs/token), Switch-Base achieves a 2.5× wall-clock speedup.

Trillion-Parameter Models

Combining expert, model, and data parallelism, the paper designs two massive models:

Model Parameters FLOPs/seq Expert Freq. Num Experts Neg. Log Perp. @500k
T5-XXL 11B 6.3T -1.095
Switch-XXL 395B 6.3T 1/2 64 -1.008
Switch-C 1571B 890B 1 2048 -1.043

Switch-C is 4× faster than T5-XXL to fixed pre-training perplexity. Switch-XXL outperforms T5-XXL in pre-training quality but is sometimes unstable in training and achieves better downstream results.


Evaluation Methodology

Pre-training: Masked language modeling on C4 corpus (Colossal Clean Crawled Corpus, 180B+ target tokens). Metric: negative log perplexity.

Fine-tuning: 576B tokens pre-training, then fine-tuned on a diverse benchmark suite:

Hardware: TPUv3 pods (high-bandwidth interconnect). All FLOPs comparisons are FLOP-matched on the same hardware.


Results

Pre-training Speed vs. T5 Baselines (Table 1)

Model Quality @100k Time to Quality=-1.50 (hrs) Speed (ex/sec)
T5-Base -1.731 Not achieved 1600
Switch-Base (cap=1.0) -1.561 62.8 1000
MoE-Base (cap=1.0) -1.572 80.1 860

Switch-Base reaches the quality threshold 22% faster than MoE-Base and at 1/3 the time of T5-Large.

Fine-tuning Results (FLOP-matched, Table 5)

Model GLUE SQuAD SuperGLUE Winogrande
T5-Base 84.3 85.5 75.1 66.6
Switch-Base 86.7 87.2 79.5 73.3
T5-Large 87.8 88.1 82.7 79.1
Switch-Large 88.5 88.6 84.7 83.0

Distillation (Table 6–7)

Distilling Switch-Base (3.8B) → T5-Base (223M):

Multilingual Pre-training

mSwitch-Base vs. mT5-Base on 101 languages:


Limitations

  1. All-to-all communication bottleneck: Each Switch layer requires two all-to-all operations. On inter-node links (InfiniBand/Ethernet), this is the dominant cost and scales poorly compared to ring AllReduce. The paper targets TPU pods with high-bandwidth mesh interconnect.

  2. Token dropping: Tokens routed to an overloaded expert are dropped (passed through residual unprocessed). Drop rates of 1% are typical but degrade training slightly. Capacity factor >1.0 reduces drops at the cost of wasted compute/memory.

  3. Training instability at scale: Switch-XXL (395B) is sometimes unstable. Stability techniques (selective precision, reduced init scale, input jitter) are effective for Switch-Base/-Large/-C but insufficient for Switch-XXL.

  4. Fine-tuning gap: Switch models do not consistently outperform T5 on all downstream tasks. On ARC Challenge, T5-Large outperforms Switch-Large. The largest models' pre-training perplexity advantage does not fully transfer to reasoning-heavy tasks (SuperGLUE).

  5. Hardware specificity: Implementation in Mesh-TensorFlow with statically-shaped tensors. GPU deployments require additional engineering (dynamic shapes, sparse kernels). Sparse GPU kernels (Gray et al., 2017; Gale et al., 2020) were not yet mature enough for training at time of publication.

  6. Expert load imbalance cannot be fully eliminated: The auxiliary loss encourages but does not guarantee uniform routing. Heavy-hitter tokens (common subwords, BOS/EOS) can consistently overload certain experts.

  7. Poorly understood fine-tuning vs. pre-training trade-off: Switch-C (1.6T, fewer FLOPs/token) vs. Switch-XXL (395B, 10× more FLOPs/token): Switch-C achieves similar pre-training perplexity but lower SQuAD score (87.7 vs. 89.6), suggesting FLOPs/token and parameter count have independent and not fully understood effects on fine-tuning.



RL Formulation Table

This paper contains no reinforcement learning. Not applicable.

The paper does note an analogy to RL in Appendix C: token routing is a discrete, non-differentiable decision (like taking an action) where the router receives no counterfactual information about alternative experts (like a bandit with partial feedback). The paper frames this as an exploration-exploitation dilemma and tests various exploration strategies (input jitter wins).


Relevance to DynamICCL

Switch Transformers are relevant to DynamICCL on multiple levels: as a distinct workload class that generates a qualitatively different collective communication pattern, and as a motivating example for dynamic communication optimization.

New collective type — All-to-All:

Per Switch layer (forward pass):
  All-to-All #1: dispatch tokens to expert devices
                 [num_cores, tokens_per_core, num_experts, expert_capacity, d_model]
                 each core sends E×C×d_model bfloat16 values
  Expert FFN computation (local, no communication)
  All-to-All #2: return expert outputs
                 same tensor size

Data-parallel gradient AllReduce (end of backward):
  AllReduce(gradients) as in standard DP

Implications for DynamICCL:

  1. All-to-All is not currently in DynamICCL's action space. The Config Agent's collective types (AllReduce, AllGather, ReduceScatter) do not include AllToAll. Extending DynamICCL to support Switch Transformer workloads requires adding NCCL's ncclAllToAll or ncclAllToAllv to the tunable collective set.

  2. Expert load imbalance creates irregular all-to-all traffic. When the auxiliary loss does not perfectly balance routing, some expert devices receive more tokens than others. This creates asymmetric message sizes in the all-to-all, which is exactly the kind of irregular, burst-prone traffic pattern that DynamICCL's LSTM+CUSUM Trigger Agent is designed to detect.

  3. Capacity factor is a dynamic parameter. The capacity_factor hyperparameter controls buffer size and thus communication volume per all-to-all. At congestion, reducing capacity_factor reduces communication volume (at the cost of more token dropping). This is a potential action for DynamICCL's Config Agent in a Switch Transformer context — dynamically trading off token quality for network relief.

  4. Coexistence of all-to-all and AllReduce. In expert + data parallel training, per-layer all-to-all and end-of-step AllReduce overlap temporally. DynamICCL must understand which collective is responsible for observed congestion and select different NCCL configurations for each.

  5. TPU vs. GPU generalizability. The paper's results are on TPUs. GPU implementations of Switch Transformers (in Fairseq, Tutel, etc.) are less optimized and more sensitive to network configuration — making DynamICCL's configuration optimization more impactful on GPU clusters like Chameleon Cloud.