Training Loop Implementation¶
Navigation: 📖 Main README | 💾 Checkpointing | 🎯 PEFT Techniques | 🔍 Code Quality
Overview¶
Status: ✅ Complete - Full training loop patterns and best practices documented
Note: Code examples throughout this document include imports at the beginning of each snippet for clarity and standalone usage. This is intentional to make each example self-contained and easy to copy.
This capability covers complete training loop implementations for ML models, including batch processing, gradient updates, and epoch management.
Planned Content¶
This document will cover: - Basic Training Loop: Single-epoch training implementation - Multi-Epoch Training: Complete training with validation - Distributed Training: Multi-GPU and distributed strategies - Mixed Precision: AMP and gradient scaling - Progress Tracking: Metrics, logging, and monitoring
Current Implementation¶
Training loop patterns in codex follow PyTorch best practices and are implemented across: - agents/ modules - Agent training implementations with various strategies - tests/agents/ - Training loop test examples and validation - Distributed support - Multi-GPU training with DDP and FSDP - Mixed precision - Automatic Mixed Precision (AMP) for faster training
Helper Functions¶
The examples below use these common helper functions:
import torch
import torch.nn as nn
from pathlib import Path
def compute_loss(outputs, targets, loss_fn=None):
"""Compute loss between model outputs and targets.
Args:
outputs: Model predictions
targets: Ground truth labels
loss_fn: Optional custom loss function (defaults to CrossEntropyLoss)
Returns:
Loss tensor
"""
if loss_fn is None:
loss_fn = nn.CrossEntropyLoss()
return loss_fn(outputs, targets)
def save_checkpoint(model, optimizer, epoch, path):
"""Save a basic training checkpoint.
Args:
model: PyTorch model
optimizer: PyTorch optimizer
epoch: Current epoch number
path: Path to save checkpoint
Note:
For advanced checkpointing with scheduler state, metadata,
and cloud storage, see CheckpointManager in checkpointing.md
"""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(checkpoint, path)
Complete Training Loop Pattern¶
import torch
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
def train_loop(
model,
train_loader,
val_loader,
optimizer,
scheduler,
num_epochs,
device,
checkpoint_dir=None,
use_amp=True,
):
"""Complete training loop with validation and checkpointing."""
scaler = GradScaler() if use_amp else None
best_val_loss = float('inf')
for epoch in range(num_epochs):
# Training phase
model.train()
train_loss = 0.0
train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
for batch_idx, (inputs, targets) in enumerate(train_pbar):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
# Mixed precision forward pass
if use_amp:
with autocast():
outputs = model(inputs)
loss = compute_loss(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
outputs = model(inputs)
loss = compute_loss(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_pbar.set_postfix({"loss": loss.item()})
avg_train_loss = train_loss / len(train_loader)
# Validation phase
val_loss = validate(model, val_loader, device)
# Learning rate scheduling
if scheduler:
scheduler.step(val_loss)
# Checkpoint best model
if checkpoint_dir and val_loss < best_val_loss:
best_val_loss = val_loss
checkpoint_path = checkpoint_dir / f"best_model_epoch{epoch}.pt"
# Note: For production use, see CheckpointManager in checkpointing.md
# which handles scheduler state, metadata, and advanced features
save_checkpoint(model, optimizer, epoch, checkpoint_path)
print(f"Epoch {epoch+1}: train_loss={avg_train_loss:.4f}, val_loss={val_loss:.4f}")
return model
def validate(model, val_loader, device):
"""Validation loop."""
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = compute_loss(outputs, targets)
val_loss += loss.item()
return val_loss / len(val_loader)
Distributed Training Loop¶
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def distributed_train_loop(
model,
train_loader,
optimizer,
num_epochs,
rank,
world_size,
):
"""Training loop for distributed training."""
# Wrap model with DDP
model = model.to(rank)
ddp_model = DDP(model, device_ids=[rank])
for epoch in range(num_epochs):
# Set epoch for distributed sampler
train_loader.sampler.set_epoch(epoch)
ddp_model.train()
for batch in train_loader:
inputs, targets = batch
inputs, targets = inputs.to(rank), targets.to(rank)
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = compute_loss(outputs, targets)
loss.backward()
optimizer.step()
# Synchronize metrics across all processes
if rank == 0:
print(f"Epoch {epoch+1} complete")
dist.barrier()
Gradient Accumulation Pattern¶
import torch
def train_with_gradient_accumulation(
model: torch.nn.Module,
train_loader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
accumulation_steps: int = 4,
device: str = "cuda",
) -> None:
"""Training loop with gradient accumulation for large batch sizes."""
model.train()
optimizer.zero_grad()
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
# Forward pass
outputs = model(inputs)
loss = compute_loss(outputs, targets)
# Scale loss by accumulation steps
loss = loss / accumulation_steps
loss.backward()
# Update weights every N steps
if (batch_idx + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
Related Capabilities¶
- functional-training: Functional training patterns
- experiment-management: Training experiment tracking
- checkpointing: Model checkpoint management
Example Pattern¶
def train_epoch(model, dataloader, optimizer, device):
"""Basic training epoch implementation."""
model.train()
total_loss = 0.0
for batch in dataloader:
inputs, targets = batch
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = compute_loss(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
Best Practices¶
-
Error Handling: Wrap training in try-except for graceful failures
-
Progress Tracking: Use tqdm for visual progress bars and logging
- Display current loss in progress bar
- Log metrics to tensorboard/wandb/mlflow
-
Save training curves for analysis
-
Gradient Clipping: Prevent exploding gradients
-
Early Stopping: Stop training when validation plateaus
-
Deterministic Training: Set seeds for reproducibility
Integration Points¶
- Checkpointing: Save model state during training (see checkpointing.md)
- Experiment Tracking: Log metrics to MLflow (see experiment_management.md)
- Functional Training: Pure functional training patterns (see functional_training.md)
- CI/CD Tests: Training loop validation in
tests/agents/test_*_training.py
Performance Optimization¶
Mixed Precision Training¶
import torch
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
outputs = model(inputs)
loss = compute_loss(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
DataLoader Optimization¶
import torch
from torch.utils.data import DataLoader
train_loader = DataLoader(
dataset,
batch_size=32,
num_workers=4, # Parallel data loading
pin_memory=True, # Faster GPU transfer
prefetch_factor=2, # Pre-fetch batches
)