Model Checkpointing¶
Navigation: 📖 Main README | 🔄 Training Loops | 🎯 PEFT Techniques | 🔍 Code Quality
Overview¶
Status: ✅ Complete - Full implementation details and examples provided
This capability covers model checkpointing strategies for saving and loading model state during training, enabling recovery from failures and experiment reproducibility.
Planned Content¶
This document will cover: - Checkpoint Formats: PyTorch, SafeTensors, ONNX - Checkpoint Strategies: Best-model, periodic, latest - State Management: Model, optimizer, scheduler state - Distributed Checkpointing: Multi-GPU checkpoint handling - Cloud Storage: S3, Azure, GCS integration
Checkpoint Components¶
A complete checkpoint includes: - Model State: Model weights and architecture - Optimizer State: Optimizer parameters and momentum - Scheduler State: Learning rate schedule state - Training Metadata: Epoch, step, metrics, random state
Current Implementation¶
Checkpointing patterns in the codebase: - See PyTorch save/load in training code - Check experiment management for checkpoint tracking
Related Capabilities¶
- experiment-management: Checkpoint versioning and tracking
- reproducibility: Deterministic checkpoint loading
- functional-training: Training checkpoint integration
Example Usage¶
import torch
def save_checkpoint(model, optimizer, epoch, path):
"""Save a training checkpoint."""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(checkpoint, path)
def load_checkpoint(model, optimizer, path):
"""Load a training checkpoint."""
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch']
Checkpoint Strategies¶
Best Model¶
Save checkpoint when validation metric improves:
if val_loss < best_val_loss:
save_checkpoint(model, optimizer, epoch, 'best_model.pt')
best_val_loss = val_loss
Periodic¶
Save checkpoint at regular intervals:
if epoch % save_frequency == 0:
save_checkpoint(model, optimizer, epoch, f'checkpoint_epoch_{epoch}.pt')
Latest¶
Always maintain the latest checkpoint:
Advanced Features¶
SafeTensors Support¶
from safetensors.torch import save_file, load_file
def save_safetensors(model, path):
"""Faster, safer checkpoint format."""
save_file(model.state_dict(), path)
Cloud Storage Backend¶
class S3CheckpointBackend:
"""Store checkpoints in S3 for durability."""
def upload_checkpoint(self, local_path):
# Upload to S3 with retry logic
pass
Distributed Checkpointing¶
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
def save_distributed_checkpoint(model: FSDP, path, rank):
"""Save FSDP model checkpoint from rank 0."""
if rank == 0:
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
torch.save(model.state_dict(), path)
Best Practices¶
- Checkpoint Frequency: Balance storage costs with recovery needs
- Save every N steps or epochs
- Use validation-triggered saves for best models
-
Consider adaptive frequency based on training stability
-
State Completeness: Save all training state together
- Model parameters
- Optimizer state (momentum, adaptive learning rates)
- Scheduler state
- Random number generator state
-
Training metadata (epoch, step, metrics)
-
Validation & Testing: Always test checkpoint recovery
- Load immediately after save to verify integrity
- Run validation forward pass to ensure model works
-
Check file size matches expected dimensions
-
Storage Management: Implement cleanup policies
- Retention: Keep best N by metric, latest N by time, or periodic N
- Archive: Move old checkpoints to cold storage
-
Compression: Use zstd/lz4 for large checkpoints
-
Error Handling: Robust checkpoint operations
- Handle disk full gracefully (keep previous checkpoint)
- Retry cloud uploads with exponential backoff
- Atomic saves (write to temp file, then rename)
Integration Points¶
- Training Loops: Auto-checkpoint in
agents/training_*.pymodules - Experiment Tracking: Log checkpoint paths to MLflow
- CI/CD Tests: Checkpoint save/load tests in
tests/agents/test_*_training.py - Distributed Training: FSDP/DDP checkpoint strategies in multi-GPU setups
- Cloud Backends: S3/Azure/GCS integration for persistence
Future Enhancements¶
- Incremental Checkpointing: Save only parameter deltas for massive models
- Checkpoint Sharding: Split across multiple files for parallel I/O
- Automatic Repair: Detect and recover from corrupted checkpoints
- Checkpoint Compression: Reduce storage with efficient compression
- Direct Remote Save: Skip local disk, write directly to cloud storage