Training Loop — End-to-End Mental Model
What actually happens in one training step, tensor by tensor.
A training step is not magic. It is a precise sequence of tensor operations. Every step: data enters, activations flow forward, loss is computed, gradients flow backward, weights change. Repeat ~millions of times.
Input → Tokenization / Preprocessing
Raw data (text, images, audio) is converted into tensors. For text: tokenizer maps strings to integer IDs → embedding lookup creates float tensors of shape [batch, seq_len, d_model]. For images: normalization, resizing → [B, C, H, W]. Memory: input tensors and targets are loaded to GPU.
Forward Pass
Inputs flow through all layers. Each layer computes output and stores intermediate activations in memory because they'll be needed for backprop. For a 7B LLM: 32 transformer blocks, each caching Q, K, V, attention scores, MLP outputs. This is 60–80% of training memory usage.
Loss Computation
Compare model output (logits) to ground truth targets. Returns a scalar loss value. For LLMs: cross-entropy over next-token predictions. The loss is the error signal — all learning flows from this single number.
Backpropagation
PyTorch auto-differentiates through the computation graph. Starting from loss, chain rule propagates gradients backward through every operation. Each parameter p gets a gradient p.grad — the direction of steepest increase in loss (we want to go opposite).
Gradient Accumulation (optional)
If using gradient accumulation over N steps: gradients from step k are added to existing p.grad without zeroing. Only after N micro-steps do you take an optimizer step. Effect: effective batch size = micro-batch × N × n_gpus.
Optimizer Step
The optimizer reads p.grad for every parameter and computes a parameter update. For Adam: maintains running mean (m) and variance (v) of gradients. Update formula: p ← p − lr × m̂/(√v̂ + ε). This is where learning actually happens.
Weight Update + Zero Gradients
Weights are now updated. Critical: call optimizer.zero_grad() before next step (or set set_to_none=True for memory efficiency). If you forget, gradients accumulate across steps — a common bug that looks like exploding gradients.
What Lives in Memory During a Step
| Tensor | Size | When Present | Notes |
|---|---|---|---|
| Model weights | N_params × dtype_bytes | Always | 7B @ FP16 = ~14GB |
| Optimizer state | 2–4× weights | Always | Adam: m + v = 2× weights |
| Activations | Scales with batch × seq_len | Forward → Backward | 60–80% of training mem |
| Gradients | Same as weights | After loss.backward() | Can be FP32 even in FP16 training |
| Input/Target batch | Small | During step | Usually negligible |
Rule of thumb: Total training memory ≈ 16× model parameters in bytes (FP32). For mixed precision: ≈ 6× in FP16 + master FP32 copy. A 7B model needs ~112GB for full FP32 training.
Forward Pass Deep Dive
What actually happens layer by layer — and why activations eat your memory.
The forward pass is deterministic given weights and inputs. It is a composed function: output = f_L(f_{L-1}(...f_1(input))). Each f_i is a layer. Each intermediate result is called an activation.
Linear Layers
Most fundamental operation. Given input x: [B, in], weight W: [out, in], bias b: [out]:
During backprop, you need: the input x to compute ∂L/∂W = xᵀ ∂L/∂y. This is why x must be stored.
Attention (Transformer)
Self-attention is the expensive operation. For each layer:
Q = XW_Q, K = XW_K, V = XW_V Attn = softmax(QKᵀ / √d_k) · VMemory cost: The attention matrix QKᵀ is [B, H, S, S] — quadratic in sequence length S. For S=8192: that's 8192² = 67M floats per layer per head. FlashAttention avoids materializing this by fusing the kernel.
Attention cost is O(S²) in memory and compute. Doubling sequence length = 4× more attention cost. This is the core bottleneck for long-context LLMs.
Why Activations Are Stored
The Fundamental Trade-off
During backprop, the gradient of layer k needs the output of layer k-1 (chain rule). So every intermediate activation must be stored from the forward pass. This is activation memory — it scales with batch size and model depth.
Gradient checkpointing trades compute for memory: recompute activations during backward instead of storing them. ~33% more compute, ~10× less activation memory.
| Layer Type | Forward Operation | What's Stored | Backprop Needs |
|---|---|---|---|
| Linear | y = Wx + b | x | x to compute dW |
| ReLU | y = max(0, x) | mask (x > 0) | mask for gradient gating |
| BatchNorm | normalize + scale | x, μ, σ² | μ, σ² for gradient |
| Attention | softmax(QKᵀ/√d)V | Q, K, V, attn weights | All 4 tensors |
| Embedding | lookup table | input indices | indices to scatter gradients |
Loss Functions
Choosing the right error signal is as important as the architecture.
Cross Entropy (CE)
L = -Σ y_true · log(softmax(logits))Use for: classification, language modeling (next token prediction). Why: maximizes log-probability of correct class. Punishes confident wrong predictions severely (log near 0 = huge loss).
Mean Squared Error (MSE)
L = (1/N) Σ (y_pred - y_true)²Use for: regression tasks, predicting continuous values. Warning: very sensitive to outliers (quadratic). Alternatives: MAE (L1) is more robust, Huber loss is in-between.
Binary Cross Entropy (BCE)
L = -(y·log(p) + (1-y)·log(1-p))Use for: binary classification, multi-label classification. Each output is an independent Bernoulli. Use BCEWithLogitsLoss in PyTorch for numerical stability.
Contrastive / NT-Xent
L = -log[exp(sim(z,z⁺)/τ) / Σ exp(sim(z,zₙ)/τ)]Use for: VLMs (CLIP), self-supervised learning. Pulls positive pairs together, pushes negatives apart in embedding space. Temperature τ controls sharpness.
Logits vs Probabilities
Logits = raw unnormalized scores from the model's last linear layer. They can be any real number. Probabilities = after softmax, sum to 1.
Always use logits + fused loss. CrossEntropyLoss in PyTorch applies log-softmax internally using the log-sum-exp trick: log Σ exp(xᵢ) = max(x) + log Σ exp(xᵢ - max(x)). This prevents overflow/underflow from exp(large number). Never do softmax → log → NLLLoss manually — you lose numerical stability.
Backpropagation
The engine of learning — how error signals flow backward through a network.
Chain Rule Intuition
Backprop is just the chain rule applied to composed functions. If L = f(g(h(x))), then:
dL/dx = (dL/df) · (df/dg) · (dg/dh) · (dh/dx)In a neural network: loss depends on outputs which depend on layer n-1 which depends on layer n-2... The gradient of the loss with respect to each weight is a product of local gradients chained together.
Layer-by-Layer Gradient Computation
What PyTorch Does
During forward, PyTorch builds a computation graph (DAG). Every tensor records the operation that created it. loss.backward() traverses this graph in reverse, calling each operation's backward function, accumulating gradients via the chain rule.
Each node has: a forward function and a backward function. The backward function computes local gradients and passes them upstream.
Vanishing & Exploding Gradients
The gradient at layer k involves a product of N weight matrices and activation derivatives. This product can shrink to near-zero (vanishing) or blow up (exploding).
Vanishing Gradients
Cause: Sigmoid/Tanh saturate → derivatives near 0. Deep networks → product of many small numbers → gradient = ~0 at early layers. Early layers learn nothing.
Fix: Use ReLU/GELU, residual connections (ResNet/Transformer), BatchNorm, careful init (Xavier/He).
Exploding Gradients
Cause: Product of weight matrices with large singular values → gradient norm grows exponentially. Common in RNNs, very deep networks, high LR.
Fix: Gradient clipping (clip_grad_norm_), lower LR, BatchNorm, careful init, residual connections.
Residual Connections are the Key Insight
Residual connections (y = F(x) + x) create gradient highways. During backprop, ∂L/∂x = ∂L/∂y · (∂F/∂x + 1). The +1 term ensures gradients always have a direct path to early layers, regardless of how small ∂F/∂x gets. This is why 100+ layer networks can be trained.
Optimizers
How gradients get translated into weight updates — and the hidden state they carry.
SGD (Stochastic Gradient Descent)
θ ← θ - lr · gThe simplest update rule. One gradient step in the negative gradient direction. Problem: sensitive to gradient noise, oscillates in narrow valleys. When to use: CNNs with careful tuning, when you have very large batches with low gradient noise. SGD with momentum is more commonly used.
SGD + Momentum
v ← β·v + g θ ← θ - lr·vAccumulates a velocity vector in the direction of consistent gradients. Dampens oscillations. β=0.9 is standard — means each step uses 90% of previous velocity + 10% new gradient. Physical intuition: ball rolling downhill with inertia.
Adam (Adaptive Moment Estimation)
m ← β₁·m + (1-β₁)·g (1st moment: mean) v ← β₂·v + (1-β₂)·g² (2nd moment: variance) m̂ = m/(1-β₁ᵗ), v̂ = v/(1-β₂ᵗ) (bias correction) θ ← θ - lr · m̂/(√v̂ + ε)Per-parameter adaptive learning rates. Dividing by √v̂ normalizes the step size by the gradient's recent variance. Effect: parameters with noisy gradients take smaller steps; parameters with consistent gradients take larger steps. Defaults: β₁=0.9, β₂=0.999, ε=1e-8.
AdamW
θ ← θ - lr · m̂/(√v̂ + ε) - lr·λ·θAdamW ≠ Adam + L2 regularization. In Adam, adding L2 to the loss means the weight decay term gets divided by √v̂, making it effectively smaller for noisy parameters. AdamW decouples weight decay from the gradient update — it's applied directly to weights. Always use AdamW for LLMs/Transformers.
Optimizer State Memory
| Optimizer | Extra State per Param | Memory Multiplier | Notes |
|---|---|---|---|
| SGD | None | 1× | + 1× if momentum |
| SGD + Momentum | velocity v | 2× | — |
| Adam / AdamW | m, v (both FP32) | 3× (FP16 model) or 4× total | State always FP32 for stability |
| 8-bit Adam | m, v (quantized) | ~1.5× | bitsandbytes library |
Learning Rate & Schedulers
The single most important hyperparameter. Get this wrong and nothing else matters.
Why LR Is #1
Learning rate controls how large a step you take in the loss landscape. Too high: overshoot minima, loss diverges or oscillates. Too low: trains too slowly, gets stuck in poor local minima. There is a "Goldilocks zone" — and it changes during training.
LR Too High
Loss spikes or diverges. Training unstable. Weight norms blow up. First thing to check when you see NaN loss. Fix: halve LR, use warmup.
LR Too Low
Loss decreases but very slowly. May plateau early. Model underfits. Signs: training loss barely moving after hours. Fix: LR finder, increase LR, check scheduler.
LR Warmup
Start with very low LR, linearly increase to target LR over N steps. Why: at training start, parameters are random, gradients are large and noisy. High LR on random init → divergence. Warmup lets the model stabilize before taking aggressive steps.
lr(t) = lr_target × (t / warmup_steps) for t ≤ warmup_stepsRule of thumb: 1–5% of total steps as warmup. For LLMs: 1000–2000 warmup steps common.
Schedulers
| Scheduler | Formula | Shape | Best For |
|---|---|---|---|
| Step Decay | lr × γ^⌊epoch/step⌋ | Staircase | CNNs, simple baselines |
| Cosine Decay | lr × 0.5(1 + cos(π·t/T)) | Smooth curve | LLMs, most modern models |
| Linear Decay | lr × (1 - t/T) | Straight line down | Fine-tuning, simple schedules |
| OneCycleLR | Up then down | Triangle | Fast training (super-convergence) |
| Constant + Cooldown | Flat then drop | L-shape | LLM pre-training (common now) |
Cosine Schedule with Warmup (Modern Default)
def get_lr(step, warmup_steps, max_steps, max_lr, min_lr=1e-5):
# Linear warmup
if step < warmup_steps:
return max_lr * step / warmup_steps
# Cosine decay
progress = (step - warmup_steps) / (max_steps - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + cos(math.pi * progress))
Debug bad LR: Log loss at every step. If loss: (a) spikes immediately → LR too high at start, add warmup. (b) decreases then plateaus early → LR too high later, check schedule. (c) barely moves → LR too low. (d) oscillates with no trend → LR too high, use momentum.
Gradient Mechanics
Controlling how gradients flow — the difference between a stable run and NaN hell.
Gradient Accumulation
When your GPU can't fit a large batch, split it into micro-batches and accumulate gradients before updating weights:
for i, (x, y) in enumerate(loader):
loss = model(x, y) / accum_steps # divide to normalize
loss.backward() # grads accumulate in .grad
if (i + 1) % accum_steps == 0:
optimizer.step()
optimizer.zero_grad()
Effective batch size = micro_batch_size × accum_steps × n_gpus. Behaves identically to a large batch if you also scale LR proportionally.
Gradient Clipping
Cap the gradient norm to prevent exploding gradients from destroying your training:
if ||g|| > max_norm: g ← g × (max_norm / ||g||)loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
| Method | What It Does | When to Use |
|---|---|---|
| Norm clipping | Scale entire gradient vector so its L2 norm ≤ threshold | LLMs, Transformers — standard practice |
| Value clipping | Clip each gradient element to [-c, c] | Less common, use for specific known distributions |
| Adaptive clipping | Track gradient norms over time, set threshold dynamically | When you don't know a good threshold upfront |
Monitor gradient norms. Log grad_norm at every step. Typical range: 0.5–2.0. If it spikes to 10+ intermittently, you have gradient instability. If it's constantly high (5+), your LR is too high or architecture has instability. If it trends to 0, vanishing gradients.
Activations
Non-linearity is what makes neural networks universal approximators.
Without activation functions, stacking linear layers gives you a single linear transformation — no matter how many layers. Activations introduce non-linearity, enabling the network to learn complex functions.
| Activation | Formula | Range | Use Case | Pitfalls |
|---|---|---|---|---|
| ReLU | max(0, x) | [0, ∞) | CNNs, hidden layers | Dying ReLU: if x always < 0, neuron dead forever |
| GELU | x·Φ(x) | (-0.17, ∞) | Transformers, BERT, GPT | Slightly more compute than ReLU |
| SiLU/Swish | x·σ(x) | (-0.28, ∞) | LLaMA, modern LLMs | Similar to GELU |
| Sigmoid | 1/(1+e⁻ˣ) | (0, 1) | Binary classification output | Saturates at ±∞ → vanishing gradients |
| Tanh | (eˣ-e⁻ˣ)/(eˣ+e⁻ˣ) | (-1, 1) | RNNs, normalized outputs | Saturates at ±1, vanishing gradients |
| Softmax | exp(xᵢ)/Σexp(xⱼ) | (0,1), sum=1 | Final layer, classification | Never use in hidden layers; numerically unstable raw |
Why GELU Won in Transformers
GELU = Gaussian Error Linear Unit. Unlike ReLU's hard gate (0 or pass), GELU has a smooth probabilistic gate: inputs are weighted by how likely they are to be positive under a Gaussian. This smooth gradient behavior empirically outperforms ReLU in transformers. LLaMA uses SiLU (nearly identical) in SwiGLU: SiLU(W₁x) ⊙ W₃x.
Saturation = vanishing gradients. Sigmoid derivative = σ(x)(1-σ(x)) → max 0.25 at x=0, nearly 0 at |x| > 3. Stack 10 sigmoid layers: gradient shrinks by 0.25¹⁰ ≈ 0.000001. This is why deep networks failed before ReLU + residual connections.
Precision & Training Efficiency
The difference between FP32 and BF16 is not just speed — it fundamentally changes what you can train.
| Format | Bits | Exponent | Mantissa | Range | Precision |
|---|---|---|---|---|---|
| FP32 | 32 | 8 | 23 | ±3.4×10³⁸ | ~7 decimal digits |
| FP16 | 16 | 5 | 10 | ±65,504 | ~3 decimal digits |
| BF16 | 16 | 8 | 7 | ±3.4×10³⁸ | ~2 decimal digits |
| INT8 | 8 | — | — | [-128, 127] | Integers only |
FP16 vs BF16 — The Critical Distinction
FP16 Problems
Max value 65,504. Gradient values or activations > 65,504 → overflow → NaN. Very small values < 6×10⁻⁵ → underflow → 0. Requires loss scaling to keep gradients in representable range.
BF16 Advantages
Same exponent range as FP32 → no overflow/underflow issues. Drop-in replacement for FP32 without loss scaling. Lower precision (7 mantissa bits vs 23) but in practice models don't care. Default for modern GPU training (A100+, H100).
Mixed Precision Training
Store model weights in FP16/BF16 for speed, but compute loss and keep master copy of weights in FP32 for precision during optimizer step.
Loss Scaling (for FP16)
Multiply loss by a large scalar S before backward → gradients are S× larger → less underflow. After backward, divide gradients by S before optimizer step. PyTorch GradScaler does this automatically with dynamic scaling — if overflow detected, it halves S and skips the step.
scaler = GradScaler()
with autocast('cuda'):
output = model(x)
loss = criterion(output, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Tensor Cores: Nvidia Tensor Cores accelerate matrix multiplications in FP16/BF16/TF32. On A100: 312 TFLOPS in BF16 vs 19.5 TFLOPS in FP32. That's 16× speedup. You must use shapes that are multiples of 8 (ideally 64) for Tensor Core utilization. This is why model dimensions are always 512, 1024, 4096, etc.
Memory Optimization Techniques
How to train models that don't fit in GPU memory — the trade-offs that define production LLM training.
| Technique | What It Does | Memory Saved | Compute Cost | Use When |
|---|---|---|---|---|
| Gradient Checkpointing | Recompute activations during backward instead of storing | ~10× activation mem | +30% compute | Large batches, deep models |
| Mixed Precision | FP16/BF16 weights + activations | 2× activation + weight mem | ~0% (often faster) | Always |
| ZeRO Stage 1 | Shard optimizer state across GPUs | 4× optimizer mem | Small comm overhead | Multi-GPU, Adam optimizer |
| ZeRO Stage 2 | Shard optimizer state + gradients | 8× total | Moderate comm | Large models, many GPUs |
| ZeRO Stage 3 | Shard optimizer + grads + parameters | ~linear with N GPUs | Higher comm | Models too large for one GPU |
| CPU Offloading | Move optimizer state to CPU RAM | GPU: optimizer mem freed | PCIe bandwidth bottleneck | Limited GPU mem, PCIe 4.0+ |
| Activation Offloading | Move activations to CPU during forward | GPU: activation mem freed | High bandwidth cost | Extreme memory pressure |
Gradient Checkpointing Deep Dive
How It Works
Without checkpointing: ALL activations from forward pass stored in GPU memory for backward pass. With checkpointing: only store activations at checkpoint boundaries (e.g., every N transformer blocks). During backward, recompute activations between checkpoints on the fly.
In PyTorch:
from torch.utils.checkpoint import checkpoint
output = checkpoint(block, x) # recomputes block during backward
ZeRO Optimizer (DeepSpeed)
ZeRO = Zero Redundancy Optimizer. In standard DDP, every GPU holds a full copy of optimizer states. For 7B param model with Adam: 7B×4 (FP32 weights) + 7B×4 (m) + 7B×4 (v) = 84GB just for optimizer. ZeRO shards this across N GPUs. With 8 GPUs, that's 84/8 = 10.5GB per GPU.
Batch Size & Scaling
Batch size is not just a throughput knob — it fundamentally changes training dynamics.
What Batch Size Controls
Each gradient update is computed on a sample of the data. Larger batches = more accurate gradient estimate (less noise), but also fewer updates per epoch. The gradient noise from small batches acts as a regularizer and helps escape sharp minima.
Small Batches (32–256)
Noisy gradients → acts like regularization. Often finds flatter minima (better generalization). More optimizer steps per epoch. Works poorly on GPUs (low utilization). Good for: fine-tuning, small models, generalization-sensitive tasks.
Large Batches (2048+)
Accurate gradient → fast convergence. Lower variance → may overfit or find sharp minima (worse generalization unless LR scaled). Excellent GPU utilization. May need LR warmup. Good for: LLM pre-training, high-throughput training.
Linear Scaling Rule
When you increase batch size by k×, scale LR by k× to maintain the same training dynamics. Intuition: larger batch → lower gradient noise → gradient is more accurate → you can afford to take a larger step. Caveat: this holds for moderate scaling (≤32×). Very large batches need additional warmup and learning rate tuning.
Effective Batch Size
effective_batch = micro_batch × grad_accum_steps × n_gpusIf micro_batch=4, accum_steps=8, n_gpus=8: effective_batch = 4×8×8 = 256. This is what matters for training dynamics. You can trade any of these three factors against each other based on hardware constraints.
Regularization
Overfitting detection and prevention — keeping your model generalizable.
Detecting Overfitting
The signature: training loss continues to decrease but validation loss stagnates or increases. Monitor both curves throughout training. The gap between train and val loss = degree of overfitting.
| Technique | How It Works | When to Use | Key Params |
|---|---|---|---|
| Dropout | Randomly zero activations during training (force network to not rely on specific neurons) | Dense layers, not attention (usually) | p=0.1–0.3 |
| Weight Decay (L2) | Add λ||θ||² to loss → penalizes large weights | Always, use AdamW | λ=0.01–0.1 |
| Data Augmentation | Artificially expand training data with transforms | Vision: always. NLP: less common | Task-specific |
| Early Stopping | Stop when val loss stops improving | When compute is limited | patience=5–20 epochs |
| Label Smoothing | Soften hard targets: instead of [0,1,0], use [0.05, 0.9, 0.05] | Classification, LLMs | ε=0.1 |
Dropout in Transformers
Modern LLMs (LLaMA, GPT-4) often use very low or zero dropout during pre-training (large dataset + short training = less overfitting risk). Dropout rates of 0.0–0.1 are common. Fine-tuning on small datasets: increase to 0.1–0.3.
Weight Decay Intuition
Without weight decay, nothing prevents weights from growing arbitrarily large (as long as loss decreases). Large weights make the model sensitive to small input perturbations → poor generalization. Weight decay pulls weights back toward zero at each step, keeping the model in a well-behaved regime.
Training Stability & Debugging
The practical guide to why training broke — and how to fix it.
| Issue | Symptoms | Root Cause | Fix |
|---|---|---|---|
| Loss not decreasing | Flat loss from step 1, or plateaus early | LR too low; wrong loss function; data/label bug; gradient not flowing | Check p.requires_grad; verify data pipeline; LR finder; check for bugs in loss |
| NaN in training | Loss = NaN, gradients = NaN, everything breaks | FP16 overflow; LR too high; bad input (inf/nan in data); log(0) | Check data for NaN/inf; add loss scaling; clip gradients; reduce LR; switch to BF16 |
| Exploding gradients | Loss spikes suddenly; grad norm >>10 | LR too high; bad init; no clipping; specific batch with large loss | clip_grad_norm_(params, 1.0); reduce LR; add warmup |
| Loss oscillates | Loss goes up and down, no trend | LR too high; inconsistent batch sizes; unstable architecture | Halve LR; add momentum; check BatchNorm usage |
| Validation >> Train loss | Big gap between train and val | Overfitting | More data; more augmentation; increase dropout; add weight decay; early stop |
| Both losses high | Model not learning at all | Underfitting: model too small, LR too low, too few epochs | Increase model capacity; increase LR; train longer; check data quality |
| GPU OOM | CUDA out of memory | Batch too large; activations not freed; large sequence length | Reduce batch size; gradient checkpointing; mixed precision; ZeRO |
| Training too slow | GPU util < 70% | Data loading bottleneck; small batch; Python GIL; not using Tensor Cores | More DataLoader workers; pin_memory=True; increase batch; use BF16; persistent_workers=True |
Debugging Checklist
Before You Change Anything
- Verify loss on random weights (should be ~log(num_classes) for CE)
- Overfit on a single batch (loss should → 0)
- Check data loader outputs visually
- Verify labels are correct and not shifted
- Log grad_norm, loss, lr at every step
- Check if model is in train mode (
model.train()) - Confirm optimizer.zero_grad() is called
Distributed Training
Scaling training across multiple GPUs and machines — the architecture of modern LLM training.
| Strategy | What's Split | Communication | Use When |
|---|---|---|---|
| Data Parallelism (DDP) | Data batch split across GPUs, full model on each | All-reduce gradients after each step | Model fits in 1 GPU, want more throughput |
| Tensor Parallelism | Weight matrices split across GPUs (column/row) | All-reduce within forward/backward | Single layer too large for 1 GPU (e.g., linear 65536→65536) |
| Pipeline Parallelism | Model layers split across GPUs (stages) | Activations passed between stages | Deep models, many GPUs per node |
| Sequence Parallelism | Sequence dimension split across GPUs | All-reduce on sequence | Very long sequences (>16k), attention memory |
| ZeRO (FSDP) | Optimizer state, gradients, params sharded | All-gather params when needed | Memory-constrained distributed training |
Data Parallelism (DDP) — The Standard
Each GPU has a full model copy. Each GPU processes a subset of the batch. After backward, gradients are all-reduced (summed + divided) across GPUs — every GPU ends up with the same average gradient. Effectively multiplies batch size by N GPUs.
# PyTorch DDP setup
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(model.to(rank), device_ids=[rank])
# All-reduce happens automatically in .backward()
Communication Overhead
The bottleneck in multi-node training is often not compute — it's gradient synchronization. NVLink (within a node): 600 GB/s. InfiniBand (across nodes): 200–400 Gb/s. For a 7B model, syncing gradients (28GB @ FP16) = ~0.1s over IB. At 10-step gradient accumulation, this is amortized. Overlap compute and communication wherever possible.
3D Parallelism: Large LLM training (GPT-4, LLaMA 65B+) typically uses all three in combination: Tensor Parallel (within node, 8 GPUs), Pipeline Parallel (across nodes, 4–8 stages), Data Parallel (remaining GPUs). This is called "3D parallelism."
Training LLMs vs CNNs vs VLMs
Key differences in training dynamics across model families.
| Property | CNN (ResNet, ConvNeXt) | LLM (GPT, LLaMA) | VLM (CLIP, LLaVA) |
|---|---|---|---|
| Input | Fixed-size image [B,C,H,W] | Token sequence [B,S] | Image + text |
| Main op | Convolution | Causal attention | Cross-modal attention |
| Memory cost driver | Activation maps (large H×W) | KV cache, O(S²) attention | Both + modality projectors |
| Batch size | 256–4096 | 256–4M tokens total | Varies (often smaller) |
| Precision | FP16 or BF16 | BF16 (critical) | BF16 for both modalities |
| Optimizer | SGD (with momentum) or AdamW | AdamW (always) | AdamW |
| LR schedule | Step / Cosine | Warmup + Cosine | Warmup + Cosine, staged |
| Grad stability | Generally stable | Can spike, needs clipping | Modal misalignment causes spikes |
| Dominant cost | Spatial convolutions | Attention O(S²), FFN | Image encoding + alignment |
Sequence Length Impact in LLMs
Going from S=2048 to S=8192: attention memory scales as O(S²) → 16× more attention memory. KV cache during inference = 2 × n_layers × n_heads × d_head × S × dtype_bytes. For LLaMA-7B at S=4096 in FP16: ~2GB KV cache per request. This is why long-context is expensive.
VLM Training Specifics
VLMs (e.g., LLaVA, Flamingo, CLIP) typically use contrastive loss. Key challenge: vision and text encoders are pre-trained separately and must be aligned. Training is often staged: (1) freeze both encoders, train projection layer; (2) unfreeze text encoder, joint fine-tuning; (3) optionally unfreeze vision encoder. Large negatives in contrastive learning are critical — CLIP used batch sizes of 32,768.
Practical Training Pipeline
Production-ready PyTorch training loop with all the essentials.
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
# ─── Setup ───────────────────────────────────────────────────
device = torch.device("cuda")
model = MyModel().to(device)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4,
weight_decay=0.1,
betas=(0.9, 0.95), # β2=0.95 is better for LLMs
)
scaler = GradScaler() # only if using FP16 (not needed for BF16)
criterion = nn.CrossEntropyLoss()
loader = DataLoader(dataset, batch_size=32, num_workers=4,
pin_memory=True, persistent_workers=True)
ACCUM_STEPS = 4
MAX_GRAD_NORM = 1.0
# ─── Training Loop ───────────────────────────────────────────
model.train()
optimizer.zero_grad(set_to_none=True)
for step, (x, y) in enumerate(loader):
x, y = x.to(device), y.to(device)
# Forward pass with mixed precision
with autocast("cuda", dtype=torch.bfloat16):
logits = model(x)
loss = criterion(logits, y) / ACCUM_STEPS
# Backward
scaler.scale(loss).backward()
if (step + 1) % ACCUM_STEPS == 0:
# Unscale before clipping!
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), MAX_GRAD_NORM
)
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Logging
print(f"step={step}, loss={loss.item()*ACCUM_STEPS:.4f}, "
f"grad_norm={grad_norm:.3f}, lr={scheduler.get_last_lr()[0]:.6f}")
DataLoader Best Practices
loader = DataLoader(
dataset,
batch_size=128,
shuffle=True,
num_workers=8, # parallel data loading
pin_memory=True, # faster CPU→GPU transfer
persistent_workers=True,# avoid worker re-init overhead
prefetch_factor=2, # pre-load batches ahead
drop_last=True, # avoid partial batch issues
)
Engineering Tips
High-value production insights — what separates a struggling run from a clean one.
✓ Always Do These
- Log: loss, grad_norm, lr, GPU utilization, tokens/sec — every step or every 10 steps
- Set seeds:
torch.manual_seed(42),numpy.random.seed(42),random.seed(42), andtorch.backends.cudnn.deterministic=Truefor debugging - Use
set_to_none=Trueinoptimizer.zero_grad()— frees memory instead of zeroing - Use
torch.compile(model)in PyTorch 2.0+ — often 30–50% speedup for free - Profile before optimizing: use
torch.profilerto find the actual bottleneck - Use
pin_memory=Truein DataLoader when training on GPU - Save checkpoints regularly with both model state AND optimizer state
✗ Common Mistakes
- Forgetting
optimizer.zero_grad()→ accumulated gradients = effectively higher LR → exploding - Calling
model.eval()during training → dropout off, BatchNorm uses running stats → silent bug - Moving tensors to CPU for logging without
.detach()→ keeps computation graph → memory leak - Using
loss.item()inside training loop every step → triggers CPU-GPU sync → 10–30% slowdown - Forgetting to unscale before gradient clipping in mixed precision → clips scaled gradients
- Using SGD when Adam is clearly better (Transformers always want Adam/AdamW)
- Not normalizing loss by accumulation steps → effective LR = ACCUM_STEPS × intended LR
Speed-Up Tricks
| Trick | Speedup | Cost |
|---|---|---|
| torch.compile() | 20–50% | First step slow (~30s compile) |
| BF16 / Mixed Precision | 2–3× | Minimal accuracy loss |
| FlashAttention-2 | 2–4× for attention | Requires installation |
| Fused optimizers | 10–20% | Use fused=True in AdamW |
| Increase batch size | Linear with GPU util | May need LR scaling |
| Channels-last memory format | 10–40% for CNNs | model.to(memory_format=torch.channels_last) |
Mental Models & Intuition
The way to think about training that makes everything else click.
Training = Error Correction Loop
The model makes a prediction. We measure how wrong it was (loss). We figure out which weights were responsible and to what degree (backprop). We nudge each weight in the direction that would have reduced the error (optimizer step). Repeat. Training is the model progressively learning from its own mistakes.
Gradients = Direction of Improvement
A gradient is not "how much to move" — it's "which direction increases loss fastest." We move in the opposite direction. The magnitude tells us the sensitivity: a large gradient means this weight matters a lot for the current predictions. A zero gradient means this weight doesn't matter right now.
Optimizer = Step Size Strategy
SGD: same step everywhere. Momentum: bigger steps in consistent directions. Adam: smaller steps for noisy parameters, bigger for consistent ones. The optimizer answers: "we know which direction to move, but how boldly should we move?" AdamW adds: "and let's not let weights grow too large."
Loss Landscape = Terrain
Imagine the loss as a mountainous terrain in N-dimensional space. Training is hiking downhill. Flat regions = saddle points, learning stalls. Sharp valleys = good convergence, poor generalization. Flat wide minima = better generalization. The optimizer is your hiking strategy; LR is stride length.
More Key Intuitions
Batch Size = Signal Quality
Small batch = one hiker taking a step based on what they can see locally. Large batch = many hikers averaging their observations. Large batches give more accurate steps but lose the "noisy exploration" that helps find better minima.
Activations = What the Network "Knows"
Each layer's output is a representation of the input at that level of abstraction. Early layers detect edges/phonemes. Middle layers detect shapes/words. Late layers detect objects/semantics. Training = learning which representations are useful for the task.
Code Snippets Reference
Drop-in code for the most common training operations.
Mixed Precision (BF16 — Modern Best Practice)
import torch
# BF16 — no loss scaling needed (A100+, H100)
with torch.autocast("cuda", dtype=torch.bfloat16):
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
Gradient Clipping
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=1.0 # standard for LLMs; try 0.5–2.0
)
optimizer.step()
# grad_norm tells you the norm BEFORE clipping
AdamW Optimizer Setup
# Standard LLM setup
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4,
betas=(0.9, 0.95), # β2=0.95 for LLMs (faster adaptation)
eps=1e-8,
weight_decay=0.1,
fused=True, # fused CUDA kernel — 20% faster
)
# Or: parameter groups (no weight decay on embeddings/norms)
no_decay = ["bias", "LayerNorm.weight", "embedding"]
params = [
{"params": [p for n,p in model.named_parameters()
if not any(nd in n for nd in no_decay)], "weight_decay": 0.1},
{"params": [p for n,p in model.named_parameters()
if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(params, lr=3e-4)
Gradient Checkpointing
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
# Option 1: wrap individual expensive blocks
class TransformerBlock(nn.Module):
def forward(self, x):
return checkpoint(self._forward, x, use_reentrant=False)
# Option 2: enable via HuggingFace
model.gradient_checkpointing_enable()
# Option 3: FSDP with activation checkpointing
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
Checkpoint Saving/Loading
# Save
torch.save({
"step": step,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"scaler": scaler.state_dict(), # FP16 only
"loss": loss.item(),
}, f"ckpt_step_{step}.pt")
# Load
ckpt = torch.load("ckpt_step_1000.pt", map_location=device)
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_step = ckpt["step"]
Quick training sanity checklist: (1) Loss at random init ≈ log(N_classes) for CE? (2) Overfit single batch in <100 steps? (3) Grad norm in 0.1–2.0 range? (4) GPU utilization >85%? (5) No NaN anywhere? If all 5 pass, your training setup is correct — now focus on hyperparameters and data.