Foundations & Motivation
What is Fine-Tuning?
Fine-tuning is the process of continuing training a pretrained model on a curated, task-specific dataset to adapt its behavior. A base LLM (e.g., Llama-3, Mistral, Qwen) has already internalized language structure, world knowledge, and general reasoning from pretraining on trillions of tokens. Fine-tuning reshapes the top layers (or all layers) of this model to specialize for your distribution.
Mechanically: you load a checkpoint, provide a dataset in the model's expected format, and run gradient-descent updates (usually with a smaller learning rate than pretraining). The result is a model that behaves differently — more in-domain, more instruction-following, or more aligned with your preferences.
Pretraining gives the model a "world model." Fine-tuning gives it a "job description." You're not teaching it new facts — you're adjusting how it uses what it already knows.
Fine-Tuning vs. Prompting vs. RAG
| Dimension | Prompting / Few-shot | RAG | Fine-Tuning |
|---|---|---|---|
| Knowledge injection | Limited (context window) | Dynamic, always fresh | Static at training time |
| Latency | Lowest | +retrieval overhead | Low (no extra context) |
| Cost per inference | High (long prompts) | Medium | Low token count |
| Style/format control | Moderate | Moderate | Excellent |
| Domain behavior | Surface-level | Data-dependent | Deep adaptation |
| Maintenance | Easy | Index updates needed | Retrain for updates |
| Data needed | 0–10 examples | Documents to index | 100s–100k+ examples |
| Hallucination risk | Moderate | Lower (grounded) | Risk of overfitting |
When Fine-Tuning is the Right Choice
- You need consistent output format/structure (JSON, structured reports) and prompting is unreliable
- You need the model to internalize domain terminology deeply (medical, legal, industrial)
- Inference cost is critical — removing system prompts saves real money at scale
- You need behavioral guarantees (e.g., never produce X, always respond in Y style)
- Latency is tight and you can't afford long context windows
- Base model consistently fails even with good prompts on your task
When Fine-Tuning is NOT the Right Choice
- Your data is dynamic and changes frequently — RAG is better
- You have fewer than ~100 high-quality examples — prompting first
- You're in early product stage — optimize prompts first, FT later
- You need to add new factual knowledge — FT doesn't reliably inject facts
- Your base model already does the task reasonably well with prompting
True Costs of Fine-Tuning
Fine-Tuning Methods
What it is: All model parameters are updated during training. The optimizer maintains states for every weight in the network. Maximum expressiveness — the model can change fundamentally.
When to Use
- Large budget, maximum performance required
- You have 100k+ high-quality examples
- Task requires deep behavioral change from base model
- You need the absolute best model and will serve it at scale
- Continual pretraining on new domain corpus
When NOT to Use
- Small dataset (<5k samples) — you'll overfit catastrophically
- Limited GPU budget — need 16-80 A100s for 7B-70B models
- Quick iteration needed — 1 training run takes hours to days
- When you need to maintain base capabilities (catastrophic forgetting risk)
Pros
- Maximum task performance ceiling
- Full adaptation of all layers
- No inference overhead
- Best for domain shift
Cons
- Massive GPU VRAM requirement
- Catastrophic forgetting risk
- Slow iteration cycle
- Storage: full model copy per FT
VRAM Estimate
Rule of thumb: ~6× model size in FP32 training VRAM 7B model FP32 → ~168 GB VRAM (4× A100 80GB minimum) 7B model BF16 + gradient checkpointing → ~80 GB
What it is: Instead of updating weight matrix W directly, LoRA decomposes the update ΔW = A×B where A ∈ ℝ^(d×r) and B ∈ ℝ^(r×k), with rank r << d. Only A and B are trained. This reduces trainable parameters by 100-10000× while preserving most of the expressiveness for task-specific adaptation.
# LoRA applied to attention projection layers
# Original: W ∈ R^(4096 × 4096) = 16.7M params
# LoRA r=16: A(4096×16) + B(16×4096) = 131K params → 127× reduction
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
r=16, # rank — start with 16, tune up for harder tasks
lora_alpha=32, # scaling = alpha/r = 2.0
target_modules=[ # which projection matrices to adapt
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj" # MLP too
],
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 83,886,080 || all params: 8,030,261,248 || 1.04%
Pros
- 1-3% of full FT VRAM
- Fast iteration (<1hr on 7B)
- Multiple adapters, one base
- Mergeable back into base weights
- Strong performance — often 95%+ of full FT
Cons
- Suboptimal for deep domain shift
- Extra latency if not merged
- Rank selection is a hyperparameter
- Hard to adapt very new concepts
LoRA Rank Selection Guide
| Rank (r) | Use Case | Trainable % | Typical Alpha |
|---|---|---|---|
| 4 | Simple format/style change, classification | ~0.1% | 8–16 |
| 8 | Instruction following, chat adaptation | ~0.5% | 16–32 |
| 16 | Domain adaptation, complex tasks (default) | ~1% | 32–64 |
| 32 | Hard tasks, specialized domains | ~2% | 64 |
| 64+ | Approaching full FT territory, high complexity | ~4% | 64–128 |
What it is: QLoRA loads the base model in 4-bit NormalFloat (NF4) quantization, then trains LoRA adapters in BF16. The frozen quantized model uses 4 bits/param, while LoRA updates happen in full precision. Double quantization compresses quantization constants themselves. Result: fine-tune a 65B model on a single A100 80GB.
from transformers import BitsAndBytesConfig
import torch
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True, # saves ~0.4 bits/param more
bnb_4bit_quant_type="nf4", # NF4 > int4 for normal distributions
bnb_4bit_compute_dtype=torch.bfloat16 # compute in BF16
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto"
)
# prepare_model_for_kbit_training handles gradient checkpointing
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)
# Then add LoRA on top as usual
Pros
- 70B model on 1× A100 80GB
- Near-LoRA quality
- Democratizes large model FT
Cons
- ~30% slower training than LoRA
- Quantization can hurt rare token performance
- bitsandbytes dependency (limited hardware)
In practice, QLoRA delivers ~98% of LoRA quality. The 2% loss comes from quantization noise in the frozen base. For most production tasks this is acceptable, but for high-stakes applications run a comparison eval before committing.
What it is: Bottleneck adapter modules are inserted after each transformer sub-layer (attention, FFN). The adapter projects to a lower-dimensional space and back: h → W_down → activation → W_up → h. Original weights are frozen. AdapterHub popularized this; MAD-X extended it for multilingual work.
Modern status: LoRA has largely displaced adapters for most use cases because LoRA achieves similar parameter efficiency with lower inference latency (adapters add new layers; LoRA modifies existing ones).
Pros
- Very modular — swap adapters per task
- Composable: stack/merge adapters
- Clean separation of concerns
Cons
- Added inference latency (new layers)
- Harder to merge into base weights
- Less parameter efficient than LoRA in practice
When to still use adapters: When you need to serve many tasks from one base model and swap adapters dynamically at runtime without reloading the model. Example: multi-tenant SaaS with per-customer behavior.
Prefix Tuning: Prepends trainable continuous vectors to the key/value states of every attention layer. These "virtual tokens" guide model behavior without touching any weights. Trained via backprop through frozen model.
Prompt Tuning: Simpler — only prepends trainable tokens to the input embedding layer (not every layer). Scales well with model size (at 10B+ it approaches full FT performance).
Pros
- Extremely few trainable params (<0.1%)
- Base model completely frozen
- Multiple soft prompts per base
Cons
- Underperforms LoRA on smaller models
- Training instability
- Hard to interpret/debug
- Inference overhead (prepended tokens)
In most real-world pipelines, reach for LoRA over prefix/prompt tuning. Prefix tuning is compelling only when you have strict constraints on modifying the base model and need ultra-low parameter overhead.
What it is: Supervised fine-tuning on (instruction, response) pairs to teach a base model to follow human instructions. This is how base models become chat/instruct models. Think: Alpaca (Stanford), FLAN, WizardLM, OpenHermes. Usually implemented as LoRA or full FT over instruction-formatted data.
// Instruction dataset format (Alpaca-style)
{
"instruction": "Extract all machine names from the following BOM.",
"input": "CNC Lathe x2, Deburring Station x1, CMM x1...",
"output": "[\"CNC Lathe\", \"Deburring Station\", \"CMM\"]"
}
// Chat format (preferred for modern models)
{
"messages": [
{"role": "system", "content": "You are a factory planning expert."},
{"role": "user", "content": "What machines do I need for TAKT time of 45s?"},
{"role": "assistant", "content": "For a 45s TAKT time with the following operations..."}
]
}
Quality Tiers
- Tier 1: Human-written, expert-verified pairs. Most expensive, best quality. <1k pairs can outperform 100k Tier 3.
- Tier 2: LLM-generated, human-reviewed. Good balance of cost and quality.
- Tier 3: Fully synthetic, no review. Risk of hallucinations propagating into model.
What it is: Running next-token prediction (the original pretraining objective) on new domain text to inject domain knowledge into the model's weights. Not supervised pairs — just raw text corpora. Examples: BioMedLM (PubMed), CodeLlama (code), Legal-BERT (court documents).
When this is the right move: When your domain has highly specialized terminology and reasoning patterns that diverge significantly from general web text. "The base model doesn't even know what half these words mean" is the signal.
Typical Pipeline
- Collect domain corpus (10M–10B tokens)
- Continual pretrain (full FT or LoRA) on raw text — next token prediction loss
- Then instruction-tune on task-specific pairs
- Optional: DPO/RLHF alignment step
Continual pretraining with high LR on domain-only data will degrade general capabilities. Use learning rate warm-up, add 5-10% general domain data to the mix, and use LoRA if you need to preserve base behavior.
What it is: A 3-stage pipeline: (1) SFT on demonstration data, (2) train a reward model (RM) on human preference comparisons (chosen vs rejected), (3) optimize the SFT model against the RM using PPO. This is how ChatGPT, Claude, and Gemini are aligned.
Standard instruction tuning on human demonstrations. Baseline behavior.
Train a model to score responses as humans would. Bradley-Terry model on preference pairs.
RL loop: generate → score → update. KL penalty prevents over-optimization (reward hacking).
When to use RLHF in production: Rarely, unless you're building a general-purpose assistant at scale. The pipeline is fragile, requires careful reward model curation, and PPO training is notoriously unstable. DPO is almost always preferred for specialized tasks.
Without a strong KL penalty, the policy model will find exploits in the reward model — verbose but empty responses, specific trigger phrases that score high. Always include KL divergence regularization from the SFT reference model.
What it is: DPO reformulates RLHF as a classification problem — no separate reward model, no RL loop. Given preference pairs (chosen, rejected), DPO directly maximizes the likelihood of chosen responses while minimizing rejected ones, using a closed-form equivalence to the RLHF objective.
from trl import DPOTrainer, DPOConfig
# Dataset format: each row has prompt, chosen, rejected
dataset_format = {
"prompt": "What is the optimal TAKT time for this line?",
"chosen": "The optimal TAKT time is calculated by dividing...",
"rejected": "TAKT time is when you make things fast."
}
training_args = DPOConfig(
beta=0.1, # KL regularization strength (0.05–0.5)
learning_rate=5e-7, # very low — DPO is sensitive
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
num_train_epochs=1, # often 1-3 epochs is enough
output_dir="./dpo-output"
)
trainer = DPOTrainer(
model=model,
ref_model=ref_model, # SFT model as reference
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
Pros
- No reward model needed
- No RL instability (PPO-free)
- Simple to implement with TRL
- Excellent for format/tone alignment
Cons
- Requires paired preference data
- Sensitive to beta hyperparameter
- Less explored than RLHF for general alignment
DPO in Practice
DPO is the go-to for teams that want alignment-style improvements without the RLHF machinery. Use it to: reduce hallucinations (rejected=wrong answers), improve response format (rejected=unstructured), enforce safety (rejected=harmful outputs). Typical dataset: 1,000–50,000 preference pairs.
Method Selection Framework
Decision Framework
Rules of Thumb — Quick Reference
GPU × Method × Model Size Matrix
| Model Size | Method | Min VRAM | Recommended Setup | Approx. Cost/run |
|---|---|---|---|---|
| 7B | QLoRA | 10 GB | 1× RTX 3090/4090 | $2–10 |
| 7B | LoRA | 20 GB | 1× A100 40GB | $5–20 |
| 7B | Full FT | 80 GB | 1× A100 80GB + ZeRO-3 | $20–100 |
| 13B | QLoRA | 16 GB | 1× A100 40GB | $5–20 |
| 13B | LoRA | 40 GB | 2× A100 40GB | $10–40 |
| 34B | QLoRA | 40 GB | 1× A100 80GB | $20–80 |
| 70B | QLoRA | 80 GB | 1× A100 80GB | $50–200 |
| 70B | LoRA | 160 GB | 2× A100 80GB | $100–400 |
Dataset Design
100 high-quality, diverse, expert-verified examples will outperform 10,000 noisy examples 90% of the time. Data quality is the biggest lever you have. Invest here before anywhere else.
Dataset Formats
Instruction Format (Alpaca-style)
{
"instruction": "Classify the following manufacturing defect.",
"input": "Surface crack detected on bearing race, 2mm depth",
"output": "Defect type: Surface crack | Severity: High | Action: Reject"
}
Chat Format (Preferred for instruct models)
{
"messages": [
{
"role": "system",
"content": "You are an industrial manufacturing AI assistant specializing in factory planning. Always provide quantitative answers with units."
},
{
"role": "user",
"content": "Our demand is 800 units/day with 8-hour shifts. Calculate TAKT time."
},
{
"role": "assistant",
"content": "TAKT Time = Available Production Time / Customer Demand\n= (8 hours × 3600 sec/hr) / 800 units\n= 28,800 sec / 800\n= **36 seconds per unit**\n\nThis means one unit must exit the line every 36 seconds to meet demand."
}
]
}
Classification Format
// For single-label classification, frame as generation:
{
"messages": [
{"role": "system", "content": "Classify support tickets. Reply with ONLY the category label."},
{"role": "user", "content": "Machine stopped mid-cycle, error code E-404"},
{"role": "assistant", "content": "MECHANICAL_FAULT"}
]
}
RAG-Augmented Format
{
"messages": [
{"role": "system", "content": "Answer based ONLY on the provided context. Say 'Not found in context' if the answer is not there."},
{
"role": "user",
"content": "Context: [Machine spec sheet content here]\n\nQuestion: What is the cycle time of the CNC lathe?"
},
{"role": "assistant", "content": "According to the spec sheet, the CNC lathe has a cycle time of 45 seconds per operation."}
]
}
Dataset Size Guidelines
| Task Type | Minimum Viable | Good Quality | Strong Performance |
|---|---|---|---|
| Classification (few labels) | 50/class | 200/class | 1000+/class |
| Structured extraction | 200 | 1,000 | 5,000+ |
| Instruction following | 500 | 5,000 | 50,000+ |
| Conversational / chat | 1,000 | 10,000 | 100,000+ |
| Domain pretraining | 10M tokens | 100M tokens | 1B+ tokens |
| DPO preference pairs | 500 | 5,000 | 50,000+ |
| Code generation | 1,000 | 10,000 | 100,000+ |
Synthetic Data Generation
When real data is scarce, LLMs can bootstrap datasets. This is now standard practice.
# Synthetic data generation pipeline
import openai
from pydantic import BaseModel
from typing import List
import json
class TrainingExample(BaseModel):
instruction: str
context: str
response: str
quality_score: float # 0-1, filter below 0.7
SEED_EXAMPLES = [
"Calculate TAKT time for 500 units/day, 2-shift operation",
"Identify bottleneck station from cycle times: [45s, 38s, 62s, 41s]"
]
def generate_variations(seed: str, n: int = 10) -> List[TrainingExample]:
prompt = f"""Generate {n} variations of this manufacturing planning question.
Vary: numbers, machine types, complexity, phrasing.
Seed: {seed}
Return JSON array of objects with keys: instruction, context, response, quality_score"""
response = openai.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"}
)
examples = json.loads(response.choices[0].message.content)
return [TrainingExample(**e) for e in examples["examples"] if e["quality_score"] > 0.7]
# Key: always review a random sample before training on synthetic data
Common Dataset Mistakes
- Training on model outputs without filtering: Model learns "confident wrongness" — hallucinations and formatting artifacts from the teacher model
- Low diversity: 1000 examples that are all minor variations of 10 templates. Model overfits to surface patterns.
- Length mismatch: Short inputs/outputs in training, long at inference. Model degrades on long outputs.
- Label leakage: Including the answer in the prompt/context during training. Model learns to pattern-match rather than reason.
- System prompt inconsistency: Training with 5 different system prompts. Model sees all as equally valid and gets confused at inference.
- Skipping data validation: Not checking for duplicates, near-duplicates, encoding issues, or truncated responses.
Data Quality Checklist
- Each response was written or reviewed by a domain expert
- Deduplicated on both input and output side (MinHash or embedding similarity)
- Length distribution of training data matches expected inference distribution
- No label leakage — answer not present in input
- System prompt is consistent and matches inference system prompt exactly
- All outputs end properly (no truncation)
- Random sample of 50 examples manually reviewed before training
- Held-out validation set (10-20%) from same distribution but not used in training
Training Pipeline
Complete Pipeline Code
#!/usr/bin/env python3
"""
Complete QLoRA fine-tuning pipeline
Production-ready with all best practices
"""
import torch
from datasets import load_dataset, Dataset
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, TrainingArguments
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import wandb
# ─── CONFIG ────────────────────────────────────────
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
OUTPUT_DIR = "./output/factory-planning-v1"
DATASET_PATH = "./data/training.jsonl"
# ─── 1. LOAD & VALIDATE DATA ───────────────────────
def load_and_validate_dataset(path: str) -> Dataset:
ds = load_dataset("json", data_files=path, split="train")
# Validate required fields
assert "messages" in ds.column_names, "Dataset must have 'messages' column"
# Filter empty responses
ds = ds.filter(lambda x: len(x["messages"][-1]["content"]) > 10)
print(f"Loaded {len(ds)} training examples")
return ds.train_test_split(test_size=0.1, seed=42)
# ─── 2. MODEL LOADING ──────────────────────────────
def load_model_and_tokenizer(model_id: str):
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
attn_implementation="flash_attention_2", # FA2 for speed
trust_remote_code=True
)
model = prepare_model_for_kbit_training(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # important for causal LM
return model, tokenizer
# ─── 3. LORA CONFIG ────────────────────────────────
def get_lora_config() -> LoraConfig:
return LoraConfig(
r=16,
lora_alpha=32,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# ─── 4. TOKENIZATION ───────────────────────────────
def format_example(example: dict, tokenizer) -> dict:
# Apply chat template — critical to match inference behavior
formatted = tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
add_generation_prompt=False
)
return {"text": formatted}
# ─── 5. TRAINING ───────────────────────────────────
def train():
wandb.init(project="factory-planning-ft", name="qlora-v1")
dataset = load_and_validate_dataset(DATASET_PATH)
model, tokenizer = load_model_and_tokenizer(MODEL_ID)
model = get_peft_model(model, get_lora_config())
model.print_trainable_parameters()
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=3,
per_device_train_batch_size=2,
per_device_eval_batch_size=4,
gradient_accumulation_steps=8, # effective batch = 16
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
bf16=True,
tf32=True,
gradient_checkpointing=True,
logging_steps=10,
eval_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=100,
save_total_limit=3,
load_best_model_at_end=True,
report_to="wandb",
optim="paged_adamw_8bit", # paged optimizer for QLoRA
dataloader_num_workers=4,
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
args=training_args,
formatting_func=lambda x: [format_example(ex, tokenizer)["text"] for ex in x],
max_seq_length=2048,
packing=True, # pack multiple short examples into one sequence
)
trainer.train()
trainer.save_model(OUTPUT_DIR)
if __name__ == "__main__":
train()
Merging LoRA Weights Post-Training
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load base model in full precision for merge
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="cpu" # merge on CPU to avoid OOM
)
# Load LoRA adapter
peft_model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
# Merge — this produces a standard HF model, no PEFT dependency at inference
merged_model = peft_model.merge_and_unload()
# Save merged model
merged_model.save_pretrained("./merged-model")
tokenizer.save_pretrained("./merged-model")
Hyperparameters & Training Details
Learning Rate
The single most impactful hyperparameter. Too high → training instability, loss spikes, degraded outputs. Too low → slow convergence, wasted compute.
| Method | LR Range | Default Start | Notes |
|---|---|---|---|
| Full Fine-Tuning | 1e-6 – 1e-5 | 2e-5 | Lower than pretraining by ~100× |
| LoRA | 1e-4 – 5e-4 | 2e-4 | Higher OK since fewer params |
| QLoRA | 1e-4 – 3e-4 | 2e-4 | Same as LoRA in practice |
| DPO | 5e-8 – 5e-6 | 5e-7 | Very low — preference shifts are subtle |
| Continual Pretraining | 1e-5 – 1e-4 | 5e-5 | Use cosine decay to zero |
If loss goes to NaN immediately: LR too high. If loss barely moves after 100 steps: LR too low. If loss decreases then suddenly spikes: use gradient clipping (max_norm=1.0) and lower LR by 3×.
All Key Hyperparameters
LoRA-Specific Parameters Deep Dive
Rank (r)
Controls expressiveness. r=16 is the default sweet spot for most tasks. Higher rank = more parameters = potentially better performance but diminishing returns beyond r=64. For very simple tasks (format changes), r=4–8 is sufficient and faster to train.
Alpha (lora_alpha)
Scaling factor applied to LoRA output: scale = alpha / r. Common practice: set alpha = 2 × r (scale=2.0) for a good starting point. Some practitioners use alpha = r (scale=1.0). Do not treat alpha independently — always think in terms of scale.
Target Modules
Which matrices to apply LoRA to. More modules = more expressive but more memory.
- Attention only (
q, k, v, o): Safe default, lower VRAM - Attention + MLP (add
gate, up, down): Better for complex tasks, recommended - All linear layers: Maximum expressiveness, use for hard domain shifts
Dropout
lora_dropout=0.05 is standard. Increase to 0.1 for small datasets (<1k). Set to 0 for inference (handled by PEFT automatically).
Sequence Packing
# Without packing: many short sequences → wasted padding → slow
# [seq1|PAD|PAD|PAD] [seq2|PAD|PAD] [seq3|PAD|PAD|PAD|PAD]
# With packing: concatenate up to max_seq_len → GPU fully utilized
# [seq1|EOS|seq2|EOS|seq3|EOS|...] ← single batch item, all real tokens
# In SFTTrainer:
trainer = SFTTrainer(
...
packing=True, # enable sequence packing
max_seq_length=2048 # pack multiple short examples into this length
)
# Tip: packing can increase throughput 2-4× for short sequence datasets
Compute & GPU Planning
VRAM Estimation
Understanding VRAM consumption is critical for planning training runs. The breakdown for a training job:
FP16/BF16 Model: N_params × 2 bytes FP32 Model: N_params × 4 bytes AdamW states: N_trainable × 8 bytes (2 momentum terms × FP32) Gradients: N_trainable × 4 bytes Activations: batch × seq_len × hidden × layers (use grad checkpointing to reduce)
QLoRA 7B total: ~6–8 GB (with grad checkpointing) LoRA 7B total: ~18–22 GB Full FT 7B BF16: ~60–80 GB (with grad checkpointing + ZeRO)
Gradient Checkpointing
Trades compute for memory — activations are recomputed during backward pass instead of stored. Reduces activation VRAM by 60-70% at the cost of ~30% slower training. Always enable for large models.
# In TrainingArguments:
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False} # newer, more stable
# Or directly on model:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
Multi-GPU Training
Data Parallelism (DDP) — Most Common
Each GPU holds a full model copy. Gradients are synchronized across GPUs. Best when model fits in single GPU VRAM. Linear scaling of throughput.
# Launch DDP training with torchrun
torchrun --nproc_per_node=4 train.py
# Or with accelerate
accelerate launch --num_processes 4 train.py
ZeRO (Zero Redundancy Optimizer)
DeepSpeed's ZeRO shards optimizer states (ZeRO-1), gradients (ZeRO-2), or model parameters (ZeRO-3) across GPUs. Required for full fine-tuning of large models on limited hardware.
| Stage | What's Sharded | Memory Reduction | Communication |
|---|---|---|---|
| ZeRO-1 | Optimizer states | ~4× | Low |
| ZeRO-2 | Optimizer + Gradients | ~8× | Medium |
| ZeRO-3 | Optimizer + Grad + Params | ~64× | High |
Precision Guide
| Precision | Bytes/Param | Use For | Notes |
|---|---|---|---|
| FP32 | 4 | Optimizer states, gradient accumulation | Ground truth precision |
| BF16 | 2 | Training on Ampere+ GPUs (A100, H100) | Preferred for training — same range as FP32 |
| FP16 | 2 | Training on Volta/Turing (V100, T4) | Risk of overflow — needs loss scaling |
| INT8 | 1 | Inference, some FT (LoRA adapters still BF16) | bitsandbytes LLM.int8() |
| NF4/INT4 | 0.5 | QLoRA training, inference | State of the art for memory efficiency |
| FP8 | 1 | H100 training/inference | Transformer Engine, highest throughput |
Cost Optimization Strategies
- Use spot/preemptible instances: 60-80% cost reduction. Use checkpointing every 50-100 steps to survive interruptions.
- Start with small experiments: Train for 100 steps on 10% of data to validate the pipeline before full runs.
- Sequence packing: 2-4× throughput improvement for datasets with short sequences.
- Flash Attention 2: 2-4× memory efficiency for long sequences, significant speedup.
- Choose right model size: Don't train 70B if 13B achieves 95% of the performance on your task.
- Paged AdamW: Reduces optimizer VRAM by offloading to CPU pages. Minimal throughput impact for LoRA.
Evaluation & Debugging
Evaluation Framework
Evaluation is where most fine-tuning projects fail. Perplexity on the training distribution tells you almost nothing about production performance. Build task-specific evals from day one.
| Task Type | Primary Metrics | Secondary Metrics | Pitfalls |
|---|---|---|---|
| Classification | F1 (weighted/macro) | Precision, Recall, Confusion matrix | Class imbalance inflating accuracy |
| Extraction / NER | Exact match, F1 | Partial match F1 | Format errors counted as wrong |
| Generation | LLM-as-judge (1-5 scale) | BLEU/ROUGE (weak signal) | Reference-based metrics miss quality |
| QA / RAG | Exact match, EM@k | Faithfulness score | Not testing hallucination on OOD |
| Structured output | JSON validity %, field accuracy | Schema compliance rate | Only testing valid examples |
| Chat / alignment | Human preference rate | Helpfulness, harmlessness scores | Positional bias in pairwise eval |
LLM-as-Judge (Practical)
def llm_judge_response(question: str, ground_truth: str, model_response: str) -> dict:
prompt = f"""Rate this response on a scale of 1-5.
Question: {question}
Reference Answer: {ground_truth}
Model Response: {model_response}
Score on:
1. Correctness (1-5): Does it answer correctly?
2. Completeness (1-5): Does it cover all key points?
3. Format (1-5): Is it well-structured?
Return JSON: {{"correctness": X, "completeness": X, "format": X, "reasoning": "..."}}"""
# Use a capable judge model (GPT-4o, Claude Sonnet, etc.)
# Key: the judge model should be stronger than the model being evaluated
...
Failure Modes & Fixes
Overfitting
Symptoms: Training loss continuously decreasing, val loss diverging. Model memorizes training examples verbatim. Produces gibberish on OOD inputs.
Fixes: Reduce epochs. Increase weight_decay. Add dropout (lora_dropout=0.1). Augment dataset. Use early stopping on val loss.
Catastrophic Forgetting
Symptoms: Model excels on fine-tuned task but fails on tasks it could do before (general QA, reasoning, instruction following). Often happens with full FT or high-rank LoRA on small datasets.
Fixes: Switch to LoRA. Add 10-20% general instruction data to training mix (replay buffer). Lower learning rate. Fewer epochs.
Output Degeneration
Symptoms: Model produces repetitive text, infinite loops, or garbage tokens. Often caused by training on poorly formatted or truncated examples.
Fixes: Audit training data for truncated outputs. Ensure EOS tokens are present. Check padding_side setting. Validate tokenization pipeline outputs.
Format Non-Compliance
Symptoms: Model was supposed to output JSON but adds markdown, prose, or extra explanation around it.
Fixes: Check that system prompt in training exactly matches inference. Add more format-only examples. Use constrained decoding (outlines library) at inference as safety net.
Debugging Checklist
- Inspect 10 random tokenized training examples — does the formatted text look right?
- Check that
labelshave -100 for prompt tokens (only train on completion) - Verify training loss decreases in first 50 steps. If not: LR too low or data issue
- Inspect validation examples every 500 steps. Sample 10 outputs from val set manually
- Plot training vs val loss curve. Divergence = overfit
- Test inference with the exact system prompt used in training
- Check GPU utilization — should be >90%. If low: data loading bottleneck
- Verify merged model produces same outputs as unmerged PEFT model on 5 test prompts
Deployment Considerations
Serving Stack Options
| Framework | Best For | Key Feature | Maturity |
|---|---|---|---|
| vLLM | High-throughput production APIs | PagedAttention, continuous batching, OpenAI-compatible | Production |
| TGI (HuggingFace) | HF-native deployments | Tensor parallelism, quantization, streaming | Production |
| Ollama | Local/edge deployment | Simple setup, GGUF support | Production (small scale) |
| llama.cpp server | CPU inference, edge | No GPU required, GGUF | Production |
| LitServe | Custom inference logic | Flexible Python API, batching | Growing |
Serving Merged LoRA with vLLM
# Serve merged model (recommended for production)
vllm serve ./merged-model \
--dtype bfloat16 \
--max-model-len 4096 \
--tensor-parallel-size 1 \
--enable-prefix-caching \
--served-model-name factory-planning-v1
# Serve with LoRA adapter (good for multi-adapter setups)
vllm serve meta-llama/Llama-3.1-8B-Instruct \
--enable-lora \
--lora-modules factory-v1=./output/factory-planning-v1 \
--max-lora-rank 16
LoRA vs Merged: When to Merge
- Single adapter per deployment
- Latency is critical
- You want to quantize the final model
- Using non-PEFT inference stacks
- Multiple tasks, one base model
- Frequently updating adapters
- A/B testing different adapter versions
- Memory is constrained, dynamic switching needed
- Quantize to INT8/FP8 for efficiency
- Run full eval suite on merged model
- Test edge cases before promoting
- Version and tag the merged artifact
A/B Testing Fine-Tuned Models
- Deploy baseline and candidate behind feature flag or % rollout
- Log all inputs and outputs to both models for offline evaluation
- Run both models on same input when possible (shadow mode)
- Define success metric before launch (not after seeing results)
- Monitor for latency regression — fine-tuned model may have different throughput
- Set minimum sample size for statistical significance before declaring winner
Real-World Use Cases
Use Case 1: Manufacturing Workflow Assistant
Applicable to: NeoFAB-style factory planning platforms
Dataset Design for Manufacturing AI
{
"messages": [
{
"role": "system",
"content": "You are a factory planning AI. Return ONLY valid JSON matching the specified schema."
},
{
"role": "user",
"content": "Plan a station for gear hobbing operation. Required output: 500 gears/day, 8hr shift. Cycle time per gear: 4 minutes."
},
{
"role": "assistant",
"content": "{\"station\": \"Gear Hobbing\", \"takt_time_sec\": 57.6, \"cycle_time_sec\": 240, \"machines_required\": 5, \"utilization_pct\": 83.3, \"bottleneck\": false, \"notes\": \"4.17 machines ideal; round up to 5 for buffer\"}"
}
]
}
Use Case 2: Document QA System
Fine-tune for grounded QA over technical documents (spec sheets, SOPs, compliance docs).
- Method: LoRA SFT on RAG-augmented examples
- Key training signal: Model must learn to say "Not in the provided documents" — include negative examples in training
- Dataset: 500 (context, question, answer) triples from real documents, 200 unanswerable questions
- Critical: Train on same chunking strategy you'll use in production RAG
Use Case 3: Classification (Ticket / Defect Routing)
- Method: QLoRA on 7B model, or consider smaller model (Phi-3 mini) for cost
- Format: Single-token or few-token output (category label only)
- Dataset: 200+ examples per class, balanced or weighted loss
- Tip: Compare against fine-tuned BERT/DeBERTa — for pure classification, discriminative models often outperform generative LLMs at 100× lower inference cost
Use Case 4: Chatbot / Customer Support
- Method: Instruction tuning (multi-turn chat format) + optional DPO for tone alignment
- Dataset: 5,000–50,000 multi-turn conversations. Include edge cases, refusals, handoffs.
- Critical hyperparams: max_seq_length must cover longest conversations (4096+ for multi-turn)
- Post-training: DPO on 1,000 preference pairs to align tone, reduce hallucination
- Eval: Human evaluation is mandatory — automated metrics miss conversational quality
Practical Tips
What Actually Matters in Production
Data quality > Model size > Method choice > Hyperparameters. Most teams spend 80% of time on hyperparameters and 20% on data. Invert this ratio.
1. System Prompt Consistency is Non-Negotiable
The system prompt used during training must be byte-for-byte identical to production. Even a minor difference (extra space, different wording) causes silent degradation. Treat the system prompt as part of the model artifact — version it alongside weights.
2. Train on Completion Only
Mask out the prompt/instruction tokens so the model only computes loss on the assistant response. Training on the full sequence teaches the model to predict its own instructions, wasting capacity. Check that labels has -100 for all prompt token positions.
# Correct: only train on completion tokens
response_template = "<|assistant|>" # depends on model's chat template
collator = DataCollatorForCompletionOnlyLM(
response_template=response_template,
tokenizer=tokenizer,
)
# Pass collator to SFTTrainer — it handles label masking automatically
3. Validate Before You Scale
Run a 100-step "smoke test" with 10% of data before committing to a full training run. Verify: loss decreases, GPU utilization is high, sample outputs look reasonable. This costs $0.50 and saves $50 in wasted compute.
4. The Chat Template Problem
Every base model has a different chat template (Llama uses [INST] tags, Mistral uses different tokens, Qwen3 uses yet another). Always use tokenizer.apply_chat_template() — never format manually. Get this wrong and you'll wonder why your trained model produces garbage.
5. Monitor Token Length Distribution
Compute percentile distribution of token lengths in your training set. If 50% of examples are <100 tokens but you set max_seq_length=2048, you're wasting VRAM on padding (unless packing). If training examples are longer than inference inputs, model may underperform on short inputs.
6. Flash Attention 2 — Always On
FA2 is free performance on any Ampere+ GPU. It reduces memory quadratic attention to linear, speeds up by 2-4×, and is now stable. Pass attn_implementation="flash_attention_2" to from_pretrained(). If you see errors, install pip install flash-attn --no-build-isolation.
7. Checkpoint Strategy
Save every 50-100 steps with rolling window of 3-5 checkpoints. The best checkpoint is often not the last one. Use load_best_model_at_end=True with metric_for_best_model="eval_loss". For spot instances: save every 50 steps to avoid losing progress on preemption.
8. Cheap Performance Tricks
- Increase batch diversity: More diverse data matters more than more data from the same distribution
- Clean before volume: Remove duplicates, filter low-quality examples, verify outputs. This alone often improves final performance by 5-15%
- Warm-start with a better base: Fine-tuning Mistral-7B-Instruct is easier and faster than fine-tuning Mistral-7B-Base. Start from instruct checkpoints.
- Multi-epoch with shuffling: 3 epochs with data shuffled between epochs often beats 1 epoch on 3× more data
- Eval set is your north star: Build your eval set before any training. Never let eval set contaminate training. Eval quality determines if you can trust your training signal.
Common Production Pitfalls
- Overfitting to eval set through hyperparameter search: If you tune 20+ hyperparameter configurations on the same val set, you've overfit the val set. Hold out a final test set untouched until deployment decision.
- Ignoring inference vs training behavior gap: Temperature, top_p, repetition_penalty all affect output. Match inference sampling params to what you used when collecting "good" training outputs.
- Not versioning data and model together: 6 months later you can't reproduce a training run because you don't know which data version was used.
- Deploying without latency testing: A fine-tuned 70B model may be 5× slower than a 13B. Test latency under realistic concurrent load before production.
- Assuming more epochs = better: For SFT, 1-3 epochs is almost always optimal. 10+ epochs is always overfit. Trust eval loss over intuition.
Final Checklist Before Going to Production
- Eval suite runs automatically and results are logged
- Model tested on 50+ adversarial / edge case examples manually
- Inference system prompt matches training system prompt exactly
- Latency benchmarked at expected concurrent load
- LoRA weights merged (if single-task deployment)
- Model and data artifacts versioned and stored
- Rollback plan defined (previous model ready to restore)
- Monitoring in place: latency p50/p95, error rate, output length distribution
- A/B test framework ready if incrementally rolling out
Data Engine & Continuous Improvement Loop
Fine-tuning is not a one-time event. The best production ML teams run a data flywheel: deploy → observe failures → mine hard examples → relabel → retrain. Each iteration compounds. Teams that do this consistently outperform those doing larger one-shot training runs.
The Data Flywheel
Active Learning Strategies
Uncertainty Sampling
Sample examples where the model is least confident. For generation tasks, use self-consistency: generate N outputs for the same input. High variance in outputs = model is uncertain = high-value labeling target.
import numpy as np
from collections import Counter
def uncertainty_score(prompt: str, model, n_samples: int = 5) -> float:
"""Higher score = model is more uncertain = prioritize for labeling"""
outputs = [model.generate(prompt) for _ in range(n_samples)]
# For classification: entropy of predictions
counts = Counter(outputs)
probs = np.array(list(counts.values())) / n_samples
entropy = -np.sum(probs * np.log(probs + 1e-10))
return entropy # max entropy = max_uncertainty
# For structured output: check JSON validity rate as proxy for uncertainty
def json_uncertainty(prompt: str, model, n: int = 5) -> float:
valid_count = sum(1 for _ in range(n) if is_valid_json(model.generate(prompt)))
return 1.0 - (valid_count / n) # 1.0 = always fails, 0.0 = always succeeds
Failure Mining Pipeline
from dataclasses import dataclass
from typing import List
import json
@dataclass
class ProductionLog:
request_id: str
prompt: str
response: str
latency_ms: float
user_feedback: str | None # "thumbs_up" | "thumbs_down" | None
automated_score: float # 0-1 from your eval model
def mine_failures(logs: List[ProductionLog], threshold: float = 0.5) -> List[ProductionLog]:
failures = []
for log in logs:
# Hard failures
if log.user_feedback == "thumbs_down":
failures.append(log)
# Soft failures: low auto-score
elif log.automated_score < threshold:
failures.append(log)
# Format failures: invalid JSON when JSON expected
elif not is_valid_json(log.response) and expects_json(log.prompt):
failures.append(log)
return failures
# Failures become labeling candidates — send to annotation queue
# Relabeled failures go into next training iteration with 3× upsampling
Online vs Offline Data Collection
| Strategy | When | Latency | Quality | Scale |
|---|---|---|---|---|
| Offline batch mining | Weekly/monthly retrain cycles | High (days) | High (human review) | Limited |
| Online shadow labeling | High-traffic production | Low (real-time) | Medium (auto-scored) | Unlimited |
| Canary-based collection | Pre-launch testing | Days | High (controlled) | Small |
| Synthetic generation | Cold start / edge cases | Hours | Medium (needs filter) | Unlimited |
If you only retrain on failures, you risk degrading on the cases the model already handles well. Always maintain 30-50% "positive" examples in retraining mix — cases the model gets right that represent the full input distribution.
Evaluation Framework Design
Most teams under-invest in eval. A rigorous eval system is the difference between "we think it's better" and "we know it's better by X% on Y metric with Z confidence."
Three-Layer Eval System
~500 hand-crafted examples, never changes, run every training iteration. Your regression baseline. Never leak these into training.
Sample 1-5% of live traffic weekly. Auto-score with LLM judge. Catches distribution drift your golden set misses.
Run candidate model in parallel on live traffic before promoting. Compare outputs systematically. Real-world signal before any user impact.
LLM-as-Judge Pipeline (Production Grade)
from pydantic import BaseModel
from enum import Enum
class JudgeScore(BaseModel):
correctness: int # 1-5
completeness: int # 1-5
format_compliance: int # 1-5
hallucination_flag: bool
reasoning: str
# Judge prompt — most important part
JUDGE_SYSTEM = """You are an expert evaluator for manufacturing AI systems.
Rate responses on correctness, completeness, format compliance.
Flag any numerical errors or unsupported claims as hallucinations.
Be strict: a 5/5 means a domain expert would accept this without edits."""
async def judge_batch(examples: list, judge_model: str = "gpt-4o") -> list:
# Parallelize with asyncio for speed
tasks = [judge_single(ex, judge_model) for ex in examples]
scores = await asyncio.gather(*tasks)
return scores
def compute_eval_metrics(scores: list) -> dict:
return {
"mean_correctness": np.mean([s.correctness for s in scores]),
"mean_completeness": np.mean([s.completeness for s in scores]),
"format_pass_rate": np.mean([s.format_compliance >= 4 for s in scores]),
"hallucination_rate": np.mean([s.hallucination_flag for s in scores]),
"composite_score": np.mean([
(s.correctness * 0.5 + s.completeness * 0.3 + s.format_compliance * 0.2) / 5.0
for s in scores
])
}
Pairwise Comparison Eval
Instead of scoring model outputs independently, compare them head-to-head. This is more reliable and aligns with how humans naturally judge quality. Used by OpenAI, Anthropic, Google for model comparison.
PAIRWISE_PROMPT = """Compare these two responses to the same question.
Question: {question}
Response A: {response_a}
Response B: {response_b}
Which is better? Consider: accuracy, completeness, conciseness.
Output JSON: {{"winner": "A" | "B" | "tie", "reasoning": "..."}}"""
def run_pairwise_eval(questions, baseline_responses, candidate_responses):
wins = {"A": 0, "B": 0, "tie": 0}
for q, a, b in zip(questions, baseline_responses, candidate_responses):
# Run twice with A/B order swapped to control for position bias
result_1 = judge_pairwise(q, a, b)
result_2 = judge_pairwise(q, b, a) # B is now "A" in prompt
winner = deconflict_results(result_1, result_2)
wins[winner] += 1
win_rate = wins["B"] / (wins["A"] + wins["B"] + 1e-10)
return win_rate # >0.55 = meaningful improvement, deploy
Eval Drift Monitoring
Production data distribution shifts over time. Your golden eval set becomes less representative. Detect this before it silently degrades model performance.
- Track input embedding distribution over time (cosine distance from training distribution)
- Monitor output length distribution — sudden changes signal behavioral drift
- Track LLM-judge scores on weekly production samples. Alert if 7-day rolling average drops >5%
- Monitor per-category performance, not just aggregate — a category can collapse while overall looks stable
Regression Testing Before Deployment
Every deployment must pass: (1) golden set performance ≥ current production model, (2) no format regression on structured output examples, (3) no catastrophic forgetting on general-capability probe set. Automate these as CI gates.
# CI/CD gate for model deployment
def deployment_gate(candidate_model, production_model, eval_sets) -> bool:
checks = {
"golden_set_parity": False,
"format_regression": False,
"capability_preservation": False
}
# Check 1: Golden set ≥ baseline
candidate_score = eval_on_golden_set(candidate_model)
baseline_score = eval_on_golden_set(production_model)
checks["golden_set_parity"] = candidate_score >= baseline_score * 0.98
# Check 2: Format compliance (JSON validity, schema match)
format_score = eval_format_compliance(candidate_model, eval_sets["format"])
checks["format_regression"] = format_score >= 0.95
# Check 3: General capability not degraded
general_score = eval_on_general_probes(candidate_model)
checks["capability_preservation"] = general_score >= 0.80 # some degradation OK
return all(checks.values())
Tokenization & Context Strategy
Tokenization Mismatch Issues
The most underrated source of silent bugs in fine-tuning pipelines. The model you train and the model you serve must use identical tokenization.
Every model family has a unique chat template (special tokens, turn delimiters, system prompt format). Llama-3 uses <|begin_of_text|>, Qwen3 uses <|im_start|>, Mistral uses [INST]. If you manually construct the prompt string instead of using apply_chat_template(), you will silently break the model.
# BAD — manual formatting (common mistake)
prompt = f"[INST] {instruction} [/INST]" # breaks for any non-Mistral model
# GOOD — use the tokenizer's own template
messages = [{"role": "user", "content": instruction}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True # adds assistant turn opener at end
)
# Verify: print 3 examples of your formatted training data before training
for i in range(3):
print(repr(tokenizer.apply_chat_template(dataset[i]["messages"], tokenize=False)))
Truncation Strategies
When examples exceed max_seq_length, how you truncate matters significantly.
| Strategy | How | Best For | Risk |
|---|---|---|---|
| Head truncation | Keep first N tokens, drop end | When beginning is most important (instructions, system prompt) | Cuts off response — label tokens lost |
| Tail truncation | Keep last N tokens, drop start | Completion-only tasks where end matters | Loses instruction/context |
| Middle truncation | Keep first+last N/2 tokens | Long documents with key info at boundaries | Loses middle context, can confuse model |
| Sliding window | Split into overlapping chunks | Long doc QA, document summarization | Increases dataset size, position artifacts |
| Hierarchical chunking | Semantic-aware splitting | Structured documents (manuals, specs) | Requires more preprocessing |
Sliding Window Implementation
def sliding_window_tokenize(
text: str,
tokenizer,
max_length: int = 2048,
stride: int = 256 # overlap between windows
) -> list:
tokens = tokenizer.encode(text)
chunks = []
for start in range(0, len(tokens), max_length - stride):
chunk = tokens[start : start + max_length]
if len(chunk) < 64: # skip tiny tail chunks
break
chunks.append(tokenizer.decode(chunk))
return chunks
Long Context Fine-Tuning
Models pretrained on 4k context often need explicit training to handle 32k+ contexts reliably. Key considerations:
- RoPE scaling: Many models use RoPE (Rotary Position Embedding). For context extension beyond pretraining length, enable dynamic NTK-aware RoPE scaling:
rope_scaling={"type": "dynamic", "factor": 2.0} - Long-context data distribution: Mix genuinely long examples (not just padded short ones). 10-20% long examples in training is often enough.
- Position interpolation: If extending beyond training context, use linear interpolation of positions rather than extrapolation
- VRAM explosion: Attention is O(n²) in sequence length. 4k → 8k doubles attention memory. Use Flash Attention 2 — reduces to O(n).
Training Stability & Real-World Tricks
Loss Spike Debugging
Loss spikes (sudden jumps mid-training) are common and often fixable. Systematic diagnosis:
| Symptom | Most Likely Cause | Fix |
|---|---|---|
| Loss spikes then recovers | Bad batch (outlier example) | Gradient clipping, inspect that batch, filter outliers |
| Loss spikes, stays high | LR too high | Reduce LR by 3-10×, restart from last checkpoint |
| Loss goes to NaN immediately | LR way too high OR FP16 overflow | Switch to BF16, reduce LR by 100× |
| Loss oscillates wildly | Batch too small (high gradient variance) | Increase gradient_accumulation_steps |
| Loss plateaus early | LR too low OR wrong target_modules | LR warmup too long, check LoRA target_modules |
| Val loss diverges at epoch 2+ | Overfitting | Reduce epochs, increase weight_decay, add dropout |
Gradient Clipping
max_grad_norm=1.0 is the standard. It prevents any single parameter update from being too large, dramatically stabilizing training. For DPO, use a tighter clip of 0.3-0.5 since preference shifts are smaller.
# In TrainingArguments — always include this
TrainingArguments(
max_grad_norm=1.0, # clip gradients exceeding this L2 norm
...
)
# Manual implementation if not using Trainer:
torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=1.0,
norm_type=2.0
)
Warmup Strategies Compared
| Scheduler | Shape | Best For | Notes |
|---|---|---|---|
| cosine | Smooth decay to near-zero | Most fine-tuning tasks | Default choice — smooth landing avoids final oscillation |
| linear | Linear decay to zero | Debugging runs | Predictable, easier to reason about |
| cosine_with_restarts | Cosine that resets periodically | Long training runs | Can escape local minima, complex to tune |
| constant_with_warmup | Warmup then flat | Very short runs (<500 steps) | No decay → risk of not converging |
| polynomial | Configurable decay curve | Research experiments | Rarely needed in FT context |
Warmup Ratio Rule of Thumb
Warmup prevents early large gradient updates from damaging the pretrained weights. Use warmup_ratio=0.05 (5% of total steps) for standard fine-tuning. For DPO and very small datasets (<500 examples), increase to 0.1.
Label Smoothing
Softens hard one-hot targets: instead of 1.0 for the correct token, use 1 - ε, distributing ε across other tokens. Reduces overconfidence, improves calibration. Useful for tasks where near-correct answers should not be heavily penalized.
TrainingArguments(
label_smoothing_factor=0.1, # 0.0 (off) to 0.1 is typical range
# Note: do NOT use with completion-only training (DataCollatorForCompletionOnlyLM)
# It interacts poorly with -100 label masking
)
Mixed Precision Pitfalls
- FP16 on older GPUs: FP16 has limited dynamic range (max ≈ 65504). Large gradient norms overflow to NaN. Symptoms: sudden NaN loss. Fix: switch to BF16 (much larger range) or enable dynamic loss scaling (
fp16=Truein TrainingArguments handles this automatically). - BF16 on Volta GPUs (V100): V100 doesn't natively support BF16. Will silently fall back to FP32, increasing VRAM. Always check GPU generation before specifying dtype.
- Optimizer in wrong precision: AdamW momentum states should always be FP32 even when model is FP16/BF16. HuggingFace Trainer handles this — but custom training loops often miss it.
- QLoRA + BF16: Frozen base is NF4, LoRA adapters train in BF16. The upcast happens automatically via
bnb_4bit_compute_dtype— don't override it.
Gradient Checkpointing Trade-offs
Reduces activation memory by 60-70%. Enables training longer sequences and larger batch sizes. Essential for full FT.
~30% slower training. Each activation is recomputed during backward pass. CPU↔GPU transfers add overhead.
If you have headroom VRAM, skip for speed. Only enable when you'd OOM otherwise. QLoRA almost always needs it.
Multi-Task & Multi-Domain Fine-Tuning
Dataset Blending Ratios
When fine-tuning on multiple data sources simultaneously, the mix ratio determines what the model optimizes hardest. This is not trivial — get it wrong and the model collapses into one task or loses general capability.
70% domain-specific task data · 20% general instruction data · 10% safety/alignment data. Adjust based on how domain-specialized you need the model to be and how much capability preservation matters.
from datasets import concatenate_datasets, interleave_datasets
# Method 1: Weighted interleaving (recommended)
# Each step, a dataset is sampled proportionally to its weight
mixed_dataset = interleave_datasets(
datasets=[domain_ds, general_ds, safety_ds],
probabilities=[0.70, 0.20, 0.10],
seed=42,
stopping_strategy="first_exhausted" # or "all_exhausted" to oversample
)
# Method 2: Concatenate + shuffle (simpler but less controlled)
from_each = {
"domain": domain_ds.select(range(7000)),
"general": general_ds.select(range(2000)),
"safety": safety_ds.select(range(1000)),
}
full_ds = concatenate_datasets(list(from_each.values())).shuffle(seed=42)
Task Balancing Techniques
- Temperature sampling: Sample from each dataset with temperature T. T<1 (e.g., 0.5) makes distribution more uniform; T>1 favors larger datasets more. Useful for very imbalanced sources.
- Loss weighting per task: Assign different loss multipliers per example type. Domain examples get weight 1.0, general instructions get 0.5. Effectively a soft version of ratio control.
- Curriculum learning: Start with easy/general examples, gradually introduce domain-specific hard examples. Can improve convergence for extreme domain shifts.
Instruction Diversity — Why It Matters
Models trained on narrow instruction templates overfit to those templates and fail on paraphrased or structurally different inputs. Instruction diversity is as important as answer quality.
# Bad: 1000 examples all starting with "Extract the following..."
# Model learns the surface pattern, not the underlying skill
# Good: Same task, diverse instruction templates
instruction_templates = [
"Extract the machine names from: {input}",
"From the following BOM, identify all machines: {input}",
"List every piece of equipment mentioned in: {input}",
"What manufacturing equipment is referenced here? {input}",
"Parse out equipment names from this text: {input}",
"Equipment extraction task. Input: {input}"
]
# Use LLM to generate N instruction variants for each example
# min 5 variants per skill = meaningfully more robust model
Multi-Task with LoRA: Separate Adapters vs One Adapter
| Approach | Pros | Cons | When to Use |
|---|---|---|---|
| One LoRA, all tasks | Single model, simple serving | Tasks can interfere, harder to debug | Tasks are related, shared representations help |
| Separate LoRA per task | No task interference, modular | Multiple adapters to manage | Unrelated tasks, per-customer adapters |
| Full FT on mixed data | Best cross-task generalization | High compute, catastrophic forgetting risk | Large dataset, tasks share significant structure |
Catastrophic Forgetting
Why It Happens
Neural networks store knowledge in weight configurations optimized for their training distribution. When you fine-tune on a new distribution, gradient updates that improve performance on the new task can destructively overwrite weights encoding prior knowledge. The optimizer has no mechanism to "protect" old knowledge — it only minimizes current loss.
Every fine-tuning method sits on a spectrum from highly plastic (full FT — learns fast, forgets fast) to highly stable (prompt tuning — barely learns, never forgets). LoRA with low rank sits in the middle — enough plasticity for real adaptation, enough stability to preserve base capabilities.
Forgetting Risk by Method
| Method | Forgetting Risk | Why |
|---|---|---|
| Full Fine-Tuning | Very High | All weights updated, no protection for base knowledge |
| LoRA (r=64+) | Medium | More parameters, higher risk of overwriting base representations |
| LoRA (r=8-16) | Low | Limited parameter budget constrains how much base is modified |
| QLoRA | Low | Frozen quantized base — physically cannot modify base weights |
| Adapters | Very Low | New layers inserted, original weights untouched |
| Prefix/Prompt Tuning | Minimal | Base weights completely frozen |
Mitigation Strategies
1. Replay Buffers
Include examples from the original pretraining/instruction-tuning distribution in each fine-tuning batch. Even 5-10% general data dramatically reduces forgetting.
# Replay buffer: always mix in general instruction data
# Ratio: 90% task-specific, 10% replay from general instruction dataset
replay_buffer = load_dataset("OpenHermes-2.5", split="train").shuffle()
replay_buffer = replay_buffer.select(range(500)) # small set is enough
combined = interleave_datasets(
[task_dataset, replay_buffer],
probabilities=[0.90, 0.10],
seed=42
)
2. Elastic Weight Consolidation (EWC)
Adds regularization term to the loss that penalizes changes to weights that were important for previous tasks. Computationally expensive but principled. Rarely used in practice — replay buffers and LoRA achieve similar results more cheaply.
3. Lower Learning Rate
Smaller LR = smaller weight updates = less forgetting. The tradeoff is slower convergence on the new task. Often the best first lever to pull when forgetting is observed.
4. Use LoRA with Low Rank
The most practical solution for most teams. Low-rank constraints limit how much of the base model's representations can be overwritten.
Detecting Catastrophic Forgetting
- Maintain a general capability probe set: 50-100 examples from standard benchmarks (MMLU, HellaSwag subset, basic math). Eval after every training run.
- Track instruction-following rate: Can the model still follow basic instructions unrelated to your task? "Write a haiku about autumn" — still works?
- Compare base model vs fine-tuned on your probe set. If fine-tuned model drops by >10%, you have a forgetting problem.
- Eval on adjacent tasks: If you FT for TAKT calculation, also test on basic arithmetic — did math capability degrade?
Inference Optimization
KV Cache Internals
During autoregressive generation, the model computes Key-Value pairs for every attention head for every token in the context. The KV cache stores these so they don't need to be recomputed for earlier tokens on each new generation step.
vLLM: Continuous Batching Explained
Traditional static batching: wait for N requests, process together, all finish before next batch. Problem: short requests wait for long ones; GPU sits idle between batches.
vLLM's continuous batching: as soon as any sequence finishes generation, a new request fills its slot. The GPU is always saturated. This is the single biggest throughput improvement in modern LLM serving.
Req1 (100 tokens) waits for Req2 (500 tokens) to finish. Throughput limited by longest request.
Req1 finishes at step 100, immediately replaced by Req5. GPU never waits. 3-10× throughput.
KV cache stored in non-contiguous pages (like OS virtual memory). Eliminates fragmentation. Enables 2-3× more concurrent sequences.
Speculative Decoding
Use a small "draft" model (2-3B) to speculatively generate N tokens ahead, then verify with the large "target" model in one forward pass. If draft tokens match target's distribution, accept them all — effectively N× speedup per verification step. Typical speedup: 2-3× for matching draft/target families.
# vLLM speculative decoding
vllm serve meta-llama/Llama-3.1-70B-Instruct \
--speculative-model meta-llama/Llama-3.2-1B-Instruct \
--num-speculative-tokens 5 \
--use-v2-block-manager
Throughput vs Latency Trade-offs
| Optimization | Throughput Impact | Latency Impact | When to Use |
|---|---|---|---|
| Larger batch size | +++ throughput | +++ latency | Async/offline processing |
| Quantization (INT8/FP8) | ++ throughput | - latency | Almost always — minimal quality loss |
| Speculative decoding | neutral/slight+ | -- latency | Interactive chat, low-latency APIs |
| Prefix caching | ++ throughput | -- latency (cached) | Shared system prompts, repeated prefixes |
| Tensor parallelism | neutral | --- latency | Large models, latency-critical serving |
| Smaller model | +++ throughput | --- latency | When quality is acceptable |
Automatic Prefix Caching
If many requests share the same prefix (e.g., a long system prompt), vLLM's APC reuses the computed KV states for that prefix across requests. Enable with --enable-prefix-caching. For a 2000-token system prompt with 1M requests/day, this saves enormous compute.
vllm serve ./merged-model \
--enable-prefix-caching \
--max-model-len 8192 \
--dtype bfloat16 \
--quantization fp8 # H100/A100 — free ~2× speedup
Fine-Tuning + RAG Hybrid Architectures
Fine-tuning teaches the model how to behave. RAG gives the model what to say. These are complementary — the best production systems for knowledge-intensive tasks combine both. Use FT to get the format/reasoning right; use RAG to keep the facts current.
When to Combine FT + RAG
- Knowledge is dynamic (product specs, prices, regulations change) — RAG handles freshness
- You need consistent output format/structure — FT handles this better than prompting
- Domain terminology is specialized — FT trains the model to understand domain language so RAG context is better utilized
- You want the model to "say I don't know" reliably when context is absent — FT can train this behavior
- System prompt is long and expensive — FT internalizes the behavior, shortening the prompt
Architecture Pattern
"""
Pattern: FT for behavior, RAG for knowledge
Fine-tuned model learns:
- How to structure the response (JSON schema, sections)
- How to say "not in provided context"
- Domain-specific reasoning patterns
- Tone, format, conciseness requirements
RAG provides:
- Current machine specs and pricing
- Latest regulatory requirements
- Customer-specific configurations
- Documents that change frequently
"""
# Training data format for FT+RAG model
training_example = {
"messages": [
{
"role": "system",
"content": "You are a factory planning assistant. Answer using ONLY the provided context."
},
{
"role": "user",
"content": (
"Context:\n{retrieved_chunks}\n\n"
"Question: {user_question}"
)
},
{
"role": "assistant",
"content": "{grounded_answer}"
}
]
}
# Crucially: include "not found" examples in training set!
# Model learns to abstain when context doesn't support the answer
Anti-Patterns: Don't Do This
❌ Using FT to Inject Factual Knowledge
Fine-tuning a model on "Q: What is the cycle time of CNC-2000? A: 45 seconds" does not reliably inject this fact. The model may learn the pattern but hallucinate similar-sounding facts at inference. Use RAG for facts. Use FT for behavior.
❌ Training on Retrieved Context Without Negative Examples
If every training example has context that answers the question, the model will never learn to say "not in context." Add 15-20% of examples where the context is irrelevant and the correct answer is an abstention.
❌ Different Chunking Strategy in Training vs Inference
If training examples use 512-token chunks but inference uses 256-token chunks, the model sees a different context structure than it was trained on. Use the exact same retrieval and chunking pipeline for generating training data as you'll use in production.
FT + RAG vs Pure RAG
| Dimension | Pure RAG (base model) | FT + RAG |
|---|---|---|
| Format consistency | Poor (prompt-dependent) | Excellent (trained in) |
| Abstention reliability | Moderate | High (trained on negatives) |
| Prompt length | Long (verbose system prompt) | Short (behavior is baked in) |
| Knowledge freshness | Always current | Always current |
| Setup complexity | Simple | More complex (FT + RAG pipeline) |
| Inference cost | Higher (long prompts) | Lower (short prompts) |
Data Formatting Edge Cases
Instruction Leakage
Instruction leakage occurs when the answer to a question can be derived from the instruction text itself, bypassing the need to actually learn the underlying skill. The model learns to pattern-match on the prompt rather than reason.
# BAD: Answer leaks from instruction
{
"instruction": "The defect category for surface cracks is SURFACE_CRACK. Classify the following defect.",
"input": "Surface crack detected on outer race",
"output": "SURFACE_CRACK"
}
# GOOD: Instruction describes task, not answer
{
"instruction": "Classify the manufacturing defect into one of: SURFACE_CRACK, DIMENSIONAL, INCLUSION, POROSITY.",
"input": "Linear mark, 1.2mm depth on outer race surface",
"output": "SURFACE_CRACK"
}
Overfitting to Prompt Templates
If all training examples start with the same instruction pattern ("Given the following...", "Your task is to..."), the model memorizes the format rather than the skill. At inference, slight rephrasing causes degradation.
Fix: Instruction Augmentation
def augment_instruction(instruction: str, n_variants: int = 5) -> list:
prompt = f"""Rewrite this instruction {n_variants} different ways.
Keep the task identical, vary only the phrasing and structure.
Instruction: {instruction}
Return as JSON array of strings."""
variants = call_llm(prompt)
return [instruction] + variants # original + N variants
EOS Token Handling
Missing or misplaced EOS tokens are one of the most common causes of "model that generates forever" at inference. Every training example's response must end with the tokenizer's EOS token.
# Verify EOS tokens are present in your dataset
def audit_eos_tokens(dataset, tokenizer):
issues = []
eos_token_id = tokenizer.eos_token_id
for i, example in enumerate(dataset):
tokens = tokenizer.encode(example["text"])
if tokens[-1] != eos_token_id:
issues.append(i)
if issues:
print(f"WARNING: {len(issues)} examples missing EOS token")
print(f"Sample indices: {issues[:5]}")
return len(issues) == 0
# Common fix: SFTTrainer adds EOS automatically when using apply_chat_template
# If manually constructing: append tokenizer.eos_token to every response
Chat Template Mismatch Bugs
The top source of "my fine-tuned model is worse than the base model" reports. A catalog of real bugs:
| Bug | Symptom | Root Cause | Fix |
|---|---|---|---|
| Wrong template used | Model ignores instructions, generates random text | Training on Llama template, serving Mistral model | Match tokenizer to model always |
| add_generation_prompt=False in inference | Model repeats human turn instead of responding | No assistant token prepended at inference | Set True for inference, False for training |
| System prompt stripped | Model behaves like base model despite FT | Template doesn't include system role tokens | Verify system messages survive apply_chat_template |
| Padding side mismatch | Model quality degrades on short sequences | Left-padding at training, right-padding at inference | padding_side="right" for causal LM always |
| Truncation cuts response | Training examples have truncated outputs | max_length too short, truncates from right | Truncate from left (input side), not right (output) |
Print repr(tokenizer.apply_chat_template([{"role": "user", "content": "test"}], tokenize=False, add_generation_prompt=True)) and inspect every special token. Then do the same for your inference code. They must be identical.
Failure Modes — Complete Reference
A structured taxonomy of every failure mode encountered in production fine-tuning. Diagnostic-first — know the symptom, find the cause, apply the fix.
| Symptom | Root Cause | Fix | Prevention |
|---|---|---|---|
| Model repeats output in loops | Overfitting, missing EOS, high repetition_penalty training | Reduce epochs; verify EOS in training data; add repetition_penalty at inference | Audit EOS tokens pre-training |
| Hallucinations increased vs baseline | Synthetic data without filtering; FT on "confident wrong" outputs | Filter training data with LLM judge; remove examples with unverified facts | Expert review of training data |
| Format breaks (invalid JSON) | Inconsistent output format in training data; template mismatch | Enforce schema via JSON validator during data generation; fix template | 100% format validation on training set |
| Loss goes NaN | LR too high; FP16 overflow; bad data (empty strings, null) | Reduce LR 10×; switch to BF16; add data validation step | Data validation pipeline, use BF16 |
| Loss spikes mid-training | Outlier batch (very long or very abnormal example) | Gradient clipping; filter length outliers; restart from checkpoint | Filter p99 length outliers pre-training |
| Val loss diverges, train continues down | Overfitting to training set | Early stopping; reduce epochs; increase weight_decay | Separate val set, monitor from step 1 |
| Model ignores system prompt | System prompt format wrong; template mismatch; labels mask system tokens | Debug chat template; verify system role survives tokenization | Print formatted examples pre-training |
| Model worse than base model | Data quality too low; catastrophic forgetting; wrong template | Audit data quality; add replay buffer; fix template | Baseline comparison mandatory before deploy |
| Good on eval, bad in production | Eval-train leak; eval distribution mismatch; sampling params differ | Rebuild eval set from held-out production data; match inference params | Separate train/val/test strictly |
| Slow training (low GPU util) | Data loading bottleneck; tokenization on-the-fly; small batch | Increase dataloader_num_workers; pre-tokenize dataset; increase grad_accum | Profile GPU util in first 50 steps |
| OOM mid-training | Sequence packing creating unexpectedly long sequences; no grad checkpointing | Add max_seq_length cap; enable gradient_checkpointing | Test with batch_size=1 first |
| Model refuses all inputs (safety over-trained) | Too many refusal examples in training data; DPO with bad rejected labels | Balance positive/refusal ratio; review DPO preference pairs | Monitor refusal rate on eval set |
| Inconsistent output length | Training outputs vary widely in length; no max_new_tokens at inference | Add length guidance to instructions; set max_new_tokens in generation config | Normalize output lengths in training data |
| Model generates in wrong language | Mixed language training data; base model default is different language | Add explicit language instruction; filter training data by language | Language detection on training set |
Base Model Selection & Benchmarking
Model Family Comparison
| Family | Strengths | Weaknesses | Best For | License |
|---|---|---|---|---|
| Llama 3.x | Best overall quality/size, strong instruction following, huge ecosystem | Commercial restrictions for large deployments | General tasks, chat, instruction tuning | Llama Community |
| Qwen 2.5 / 3 | Excellent multilingual, strong math/code, long context (128k) | Less community fine-tune tooling | Multilingual, math, long doc tasks | Apache 2.0 (most) |
| Mistral / Mixtral | MoE architecture efficient, strong reasoning, permissive license | Smaller community vs Llama, less tooling | Efficient inference, multi-domain MoE | Apache 2.0 |
| Gemma 2 / 3 | Strong for size, good for edge/on-device | Limited commercial use cases for some sizes | On-device, constrained compute | Gemma ToS |
| Phi-3 / 4 | Exceptional performance at tiny sizes (3.8B) | Very small context (some variants) | Classification, edge deployment, latency-critical | MIT |
| DeepSeek-R1 | Strong reasoning, CoT quality | Large, slow, complex | Complex reasoning, code, math tasks | MIT (distills) |
Pre-FT Benchmarking Checklist
Before committing to a base model for fine-tuning, run this systematic evaluation. Never choose a base model based on leaderboard position alone — leaderboards measure average capability, not your specific task.
"""
Pre-FT Base Model Evaluation Checklist
Run all candidates on the SAME test set before choosing base model
"""
eval_plan = {
"zero_shot": {
"description": "No examples, just instruction. Baseline task capability.",
"n_examples": 50,
"metric": "task_accuracy"
},
"few_shot": {
"description": "3-5 examples in context. Does model use them correctly?",
"n_examples": 50,
"metric": "task_accuracy"
},
"format_compliance": {
"description": "JSON / structured output. Can it follow format without FT?",
"n_examples": 30,
"metric": "json_valid_rate"
},
"latency": {
"description": "TTFT and tokens/sec at expected concurrency",
"warmup_requests": 10,
"metric": "p50_ttft_ms, tokens_per_sec"
},
"cost": {
"description": "Cost per 1M tokens at expected daily volume",
"metric": "usd_per_1M_tokens"
}
}
Choosing Model Size: The Practical Framework
Fine-tuning from an Instruct checkpoint (e.g., Llama-3.1-8B-Instruct) is almost always better than from the base model for task-specific fine-tuning. Instruct models already know how to follow instructions — your training budget goes further on task adaptation rather than basic instruction following.
Tooling Ecosystem
Library Selection Guide
| Tool | Purpose | When to Use | Maturity |
|---|---|---|---|
| HuggingFace TRL | SFT, DPO, RLHF training loops | Default for SFT and DPO. Best-maintained, most integrations. | Production |
| PEFT | LoRA, QLoRA, adapters, prefix tuning | Always — use alongside TRL or Trainer | Production |
| Axolotl | Config-driven training orchestration | When you want to launch FT jobs via YAML config without writing training code. Great for quick experiments and team standardization. | Production |
| DeepSpeed | Distributed training, ZeRO optimization | Multi-GPU full fine-tuning of 7B+ models. Required for ZeRO-3. | Production |
| Unsloth | Fast LoRA/QLoRA training (2-5× speedup) | When training speed matters, single GPU, Llama/Mistral family. Monkey-patches HF internals for speed. | Growing |
| vLLM | High-throughput inference serving | Production serving of fine-tuned models. OpenAI-compatible API. | Production |
| LitGPT | Minimal training framework | Research, custom training loops, when you need full control | Research |
| torchtune | PyTorch-native fine-tuning | PyTorch-centric shops, when you want native torch without HF abstractions | Growing |
Axolotl Config Example
Axolotl lets you define an entire training run in YAML. Recommended for team environments where you want reproducibility without custom code.
# axolotl config: qlora_factory.yml
base_model: meta-llama/Llama-3.1-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true # target all linear layers
datasets:
- path: data/factory_training.jsonl
type: chat_template
train_on_inputs: false # completion only
val_set_size: 0.05
sequence_len: 2048
sample_packing: true
gradient_accumulation_steps: 8
micro_batch_size: 2
num_epochs: 3
learning_rate: 0.0002
optimizer: paged_adamw_8bit
lr_scheduler: cosine
warmup_ratio: 0.05
bf16: auto
flash_attention: true
gradient_checkpointing: true
wandb_project: factory-planning-ft
output_dir: ./output/factory-v1
# Launch training
axolotl train qlora_factory.yml
# Merge adapter after training
axolotl merge-lora qlora_factory.yml --lora-model-dir ./output/factory-v1
DeepSpeed ZeRO-3 Config
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "cpu"},
"offload_param": {"device": "cpu"},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"gather_16bit_weights_on_model_save": true
},
"bf16": {"enabled": "auto"},
"gradient_clipping": 1.0,
"train_batch_size": "auto"
}
Safety & Guardrails
Safety Fine-Tuning vs. Moderation Layer
Two distinct approaches to making deployed models safe. Both have a role in production systems.
Bake safety behavior into model weights via DPO/SFT on refusal data. Model inherently declines harmful requests. Lowest latency, no extra inference cost. Risk: can be over-refused or jailbroken.
Separate classifier (small model or rule-based) checks input/output. OpenAI Moderation API, Llama Guard, Perspective API. Flexible, updatable without retraining base model. Adds latency.
Use both in layers. Safety FT for primary defense (low overhead). Moderation classifier for hard policy enforcement (compliance, legal). Defense in depth.
Policy Tuning via DPO
DPO is the practical way to enforce domain-specific policies: "always cite sources," "never claim certainty about regulatory requirements," "always recommend consulting an expert for safety-critical decisions."
// DPO preference pair for policy enforcement
{
"prompt": "What torque spec should I use for the M12 bolt in this assembly?",
"chosen": "Based on the provided spec sheet, the recommended torque is 85 Nm. However, always verify with the latest version of the assembly manual and consult your process engineer for safety-critical fasteners.",
"rejected": "Use 85 Nm torque for the M12 bolt. This is the standard specification."
}
// Chosen: grounded + appropriate hedge
// Rejected: confident but potentially dangerous without source
Jailbreak Resistance
Fine-tuned models can inherit or amplify jailbreaks from the base model. Practical defenses:
- Include adversarial examples in DPO data — rejected = complying with jailbreak, chosen = appropriate refusal
- Use Llama Guard or a similar classifier as a moderation layer for high-stakes applications
- Test against known jailbreak patterns (prefix injection, role-playing, encoding tricks) before deployment
- Keep system prompt logic in model weights (via FT), not just at inference — system prompts can be stripped by adversarial API users
Over-indexing on safety examples in training creates an over-refusing model that frustrates legitimate users. Target: <5% false-refusal rate on legitimate requests. Monitor this metric explicitly. DPO pairs should have roughly equal chosen-safety and chosen-helpful examples.
Experiment Tracking
You will run the same experiment twice. You will not know which checkpoint produced which results. You will lose your best config. Set up experiment tracking before running your first training job — not after the third one fails.
Weights & Biases Setup
import wandb
# Initialize at start of training script
wandb.init(
project="llm-factory-finetuning",
name=f"qlora-r16-lr2e4-ep3-{wandb.util.generate_id()}",
config={
# Always log the full config — not just "interesting" params
"base_model": MODEL_ID,
"method": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"learning_rate": 2e-4,
"epochs": 3,
"batch_size": 2,
"grad_accum": 8,
"max_seq_len": 2048,
"dataset_version": "v3.1", # CRITICAL — version your data
"dataset_size": len(dataset),
"git_commit": get_git_hash(), # reproducibility
},
tags=["qlora", "factory-planning", "v3-data"]
)
# Log custom metrics during training
class EvalCallback(TrainerCallback):
def on_evaluate(self, args, state, control, metrics, **kwargs):
# Log task-specific metrics beyond just loss
task_metrics = run_task_eval(trainer.model)
wandb.log({
"eval/json_validity": task_metrics["json_valid_rate"],
"eval/takt_accuracy": task_metrics["takt_accuracy"],
"eval/composite_score": task_metrics["composite"],
"step": state.global_step
})
What to Log (Non-Negotiables)
W&B vs MLflow: When to Use What
| Weights & Biases | MLflow | |
|---|---|---|
| Best for | Real-time training monitoring, collaboration, rich visualizations | Self-hosted, compliance requirements, artifact versioning |
| Setup cost | Minutes (SaaS) | Hours (self-hosted setup) |
| Artifact storage | Good (W&B Artifacts) | Excellent (MLflow Models) |
| Experiment comparison | Best-in-class UI | Functional but less polished |
| HuggingFace Trainer integration | report_to="wandb" | report_to="mlflow" |
| Cost | Free tier, $50+/month for teams | Free (self-hosted) |
Advanced Topics
MoE Fine-Tuning (Mixture of Experts)
MoE models (Mixtral 8×7B, Qwen MoE, DeepSeek-MoE) have a sparse architecture: only a subset of "expert" FFN layers are active per token. This creates interesting dynamics for fine-tuning.
- LoRA on MoE: Apply LoRA to the router and to each expert's projection matrices. Target fewer experts with higher rank rather than all experts with low rank. The router needs to be included in target_modules.
- Load balancing: During FT, if training distribution activates only a few experts heavily, those experts overfit while others degrade. Monitor expert utilization distribution during training.
- VRAM challenge: Mixtral 8×7B has 47B total params but only 13B active per token. All 47B still need to be in VRAM. QLoRA helps substantially.
# LoRA on Mixtral — target all expert modules
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"w1", "w2", "w3" # Expert FFN layers in Mixtral
],
modules_to_save=["lm_head"], # save these fully (not LoRA)
task_type="CAUSAL_LM"
)
Vision-Language Model (VLM) Fine-Tuning
Models like LLaVA, Qwen-VL, InternVL, and Llama-3.2-Vision accept image + text input. Fine-tuning adds domain-specific visual understanding.
- What to fine-tune: Usually LoRA on the language model component only. The vision encoder is often frozen (it's already very capable). The cross-modal projection layer can also be trained.
- Dataset format: Image + text pairs with conversation format. Images are encoded by the vision encoder before being passed to the LLM.
- Use cases: PCB defect detection, medical image Q&A, factory floor monitoring, document understanding with visual layouts.
- Key challenge: Images tokenize to 256-1024 tokens depending on resolution. A batch of 8 images at 512 tokens = 4096 extra tokens. Plan VRAM accordingly.
Knowledge Distillation After Fine-Tuning
Use your fine-tuned large model (teacher) to generate training data for a smaller model (student). Effectively compresses the fine-tuned behavior into a more efficient model.
"""
Distillation pipeline:
1. Fine-tune 70B teacher on your task (expensive, once)
2. Use teacher to generate high-quality (input, output) pairs at scale
3. Fine-tune 7B student on teacher's outputs
4. Student gets ~85-90% of teacher quality at 10× lower inference cost
"""
async def distill_dataset(prompts: list, teacher_model, n_samples: int = 3) -> list:
# Generate N outputs per prompt from teacher
# Select highest-quality output using LLM-as-judge
examples = []
for prompt in prompts:
outputs = [await generate(teacher_model, prompt) for _ in range(n_samples)]
best = await select_best(outputs, prompt) # judge selects best
examples.append({"prompt": prompt, "response": best})
return examples
Model Merging
Combine multiple fine-tuned models by merging their weight spaces. Produces a model that has capabilities of all source models without additional inference cost. Key techniques:
SLERP (Spherical Linear Interpolation)
Interpolate between two model checkpoints along the surface of a hypersphere. Better than linear interpolation for preserving model capabilities. Used by mergekit.
TIES (Task-Vector Arithmetic)
Compute "task vectors" (FT model - base model) for each fine-tuned model, then combine vectors with appropriate scaling and interference resolution before adding back to base.
DARE (Drop And REscale)
Randomly zero out a fraction of delta weights, then rescale remaining weights to compensate. Reduces interference between merged models. Use with TIES for best results.
# mergekit config: merge two domain-specific LoRA models
merge_method: ties
base_model: meta-llama/Llama-3.1-8B-Instruct
models:
- model: ./factory-planning-merged # TAKT/capacity planning
parameters:
weight: 0.6
density: 0.7 # DARE: keep 70% of delta
- model: ./machine-selection-merged # machine selection skills
parameters:
weight: 0.4
density: 0.7
parameters:
normalize: true
dtype: bfloat16
Model merging is not a replacement for proper fine-tuning — it's a way to combine capabilities cheaply. Best use case: you have two specialized models and need a generalist without the cost of retraining from scratch on combined data. Always eval the merged model — quality can vary unpredictably and requires empirical validation.