Deep Learning Systems · v1.0

Training in
Deep Learning

An engineer's internal reference for understanding, debugging, and optimizing model training — from a single forward pass to distributed multi-GPU systems.

19Sections
Depth
Production-grade
01

Training Loop — End-to-End Mental Model

What actually happens in one training step, tensor by tensor.

A training step is not magic. It is a precise sequence of tensor operations. Every step: data enters, activations flow forward, loss is computed, gradients flow backward, weights change. Repeat ~millions of times.

1
Input → Tokenization / Preprocessing

Raw data (text, images, audio) is converted into tensors. For text: tokenizer maps strings to integer IDs → embedding lookup creates float tensors of shape [batch, seq_len, d_model]. For images: normalization, resizing → [B, C, H, W]. Memory: input tensors and targets are loaded to GPU.

2
Forward Pass

Inputs flow through all layers. Each layer computes output and stores intermediate activations in memory because they'll be needed for backprop. For a 7B LLM: 32 transformer blocks, each caching Q, K, V, attention scores, MLP outputs. This is 60–80% of training memory usage.

3
Loss Computation

Compare model output (logits) to ground truth targets. Returns a scalar loss value. For LLMs: cross-entropy over next-token predictions. The loss is the error signal — all learning flows from this single number.

4
Backpropagation

PyTorch auto-differentiates through the computation graph. Starting from loss, chain rule propagates gradients backward through every operation. Each parameter p gets a gradient p.grad — the direction of steepest increase in loss (we want to go opposite).

5
Gradient Accumulation (optional)

If using gradient accumulation over N steps: gradients from step k are added to existing p.grad without zeroing. Only after N micro-steps do you take an optimizer step. Effect: effective batch size = micro-batch × N × n_gpus.

6
Optimizer Step

The optimizer reads p.grad for every parameter and computes a parameter update. For Adam: maintains running mean (m) and variance (v) of gradients. Update formula: p ← p − lr × m̂/(√v̂ + ε). This is where learning actually happens.

7
Weight Update + Zero Gradients

Weights are now updated. Critical: call optimizer.zero_grad() before next step (or set set_to_none=True for memory efficiency). If you forget, gradients accumulate across steps — a common bug that looks like exploding gradients.

What Lives in Memory During a Step

TensorSizeWhen PresentNotes
Model weightsN_params × dtype_bytesAlways7B @ FP16 = ~14GB
Optimizer state2–4× weightsAlwaysAdam: m + v = 2× weights
ActivationsScales with batch × seq_lenForward → Backward60–80% of training mem
GradientsSame as weightsAfter loss.backward()Can be FP32 even in FP16 training
Input/Target batchSmallDuring stepUsually negligible
💡

Rule of thumb: Total training memory ≈ 16× model parameters in bytes (FP32). For mixed precision: ≈ 6× in FP16 + master FP32 copy. A 7B model needs ~112GB for full FP32 training.

02

Forward Pass Deep Dive

What actually happens layer by layer — and why activations eat your memory.

The forward pass is deterministic given weights and inputs. It is a composed function: output = f_L(f_{L-1}(...f_1(input))). Each f_i is a layer. Each intermediate result is called an activation.

Linear Layers

Most fundamental operation. Given input x: [B, in], weight W: [out, in], bias b: [out]:

y = xWᵀ + b → shape: [B, out]

During backprop, you need: the input x to compute ∂L/∂W = xᵀ ∂L/∂y. This is why x must be stored.

Attention (Transformer)

Self-attention is the expensive operation. For each layer:

Q = XW_Q, K = XW_K, V = XW_V Attn = softmax(QKᵀ / √d_k) · V

Memory cost: The attention matrix QKᵀ is [B, H, S, S] — quadratic in sequence length S. For S=8192: that's 8192² = 67M floats per layer per head. FlashAttention avoids materializing this by fusing the kernel.

⚠️

Attention cost is O(S²) in memory and compute. Doubling sequence length = 4× more attention cost. This is the core bottleneck for long-context LLMs.

Why Activations Are Stored

The Fundamental Trade-off

During backprop, the gradient of layer k needs the output of layer k-1 (chain rule). So every intermediate activation must be stored from the forward pass. This is activation memory — it scales with batch size and model depth.

Gradient checkpointing trades compute for memory: recompute activations during backward instead of storing them. ~33% more compute, ~10× less activation memory.

Layer TypeForward OperationWhat's StoredBackprop Needs
Lineary = Wx + bxx to compute dW
ReLUy = max(0, x)mask (x > 0)mask for gradient gating
BatchNormnormalize + scalex, μ, σ²μ, σ² for gradient
Attentionsoftmax(QKᵀ/√d)VQ, K, V, attn weightsAll 4 tensors
Embeddinglookup tableinput indicesindices to scatter gradients
03

Loss Functions

Choosing the right error signal is as important as the architecture.

Cross Entropy (CE)

L = -Σ y_true · log(softmax(logits))

Use for: classification, language modeling (next token prediction). Why: maximizes log-probability of correct class. Punishes confident wrong predictions severely (log near 0 = huge loss).

Mean Squared Error (MSE)

L = (1/N) Σ (y_pred - y_true)²

Use for: regression tasks, predicting continuous values. Warning: very sensitive to outliers (quadratic). Alternatives: MAE (L1) is more robust, Huber loss is in-between.

Binary Cross Entropy (BCE)

L = -(y·log(p) + (1-y)·log(1-p))

Use for: binary classification, multi-label classification. Each output is an independent Bernoulli. Use BCEWithLogitsLoss in PyTorch for numerical stability.

Contrastive / NT-Xent

L = -log[exp(sim(z,z⁺)/τ) / Σ exp(sim(z,zₙ)/τ)]

Use for: VLMs (CLIP), self-supervised learning. Pulls positive pairs together, pushes negatives apart in embedding space. Temperature τ controls sharpness.

Logits vs Probabilities

Logits = raw unnormalized scores from the model's last linear layer. They can be any real number. Probabilities = after softmax, sum to 1.

💡

Always use logits + fused loss. CrossEntropyLoss in PyTorch applies log-softmax internally using the log-sum-exp trick: log Σ exp(xᵢ) = max(x) + log Σ exp(xᵢ - max(x)). This prevents overflow/underflow from exp(large number). Never do softmax → log → NLLLoss manually — you lose numerical stability.

04

Backpropagation

The engine of learning — how error signals flow backward through a network.

Chain Rule Intuition

Backprop is just the chain rule applied to composed functions. If L = f(g(h(x))), then:

dL/dx = (dL/df) · (df/dg) · (dg/dh) · (dh/dx)

In a neural network: loss depends on outputs which depend on layer n-1 which depends on layer n-2... The gradient of the loss with respect to each weight is a product of local gradients chained together.

Layer-by-Layer Gradient Computation

What PyTorch Does

During forward, PyTorch builds a computation graph (DAG). Every tensor records the operation that created it. loss.backward() traverses this graph in reverse, calling each operation's backward function, accumulating gradients via the chain rule.

Each node has: a forward function and a backward function. The backward function computes local gradients and passes them upstream.

Vanishing & Exploding Gradients

The gradient at layer k involves a product of N weight matrices and activation derivatives. This product can shrink to near-zero (vanishing) or blow up (exploding).

Vanishing Gradients

Cause: Sigmoid/Tanh saturate → derivatives near 0. Deep networks → product of many small numbers → gradient = ~0 at early layers. Early layers learn nothing.

Fix: Use ReLU/GELU, residual connections (ResNet/Transformer), BatchNorm, careful init (Xavier/He).

Exploding Gradients

Cause: Product of weight matrices with large singular values → gradient norm grows exponentially. Common in RNNs, very deep networks, high LR.

Fix: Gradient clipping (clip_grad_norm_), lower LR, BatchNorm, careful init, residual connections.

Residual Connections are the Key Insight

🔑

Residual connections (y = F(x) + x) create gradient highways. During backprop, ∂L/∂x = ∂L/∂y · (∂F/∂x + 1). The +1 term ensures gradients always have a direct path to early layers, regardless of how small ∂F/∂x gets. This is why 100+ layer networks can be trained.

05

Optimizers

How gradients get translated into weight updates — and the hidden state they carry.

SGD (Stochastic Gradient Descent)

θ ← θ - lr · g

The simplest update rule. One gradient step in the negative gradient direction. Problem: sensitive to gradient noise, oscillates in narrow valleys. When to use: CNNs with careful tuning, when you have very large batches with low gradient noise. SGD with momentum is more commonly used.

SGD + Momentum

v ← β·v + g θ ← θ - lr·v

Accumulates a velocity vector in the direction of consistent gradients. Dampens oscillations. β=0.9 is standard — means each step uses 90% of previous velocity + 10% new gradient. Physical intuition: ball rolling downhill with inertia.

Adam (Adaptive Moment Estimation)

m ← β₁·m + (1-β₁)·g (1st moment: mean) v ← β₂·v + (1-β₂)·g² (2nd moment: variance) m̂ = m/(1-β₁ᵗ), v̂ = v/(1-β₂ᵗ) (bias correction) θ ← θ - lr · m̂/(√v̂ + ε)

Per-parameter adaptive learning rates. Dividing by √v̂ normalizes the step size by the gradient's recent variance. Effect: parameters with noisy gradients take smaller steps; parameters with consistent gradients take larger steps. Defaults: β₁=0.9, β₂=0.999, ε=1e-8.

AdamW

θ ← θ - lr · m̂/(√v̂ + ε) - lr·λ·θ
💡

AdamW ≠ Adam + L2 regularization. In Adam, adding L2 to the loss means the weight decay term gets divided by √v̂, making it effectively smaller for noisy parameters. AdamW decouples weight decay from the gradient update — it's applied directly to weights. Always use AdamW for LLMs/Transformers.

Optimizer State Memory

OptimizerExtra State per ParamMemory MultiplierNotes
SGDNone+ 1× if momentum
SGD + Momentumvelocity v
Adam / AdamWm, v (both FP32)3× (FP16 model) or 4× totalState always FP32 for stability
8-bit Adamm, v (quantized)~1.5×bitsandbytes library
06

Learning Rate & Schedulers

The single most important hyperparameter. Get this wrong and nothing else matters.

Why LR Is #1

Learning rate controls how large a step you take in the loss landscape. Too high: overshoot minima, loss diverges or oscillates. Too low: trains too slowly, gets stuck in poor local minima. There is a "Goldilocks zone" — and it changes during training.

LR Too High

Loss spikes or diverges. Training unstable. Weight norms blow up. First thing to check when you see NaN loss. Fix: halve LR, use warmup.

LR Too Low

Loss decreases but very slowly. May plateau early. Model underfits. Signs: training loss barely moving after hours. Fix: LR finder, increase LR, check scheduler.

LR Warmup

Start with very low LR, linearly increase to target LR over N steps. Why: at training start, parameters are random, gradients are large and noisy. High LR on random init → divergence. Warmup lets the model stabilize before taking aggressive steps.

lr(t) = lr_target × (t / warmup_steps) for t ≤ warmup_steps

Rule of thumb: 1–5% of total steps as warmup. For LLMs: 1000–2000 warmup steps common.

Schedulers

SchedulerFormulaShapeBest For
Step Decaylr × γ^⌊epoch/step⌋StaircaseCNNs, simple baselines
Cosine Decaylr × 0.5(1 + cos(π·t/T))Smooth curveLLMs, most modern models
Linear Decaylr × (1 - t/T)Straight line downFine-tuning, simple schedules
OneCycleLRUp then downTriangleFast training (super-convergence)
Constant + CooldownFlat then dropL-shapeLLM pre-training (common now)

Cosine Schedule with Warmup (Modern Default)

def get_lr(step, warmup_steps, max_steps, max_lr, min_lr=1e-5):
    # Linear warmup
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    # Cosine decay
    progress = (step - warmup_steps) / (max_steps - warmup_steps)
    return min_lr + 0.5 * (max_lr - min_lr) * (1 + cos(math.pi * progress))
📌

Debug bad LR: Log loss at every step. If loss: (a) spikes immediately → LR too high at start, add warmup. (b) decreases then plateaus early → LR too high later, check schedule. (c) barely moves → LR too low. (d) oscillates with no trend → LR too high, use momentum.

07

Gradient Mechanics

Controlling how gradients flow — the difference between a stable run and NaN hell.

Gradient Accumulation

When your GPU can't fit a large batch, split it into micro-batches and accumulate gradients before updating weights:

for i, (x, y) in enumerate(loader):
    loss = model(x, y) / accum_steps   # divide to normalize
    loss.backward()                         # grads accumulate in .grad
    if (i + 1) % accum_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Effective batch size = micro_batch_size × accum_steps × n_gpus. Behaves identically to a large batch if you also scale LR proportionally.

Gradient Clipping

Cap the gradient norm to prevent exploding gradients from destroying your training:

if ||g|| > max_norm: g ← g × (max_norm / ||g||)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
MethodWhat It DoesWhen to Use
Norm clippingScale entire gradient vector so its L2 norm ≤ thresholdLLMs, Transformers — standard practice
Value clippingClip each gradient element to [-c, c]Less common, use for specific known distributions
Adaptive clippingTrack gradient norms over time, set threshold dynamicallyWhen you don't know a good threshold upfront
💡

Monitor gradient norms. Log grad_norm at every step. Typical range: 0.5–2.0. If it spikes to 10+ intermittently, you have gradient instability. If it's constantly high (5+), your LR is too high or architecture has instability. If it trends to 0, vanishing gradients.

08

Activations

Non-linearity is what makes neural networks universal approximators.

Without activation functions, stacking linear layers gives you a single linear transformation — no matter how many layers. Activations introduce non-linearity, enabling the network to learn complex functions.

ActivationFormulaRangeUse CasePitfalls
ReLUmax(0, x)[0, ∞)CNNs, hidden layersDying ReLU: if x always < 0, neuron dead forever
GELUx·Φ(x)(-0.17, ∞)Transformers, BERT, GPTSlightly more compute than ReLU
SiLU/Swishx·σ(x)(-0.28, ∞)LLaMA, modern LLMsSimilar to GELU
Sigmoid1/(1+e⁻ˣ)(0, 1)Binary classification outputSaturates at ±∞ → vanishing gradients
Tanh(eˣ-e⁻ˣ)/(eˣ+e⁻ˣ)(-1, 1)RNNs, normalized outputsSaturates at ±1, vanishing gradients
Softmaxexp(xᵢ)/Σexp(xⱼ)(0,1), sum=1Final layer, classificationNever use in hidden layers; numerically unstable raw

Why GELU Won in Transformers

GELU = Gaussian Error Linear Unit. Unlike ReLU's hard gate (0 or pass), GELU has a smooth probabilistic gate: inputs are weighted by how likely they are to be positive under a Gaussian. This smooth gradient behavior empirically outperforms ReLU in transformers. LLaMA uses SiLU (nearly identical) in SwiGLU: SiLU(W₁x) ⊙ W₃x.

⚠️

Saturation = vanishing gradients. Sigmoid derivative = σ(x)(1-σ(x)) → max 0.25 at x=0, nearly 0 at |x| > 3. Stack 10 sigmoid layers: gradient shrinks by 0.25¹⁰ ≈ 0.000001. This is why deep networks failed before ReLU + residual connections.

09

Precision & Training Efficiency

The difference between FP32 and BF16 is not just speed — it fundamentally changes what you can train.

FormatBitsExponentMantissaRangePrecision
FP3232823±3.4×10³⁸~7 decimal digits
FP1616510±65,504~3 decimal digits
BF161687±3.4×10³⁸~2 decimal digits
INT88[-128, 127]Integers only

FP16 vs BF16 — The Critical Distinction

FP16 Problems

Max value 65,504. Gradient values or activations > 65,504 → overflow → NaN. Very small values < 6×10⁻⁵ → underflow → 0. Requires loss scaling to keep gradients in representable range.

BF16 Advantages

Same exponent range as FP32 → no overflow/underflow issues. Drop-in replacement for FP32 without loss scaling. Lower precision (7 mantissa bits vs 23) but in practice models don't care. Default for modern GPU training (A100+, H100).

Mixed Precision Training

Store model weights in FP16/BF16 for speed, but compute loss and keep master copy of weights in FP32 for precision during optimizer step.

FP16 weights
Forward (FP16)
Loss (FP32)
Backward (FP16)
Opt step (FP32)
Copy to FP16

Loss Scaling (for FP16)

Multiply loss by a large scalar S before backward → gradients are S× larger → less underflow. After backward, divide gradients by S before optimizer step. PyTorch GradScaler does this automatically with dynamic scaling — if overflow detected, it halves S and skips the step.

scaler = GradScaler()

with autocast('cuda'):
    output = model(x)
    loss = criterion(output, y)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
💡

Tensor Cores: Nvidia Tensor Cores accelerate matrix multiplications in FP16/BF16/TF32. On A100: 312 TFLOPS in BF16 vs 19.5 TFLOPS in FP32. That's 16× speedup. You must use shapes that are multiples of 8 (ideally 64) for Tensor Core utilization. This is why model dimensions are always 512, 1024, 4096, etc.

10

Memory Optimization Techniques

How to train models that don't fit in GPU memory — the trade-offs that define production LLM training.

TechniqueWhat It DoesMemory SavedCompute CostUse When
Gradient CheckpointingRecompute activations during backward instead of storing~10× activation mem+30% computeLarge batches, deep models
Mixed PrecisionFP16/BF16 weights + activations2× activation + weight mem~0% (often faster)Always
ZeRO Stage 1Shard optimizer state across GPUs4× optimizer memSmall comm overheadMulti-GPU, Adam optimizer
ZeRO Stage 2Shard optimizer state + gradients8× totalModerate commLarge models, many GPUs
ZeRO Stage 3Shard optimizer + grads + parameters~linear with N GPUsHigher commModels too large for one GPU
CPU OffloadingMove optimizer state to CPU RAMGPU: optimizer mem freedPCIe bandwidth bottleneckLimited GPU mem, PCIe 4.0+
Activation OffloadingMove activations to CPU during forwardGPU: activation mem freedHigh bandwidth costExtreme memory pressure

Gradient Checkpointing Deep Dive

How It Works

Without checkpointing: ALL activations from forward pass stored in GPU memory for backward pass. With checkpointing: only store activations at checkpoint boundaries (e.g., every N transformer blocks). During backward, recompute activations between checkpoints on the fly.

In PyTorch:

from torch.utils.checkpoint import checkpoint

output = checkpoint(block, x)  # recomputes block during backward

ZeRO Optimizer (DeepSpeed)

ZeRO = Zero Redundancy Optimizer. In standard DDP, every GPU holds a full copy of optimizer states. For 7B param model with Adam: 7B×4 (FP32 weights) + 7B×4 (m) + 7B×4 (v) = 84GB just for optimizer. ZeRO shards this across N GPUs. With 8 GPUs, that's 84/8 = 10.5GB per GPU.

11

Batch Size & Scaling

Batch size is not just a throughput knob — it fundamentally changes training dynamics.

What Batch Size Controls

Each gradient update is computed on a sample of the data. Larger batches = more accurate gradient estimate (less noise), but also fewer updates per epoch. The gradient noise from small batches acts as a regularizer and helps escape sharp minima.

Small Batches (32–256)

Noisy gradients → acts like regularization. Often finds flatter minima (better generalization). More optimizer steps per epoch. Works poorly on GPUs (low utilization). Good for: fine-tuning, small models, generalization-sensitive tasks.

Large Batches (2048+)

Accurate gradient → fast convergence. Lower variance → may overfit or find sharp minima (worse generalization unless LR scaled). Excellent GPU utilization. May need LR warmup. Good for: LLM pre-training, high-throughput training.

Linear Scaling Rule

📐

When you increase batch size by k×, scale LR by k× to maintain the same training dynamics. Intuition: larger batch → lower gradient noise → gradient is more accurate → you can afford to take a larger step. Caveat: this holds for moderate scaling (≤32×). Very large batches need additional warmup and learning rate tuning.

Effective Batch Size

effective_batch = micro_batch × grad_accum_steps × n_gpus

If micro_batch=4, accum_steps=8, n_gpus=8: effective_batch = 4×8×8 = 256. This is what matters for training dynamics. You can trade any of these three factors against each other based on hardware constraints.

12

Regularization

Overfitting detection and prevention — keeping your model generalizable.

Detecting Overfitting

The signature: training loss continues to decrease but validation loss stagnates or increases. Monitor both curves throughout training. The gap between train and val loss = degree of overfitting.

TechniqueHow It WorksWhen to UseKey Params
DropoutRandomly zero activations during training (force network to not rely on specific neurons)Dense layers, not attention (usually)p=0.1–0.3
Weight Decay (L2)Add λ||θ||² to loss → penalizes large weightsAlways, use AdamWλ=0.01–0.1
Data AugmentationArtificially expand training data with transformsVision: always. NLP: less commonTask-specific
Early StoppingStop when val loss stops improvingWhen compute is limitedpatience=5–20 epochs
Label SmoothingSoften hard targets: instead of [0,1,0], use [0.05, 0.9, 0.05]Classification, LLMsε=0.1

Dropout in Transformers

Modern LLMs (LLaMA, GPT-4) often use very low or zero dropout during pre-training (large dataset + short training = less overfitting risk). Dropout rates of 0.0–0.1 are common. Fine-tuning on small datasets: increase to 0.1–0.3.

Weight Decay Intuition

Without weight decay, nothing prevents weights from growing arbitrarily large (as long as loss decreases). Large weights make the model sensitive to small input perturbations → poor generalization. Weight decay pulls weights back toward zero at each step, keeping the model in a well-behaved regime.

13

Training Stability & Debugging

The practical guide to why training broke — and how to fix it.

IssueSymptomsRoot CauseFix
Loss not decreasing Flat loss from step 1, or plateaus early LR too low; wrong loss function; data/label bug; gradient not flowing Check p.requires_grad; verify data pipeline; LR finder; check for bugs in loss
NaN in training Loss = NaN, gradients = NaN, everything breaks FP16 overflow; LR too high; bad input (inf/nan in data); log(0) Check data for NaN/inf; add loss scaling; clip gradients; reduce LR; switch to BF16
Exploding gradients Loss spikes suddenly; grad norm >>10 LR too high; bad init; no clipping; specific batch with large loss clip_grad_norm_(params, 1.0); reduce LR; add warmup
Loss oscillates Loss goes up and down, no trend LR too high; inconsistent batch sizes; unstable architecture Halve LR; add momentum; check BatchNorm usage
Validation >> Train loss Big gap between train and val Overfitting More data; more augmentation; increase dropout; add weight decay; early stop
Both losses high Model not learning at all Underfitting: model too small, LR too low, too few epochs Increase model capacity; increase LR; train longer; check data quality
GPU OOM CUDA out of memory Batch too large; activations not freed; large sequence length Reduce batch size; gradient checkpointing; mixed precision; ZeRO
Training too slow GPU util < 70% Data loading bottleneck; small batch; Python GIL; not using Tensor Cores More DataLoader workers; pin_memory=True; increase batch; use BF16; persistent_workers=True

Debugging Checklist

Before You Change Anything

  1. Verify loss on random weights (should be ~log(num_classes) for CE)
  2. Overfit on a single batch (loss should → 0)
  3. Check data loader outputs visually
  4. Verify labels are correct and not shifted
  5. Log grad_norm, loss, lr at every step
  6. Check if model is in train mode (model.train())
  7. Confirm optimizer.zero_grad() is called
14

Distributed Training

Scaling training across multiple GPUs and machines — the architecture of modern LLM training.

StrategyWhat's SplitCommunicationUse When
Data Parallelism (DDP)Data batch split across GPUs, full model on eachAll-reduce gradients after each stepModel fits in 1 GPU, want more throughput
Tensor ParallelismWeight matrices split across GPUs (column/row)All-reduce within forward/backwardSingle layer too large for 1 GPU (e.g., linear 65536→65536)
Pipeline ParallelismModel layers split across GPUs (stages)Activations passed between stagesDeep models, many GPUs per node
Sequence ParallelismSequence dimension split across GPUsAll-reduce on sequenceVery long sequences (>16k), attention memory
ZeRO (FSDP)Optimizer state, gradients, params shardedAll-gather params when neededMemory-constrained distributed training

Data Parallelism (DDP) — The Standard

Each GPU has a full model copy. Each GPU processes a subset of the batch. After backward, gradients are all-reduced (summed + divided) across GPUs — every GPU ends up with the same average gradient. Effectively multiplies batch size by N GPUs.

# PyTorch DDP setup
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(model.to(rank), device_ids=[rank])
# All-reduce happens automatically in .backward()

Communication Overhead

The bottleneck in multi-node training is often not compute — it's gradient synchronization. NVLink (within a node): 600 GB/s. InfiniBand (across nodes): 200–400 Gb/s. For a 7B model, syncing gradients (28GB @ FP16) = ~0.1s over IB. At 10-step gradient accumulation, this is amortized. Overlap compute and communication wherever possible.

🏗️

3D Parallelism: Large LLM training (GPT-4, LLaMA 65B+) typically uses all three in combination: Tensor Parallel (within node, 8 GPUs), Pipeline Parallel (across nodes, 4–8 stages), Data Parallel (remaining GPUs). This is called "3D parallelism."

15

Training LLMs vs CNNs vs VLMs

Key differences in training dynamics across model families.

PropertyCNN (ResNet, ConvNeXt)LLM (GPT, LLaMA)VLM (CLIP, LLaVA)
InputFixed-size image [B,C,H,W]Token sequence [B,S]Image + text
Main opConvolutionCausal attentionCross-modal attention
Memory cost driverActivation maps (large H×W)KV cache, O(S²) attentionBoth + modality projectors
Batch size256–4096256–4M tokens totalVaries (often smaller)
PrecisionFP16 or BF16BF16 (critical)BF16 for both modalities
OptimizerSGD (with momentum) or AdamWAdamW (always)AdamW
LR scheduleStep / CosineWarmup + CosineWarmup + Cosine, staged
Grad stabilityGenerally stableCan spike, needs clippingModal misalignment causes spikes
Dominant costSpatial convolutionsAttention O(S²), FFNImage encoding + alignment

Sequence Length Impact in LLMs

Going from S=2048 to S=8192: attention memory scales as O(S²) → 16× more attention memory. KV cache during inference = 2 × n_layers × n_heads × d_head × S × dtype_bytes. For LLaMA-7B at S=4096 in FP16: ~2GB KV cache per request. This is why long-context is expensive.

VLM Training Specifics

VLMs (e.g., LLaVA, Flamingo, CLIP) typically use contrastive loss. Key challenge: vision and text encoders are pre-trained separately and must be aligned. Training is often staged: (1) freeze both encoders, train projection layer; (2) unfreeze text encoder, joint fine-tuning; (3) optionally unfreeze vision encoder. Large negatives in contrastive learning are critical — CLIP used batch sizes of 32,768.

16

Practical Training Pipeline

Production-ready PyTorch training loop with all the essentials.

import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader

# ─── Setup ───────────────────────────────────────────────────
device = torch.device("cuda")
model = MyModel().to(device)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=0.1,
    betas=(0.9, 0.95),    # β2=0.95 is better for LLMs
)
scaler = GradScaler()   # only if using FP16 (not needed for BF16)
criterion = nn.CrossEntropyLoss()
loader = DataLoader(dataset, batch_size=32, num_workers=4,
                    pin_memory=True, persistent_workers=True)

ACCUM_STEPS = 4
MAX_GRAD_NORM = 1.0

# ─── Training Loop ───────────────────────────────────────────
model.train()
optimizer.zero_grad(set_to_none=True)

for step, (x, y) in enumerate(loader):
    x, y = x.to(device), y.to(device)

    # Forward pass with mixed precision
    with autocast("cuda", dtype=torch.bfloat16):
        logits = model(x)
        loss = criterion(logits, y) / ACCUM_STEPS

    # Backward
    scaler.scale(loss).backward()

    if (step + 1) % ACCUM_STEPS == 0:
        # Unscale before clipping!
        scaler.unscale_(optimizer)
        grad_norm = torch.nn.utils.clip_grad_norm_(
            model.parameters(), MAX_GRAD_NORM
        )

        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad(set_to_none=True)

        # Logging
        print(f"step={step}, loss={loss.item()*ACCUM_STEPS:.4f}, "
              f"grad_norm={grad_norm:.3f}, lr={scheduler.get_last_lr()[0]:.6f}")

DataLoader Best Practices

loader = DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    num_workers=8,           # parallel data loading
    pin_memory=True,         # faster CPU→GPU transfer
    persistent_workers=True,# avoid worker re-init overhead
    prefetch_factor=2,       # pre-load batches ahead
    drop_last=True,          # avoid partial batch issues
)
17

Engineering Tips

High-value production insights — what separates a struggling run from a clean one.

✓ Always Do These

  • Log: loss, grad_norm, lr, GPU utilization, tokens/sec — every step or every 10 steps
  • Set seeds: torch.manual_seed(42), numpy.random.seed(42), random.seed(42), and torch.backends.cudnn.deterministic=True for debugging
  • Use set_to_none=True in optimizer.zero_grad() — frees memory instead of zeroing
  • Use torch.compile(model) in PyTorch 2.0+ — often 30–50% speedup for free
  • Profile before optimizing: use torch.profiler to find the actual bottleneck
  • Use pin_memory=True in DataLoader when training on GPU
  • Save checkpoints regularly with both model state AND optimizer state

✗ Common Mistakes

  • Forgetting optimizer.zero_grad() → accumulated gradients = effectively higher LR → exploding
  • Calling model.eval() during training → dropout off, BatchNorm uses running stats → silent bug
  • Moving tensors to CPU for logging without .detach() → keeps computation graph → memory leak
  • Using loss.item() inside training loop every step → triggers CPU-GPU sync → 10–30% slowdown
  • Forgetting to unscale before gradient clipping in mixed precision → clips scaled gradients
  • Using SGD when Adam is clearly better (Transformers always want Adam/AdamW)
  • Not normalizing loss by accumulation steps → effective LR = ACCUM_STEPS × intended LR

Speed-Up Tricks

TrickSpeedupCost
torch.compile()20–50%First step slow (~30s compile)
BF16 / Mixed Precision2–3×Minimal accuracy loss
FlashAttention-22–4× for attentionRequires installation
Fused optimizers10–20%Use fused=True in AdamW
Increase batch sizeLinear with GPU utilMay need LR scaling
Channels-last memory format10–40% for CNNsmodel.to(memory_format=torch.channels_last)
18

Mental Models & Intuition

The way to think about training that makes everything else click.

Training = Error Correction Loop

The model makes a prediction. We measure how wrong it was (loss). We figure out which weights were responsible and to what degree (backprop). We nudge each weight in the direction that would have reduced the error (optimizer step). Repeat. Training is the model progressively learning from its own mistakes.

Gradients = Direction of Improvement

A gradient is not "how much to move" — it's "which direction increases loss fastest." We move in the opposite direction. The magnitude tells us the sensitivity: a large gradient means this weight matters a lot for the current predictions. A zero gradient means this weight doesn't matter right now.

Optimizer = Step Size Strategy

SGD: same step everywhere. Momentum: bigger steps in consistent directions. Adam: smaller steps for noisy parameters, bigger for consistent ones. The optimizer answers: "we know which direction to move, but how boldly should we move?" AdamW adds: "and let's not let weights grow too large."

Loss Landscape = Terrain

Imagine the loss as a mountainous terrain in N-dimensional space. Training is hiking downhill. Flat regions = saddle points, learning stalls. Sharp valleys = good convergence, poor generalization. Flat wide minima = better generalization. The optimizer is your hiking strategy; LR is stride length.

More Key Intuitions

Batch Size = Signal Quality

Small batch = one hiker taking a step based on what they can see locally. Large batch = many hikers averaging their observations. Large batches give more accurate steps but lose the "noisy exploration" that helps find better minima.

Activations = What the Network "Knows"

Each layer's output is a representation of the input at that level of abstraction. Early layers detect edges/phonemes. Middle layers detect shapes/words. Late layers detect objects/semantics. Training = learning which representations are useful for the task.

19

Code Snippets Reference

Drop-in code for the most common training operations.

Mixed Precision (BF16 — Modern Best Practice)

import torch

# BF16 — no loss scaling needed (A100+, H100)
with torch.autocast("cuda", dtype=torch.bfloat16):
    output = model(x)
    loss = criterion(output, y)
loss.backward()
optimizer.step()

Gradient Clipping

loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    max_norm=1.0   # standard for LLMs; try 0.5–2.0
)
optimizer.step()
# grad_norm tells you the norm BEFORE clipping

AdamW Optimizer Setup

# Standard LLM setup
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    betas=(0.9, 0.95),    # β2=0.95 for LLMs (faster adaptation)
    eps=1e-8,
    weight_decay=0.1,
    fused=True,            # fused CUDA kernel — 20% faster
)

# Or: parameter groups (no weight decay on embeddings/norms)
no_decay = ["bias", "LayerNorm.weight", "embedding"]
params = [
    {"params": [p for n,p in model.named_parameters()
               if not any(nd in n for nd in no_decay)], "weight_decay": 0.1},
    {"params": [p for n,p in model.named_parameters()
               if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(params, lr=3e-4)

Gradient Checkpointing

from torch.utils.checkpoint import checkpoint, checkpoint_sequential

# Option 1: wrap individual expensive blocks
class TransformerBlock(nn.Module):
    def forward(self, x):
        return checkpoint(self._forward, x, use_reentrant=False)

# Option 2: enable via HuggingFace
model.gradient_checkpointing_enable()

# Option 3: FSDP with activation checkpointing
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

Checkpoint Saving/Loading

# Save
torch.save({
    "step": step,
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scheduler": scheduler.state_dict(),
    "scaler": scaler.state_dict(),   # FP16 only
    "loss": loss.item(),
}, f"ckpt_step_{step}.pt")

# Load
ckpt = torch.load("ckpt_step_1000.pt", map_location=device)
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_step = ckpt["step"]
💡

Quick training sanity checklist: (1) Loss at random init ≈ log(N_classes) for CE? (2) Overfit single batch in <100 steps? (3) Grad norm in 0.1–2.0 range? (4) GPU utilization >85%? (5) No NaN anywhere? If all 5 pass, your training setup is correct — now focus on hyperparameters and data.