AI Training
Covers distributed training strategies, mixed precision training, gradient management, and the systems challenges of training at scale.
- Compare data parallelism, model parallelism, and pipeline parallelism strategies for distributed training
- Explain how mixed precision training reduces memory usage and increases throughput with minimal quality loss
- Analyze the communication overhead in distributed training and evaluate gradient compression techniques
- Design a distributed training configuration that matches hardware topology and model requirements
- Implement fault-tolerant training with checkpointing and automatic recovery mechanisms
01 Distributed Training Fundamentals Viz
As models grow larger and datasets expand, training on a single accelerator becomes impractical. Distributed training splits the workload across multiple devices, potentially across multiple machines, to reduce training time from weeks to hours. Understanding distributed training is essential for any engineer working with modern large-scale ML systems.
Distributed Training
The practice of splitting a model training workload across multiple accelerators (GPUs/TPUs) and potentially multiple machines, coordinating computation and communication to reduce total training time while maintaining mathematical correctness.
The fundamental challenge of distributed training is maintaining the mathematical equivalence of optimization while dividing work across devices. Communication overhead between devices, synchronization strategies, and fault tolerance all add complexity that does not exist in single-device training.
Figure: Distributed Training Patterns
Training GPT-3 (175B parameters) on a single NVIDIA A100 GPU would take approximately 355 years. With 1,024 A100 GPUs using efficient distributed training, it takes roughly 34 days. Without distribution strategies, the largest foundation models would simply be impossible to train in any practical timeframe.
Failure Modes in Distributed Systems
Distributed training introduces new failure modes that single-device training does not have. Network partitions, device failures, stragglers (slow devices), and communication deadlocks can all disrupt training. Production distributed training systems must handle these failures gracefully, typically through checkpointing and automatic restart mechanisms.
- Network partitions — Devices lose connectivity, causing communication operations to hang or timeout.
- Device failures — A GPU may crash or produce incorrect results due to hardware faults.
- Stragglers — One slow device forces all others to wait during synchronization barriers.
- Communication deadlocks — Incorrect ordering of collective operations blocks all progress.
- Checkpoint corruption — Failed writes during checkpointing can corrupt the saved state.
In synchronous distributed training, all devices must wait for the slowest device before proceeding to the next step. A single device running 20% slower (due to thermal throttling, background processes, or hardware degradation) slows the entire cluster by 20%. At scale, straggler mitigation through redundant computation or asynchronous methods becomes essential.
Choosing a Distribution Strategy
| Strategy | When to Use | Key Requirement |
|---|---|---|
| Data Parallelism | Model fits in one device's memory | High inter-device bandwidth for gradient sync |
| Model Parallelism | Model too large for one device | Very high bandwidth (NVLink within a node) |
| Pipeline Parallelism | Model too large, many sequential layers | Moderate bandwidth across nodes |
| Hybrid (3D Parallelism) | Largest models (100B+ params) | Multi-level network topology awareness |
Table 8.1: Choosing a distribution strategy based on model size and hardware.
Always start with data parallelism. It is the simplest approach and sufficient for most models. Only move to model parallelism or pipeline parallelism when the model genuinely does not fit in a single device's memory. Premature complexity in distribution strategy is a common source of bugs and wasted engineering effort.
When should you use model parallelism instead of data parallelism for distributed training?
Think about the fundamental constraint that forces you to split the model itself.
Distributed Training
Training a model across multiple accelerators or machines to reduce total training time by parallelizing computation.
Straggler Problem
The performance degradation caused by the slowest device in a distributed training setup, which forces all other devices to wait during synchronization.
02 Data Parallelism
Data parallelism is the most common distributed training strategy, where each device holds a complete copy of the model and processes a different subset of the training data. After computing local gradients, devices synchronize by aggregating (typically averaging) their gradients before updating model parameters.
Data Parallelism
A distributed training strategy where every device maintains a complete copy of the model. The training dataset is partitioned across devices, each computes gradients on its partition, and gradients are aggregated (averaged) across all devices before updating the shared model parameters.
Synchronous Data Parallelism
Synchronous data parallelism uses an all-reduce operation to aggregate gradients from all devices before any device updates its parameters. This maintains mathematical equivalence with single-device training (up to floating-point precision) and is the default approach in most frameworks. The all-reduce operation can be implemented using ring all-reduce, which achieves optimal bandwidth utilization.
g_{\text{global}} = \frac{1}{N} \sum_{i=1}^{N} g_iRing all-reduce arranges N devices in a logical ring. Each device sends a chunk of its gradient to the next device in the ring while receiving from the previous. After 2(N-1) communication steps, all devices have the fully reduced gradient. This achieves optimal bandwidth utilization: each device sends and receives exactly 2(N-1)/N of the total gradient data, approaching the theoretical minimum as N grows.
Asynchronous Data Parallelism
Asynchronous data parallelism allows devices to update parameters independently without waiting for all gradients to arrive. This eliminates the synchronization bottleneck but introduces gradient staleness, where devices may use slightly outdated parameters. Asynchronous methods converge faster in wall-clock time for some problems but can be less stable and harder to reproduce.
| Property | Synchronous | Asynchronous |
|---|---|---|
| Mathematical equivalence | Preserved (up to FP precision) | Approximate (stale gradients) |
| Straggler sensitivity | High (all wait for slowest) | Low (no synchronization barrier) |
| Convergence stability | Deterministic | Less stable, harder to reproduce |
| Implementation complexity | Moderate (all-reduce) | Higher (parameter servers, staleness) |
| Preferred for | Most use cases, default choice | Very large clusters with heterogeneous devices |
Table 8.2: Synchronous vs. asynchronous data parallelism.
Scaling: Batch Size and Learning Rate
Scaling data parallelism effectively requires careful attention to the relationship between batch size and learning rate. The linear scaling rule suggests increasing the learning rate proportionally to the number of devices, but this breaks down at very large batch sizes.
\eta_{\text{scaled}} = \eta_{\text{base}} \times NThe linear scaling rule works well up to a point, but beyond a critical batch size, training becomes unstable or model quality degrades. The LARS (Layer-wise Adaptive Rate Scaling) and LAMB (Layer-wise Adaptive Moments) optimizers address this by adapting learning rates per layer based on the ratio of weight norms to gradient norms, enabling stable training with batch sizes of 32K or more.
When scaling to many devices, always use a learning rate warmup period. Start with a small learning rate and linearly increase it over the first few hundred or thousand steps. This prevents the large initial gradients (from randomly initialized or not-yet-adapted weights) from destabilizing training when multiplied by a large scaled learning rate.
Data Parallelism
A distributed training strategy where each device holds a full model copy and processes different data subsets, synchronizing gradients across devices.
All-Reduce
A collective communication operation that aggregates values from all devices and distributes the result back to all devices, commonly used for gradient synchronization.
03 Model Parallelism and Pipeline Parallelism
When a model is too large to fit in a single device's memory, model parallelism distributes different parts of the model across devices. Tensor parallelism splits individual layers across devices, while pipeline parallelism assigns different layers to different devices. Both approaches are essential for training the largest foundation models.
Model Parallelism
A distributed training strategy that partitions the model itself across multiple devices, with each device responsible for a subset of the model's parameters and computation. Required when a model's memory footprint exceeds the capacity of a single accelerator.
Tensor Parallelism
Tensor parallelism partitions weight matrices across devices and uses collective communication to combine partial results. For Transformer models, the attention and feed-forward layers can be split across devices along specific dimensions. Megatron-LM popularized efficient tensor parallelism strategies for large language models.
Consider a feed-forward layer with weight matrix W of shape [4096, 16384]. With 4-way tensor parallelism, W is split column-wise into four shards of shape [4096, 4096], one per device. Each device computes its partial output Y_i = X * W_i, then an all-reduce sums the partial results. This distributes both memory and computation evenly.
Pipeline Parallelism
Pipeline parallelism assigns sequential groups of layers to different devices, creating a pipeline where micro-batches flow through devices in sequence. GPipe and PipeDream introduced techniques to reduce the "pipeline bubble" (idle time when devices wait for data), including micro-batch interleaving and asynchronous scheduling.
Pipeline Bubble
The idle time at the beginning and end of a pipeline parallel training step when some devices have no micro-batch to process. The bubble fraction decreases as the number of micro-batches increases relative to the number of pipeline stages.
With P pipeline stages and M micro-batches, the bubble fraction is approximately (P-1)/M. For 4 stages and 4 micro-batches, 75% of the pipeline is wasted in the bubble. Increasing to 32 micro-batches reduces the bubble to under 10%. Always use many more micro-batches than pipeline stages to maintain high device utilization.
\text{Bubble fraction} \approx \frac{P - 1}{M}Hybrid (3D) Parallelism
In practice, the largest models use a combination of all three parallelism strategies. A typical setup might use tensor parallelism within a node (where NVLink provides high bandwidth), pipeline parallelism across nodes, and data parallelism across groups of nodes.
| Framework | Parallelism Support | Notable Features |
|---|---|---|
| Megatron-LM | Tensor + Pipeline + Data | NVIDIA's reference for large language model training |
| DeepSpeed (ZeRO) | Data + Pipeline (ZeRO-3) | Memory-efficient optimizer state sharding |
| FSDP (PyTorch) | Sharded Data Parallel | Native PyTorch, shards params/grads/optimizer states |
| Alpa | Automated Tensor + Pipeline | Automatic parallelism strategy search via ILP |
Table 8.3: Major frameworks for hybrid parallelism.
Use tensor parallelism within a single node where NVLink provides 600+ GB/s bandwidth. Use pipeline parallelism across nodes connected by InfiniBand (100-400 Gb/s). Use data parallelism across groups of nodes. This hierarchical mapping minimizes communication over the slowest links.
Model Parallelism
A strategy that distributes different parts of a model across devices, enabling training of models that exceed single-device memory capacity.
Pipeline Parallelism
A form of model parallelism that assigns sequential groups of layers to different devices, processing micro-batches in a pipeline fashion.
04 Mixed Precision Training
Mixed precision training uses lower-precision number formats (typically FP16 or BF16) for most computations while maintaining FP32 for critical operations. This approach can nearly double training throughput on modern GPUs with tensor cores while using significantly less memory.
Mixed Precision Training
A training technique that uses lower-precision floating-point formats (FP16 or BF16) for the computationally intensive forward and backward passes while retaining FP32 precision for the weight update step. This yields up to 2x throughput improvement with minimal accuracy loss.
The Mixed Precision Recipe
The key technique is maintaining a FP32 master copy of weights while performing forward and backward passes in half precision. Gradients are computed in FP16/BF16 and then cast to FP32 for the weight update. This preserves the precision of the update step, which is most sensitive to numerical errors.
- Maintain FP32 master weights (the "source of truth" for parameters).
- Cast weights to FP16/BF16 for the forward pass.
- Compute loss and backward pass in FP16/BF16.
- Cast gradients back to FP32.
- Update FP32 master weights using the FP32 gradients.
Weight updates are typically tiny (learning rate * gradient, e.g., 1e-4 * 1e-3 = 1e-7). In FP16, values smaller than ~6e-8 round to zero, meaning many weight updates would be silently lost. By accumulating updates in FP32 (minimum representable: ~1.4e-45), no update is lost. The FP32 master weights are the true parameters; the FP16 copies are just for fast computation.
Loss Scaling for FP16
Loss scaling is essential for FP16 training to prevent gradient underflow. Small gradient values that are representable in FP32 may round to zero in FP16. Loss scaling multiplies the loss by a large factor before backpropagation (scaling up gradients) and then divides the gradients after the backward pass. Dynamic loss scaling automatically adjusts the scale factor during training.
\hat{L} = s \cdot L, \quad \hat{g} = \frac{1}{s} \nabla_{\theta} \hat{L}BFloat16: The Simpler Alternative
BFloat16 (BF16), developed for Google's TPUs and now supported on NVIDIA Ampere and newer GPUs, uses the same exponent range as FP32 but with reduced mantissa precision. This eliminates the need for loss scaling because BF16 can represent the same range of values as FP32.
| Format | Sign Bits | Exponent Bits | Mantissa Bits | Range | Loss Scaling Needed |
|---|---|---|---|---|---|
| FP32 | 1 | 8 | 23 | ~1.4e-45 to ~3.4e38 | No |
| FP16 | 1 | 5 | 10 | ~6e-8 to 65504 | Yes |
| BF16 | 1 | 8 | 7 | ~1.4e-45 to ~3.4e38 | No |
| FP8 (E4M3) | 1 | 4 | 3 | ~0.015 to 448 | Yes |
Table 8.4: Comparison of floating-point formats used in ML training.
If your hardware supports BF16 (NVIDIA Ampere/Hopper GPUs or Google TPUs), prefer BF16 over FP16 for training. BF16 eliminates the need for loss scaling, simplifying the training pipeline and reducing a common source of training instability. Reserve FP16 for inference on older hardware that lacks BF16 support.
class="tok-comment"># PyTorch automatic mixed precision (AMP)
scaler = torch.cuda.amp.GradScaler() class="tok-comment"># Only needed for FP16
for batch in dataloader:
optimizer.zero_grad()
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model(batch)
loss = loss_fn(output, targets)
class="tok-comment"># With BF16, no scaler needed:
loss.backward()
optimizer.step()Mixed Precision Training
A technique that uses lower-precision formats (FP16/BF16) for most computation while maintaining FP32 for numerically sensitive operations, improving speed and reducing memory.
Loss Scaling
A technique used in FP16 training that scales the loss value up before backpropagation to prevent small gradients from underflowing to zero.
05 Gradient Management and Communication
Efficient gradient communication is often the bottleneck in distributed training. For a model with billions of parameters, each gradient synchronization step must transfer gigabytes of data across the network. Optimizing this communication is essential for achieving good scaling efficiency.
A model with 7 billion FP32 parameters requires transferring 28 GB of gradient data per synchronization step. On a 100 Gbps InfiniBand network, this alone takes ~2.2 seconds per step — potentially exceeding the computation time. At 175B parameters, the problem becomes 25x worse. Communication optimization is not optional at scale; it is the primary engineering challenge.
Gradient Compression
Gradient compression reduces communication volume by transmitting approximate gradients. Techniques include gradient quantization (reducing precision of transmitted gradients), sparsification (transmitting only the largest gradients), and error feedback (accumulating compression error for future correction). These methods can reduce communication by 10-100x with minimal impact on convergence.
Gradient Sparsification
A compression technique that transmits only the top-k largest gradient values (by magnitude), typically 0.1-1% of the total. The remaining small gradients are accumulated locally in an error feedback buffer and added to the next step's gradients, ensuring no gradient information is permanently lost.
| Technique | Compression Ratio | Convergence Impact | Complexity |
|---|---|---|---|
| Gradient Quantization (1-bit) | 32x | Minimal with error feedback | Low |
| Top-k Sparsification (0.1%) | 1000x | Minimal with error feedback | Moderate |
| Random Sparsification | 10-100x | Moderate, needs larger learning rate | Low |
| PowerSGD (low-rank) | 10-100x | Minimal | Moderate |
Table 8.5: Gradient compression techniques and their trade-offs.
Gradient Accumulation
Gradient accumulation is a practical technique for simulating larger batch sizes without additional devices. Instead of synchronizing gradients after every micro-batch, gradients are accumulated locally over multiple micro-batches before synchronization. This reduces communication frequency and allows training with effectively larger batch sizes on limited hardware.
class="tok-comment"># Gradient accumulation: simulate batch_size * accum_steps
accum_steps = class="tok-number">4
for i, batch in enumerate(dataloader):
loss = model(batch) / accum_steps class="tok-comment"># Scale loss
loss.backward() class="tok-comment"># Accumulate gradients
if (i + class="tok-number">1) % accum_steps == class="tok-number">0:
optimizer.step() class="tok-comment"># Update after N steps
optimizer.zero_grad() class="tok-comment"># Reset gradientsIf you cannot afford many GPUs, gradient accumulation lets you simulate large-batch training on limited hardware. For example, 4 GPUs with 8 accumulation steps gives an effective batch size equivalent to 32 GPUs — at the cost of 8x longer wall-clock time per step but with the same convergence behavior.
In the gradient accumulation code, why is the loss divided by accum_steps before calling backward()?
What happens to the gradient magnitude if you call backward() 4 times without scaling?
Overlapping Communication and Computation
Overlapping computation with communication is critical for hiding communication latency. By starting gradient communication for earlier layers while later layers are still computing, the communication overhead can be partially or fully masked. Modern frameworks implement this through bucket-based gradient reduction, where gradients are grouped into buckets that are communicated as soon as all gradients in the bucket are available.
PyTorch DDP groups gradients into 25 MB buckets by default. As the backward pass progresses from the last layer toward the first, each bucket triggers an all-reduce as soon as all its gradients are computed. While the network transfers bucket N, the GPU continues computing gradients for bucket N+1. With well-tuned bucket sizes, communication can be almost entirely hidden behind computation.
Gradient Compression
Techniques that reduce the volume of gradient data communicated between devices during distributed training, including quantization and sparsification.
Gradient Accumulation
A technique that sums gradients over multiple micro-batches before performing a weight update, effectively simulating larger batch sizes.
Key Takeaways
- 1Data parallelism is the simplest and most common distributed training approach, but model parallelism is essential for the largest models.
- 2Mixed precision training can nearly double throughput and halve memory usage with minimal impact on model quality.
- 3The linear scaling rule for learning rate breaks down at very large batch sizes, requiring warmup and specialized optimizers.
- 4Gradient communication is often the bottleneck in distributed training, motivating compression and overlap techniques.
- 5Production training systems combine data, tensor, and pipeline parallelism in hybrid configurations matched to hardware topology.
CH.08
Chapter Complete
Chapter Progress
Interact with the visualization
AI Training Quiz
Test your understanding of distributed training, gradient descent optimization, and training at scale.
Ready to test your knowledge?