Saving and Loading PyTorch Models
Saving and loading models is essential for training, evaluation, and deployment. PyTorch provides multiple approaches: state_dict() for checkpointing during training, torch.save() for complete model serialization, and ONNX export for cross-framework compatibility. Choosing the right approach impacts reproducibility, model size, and deployment flexibility.
State dict: saving parameters and buffers
The state dictionary contains all learnable parameters and non-trainable buffers (like batch norm statistics). Saving only state dicts is lightweight and flexible—you reload into a recreated model architecture.
Saving and loading state dicts
import torch
import torch.nn as nn
# Define model
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, 2)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Create model, train briefly
model = SimpleNet()
x = torch.randn(32, 10)
y = model(x)
# Save state dict (lightweight, recommended)
torch.save(model.state_dict(), '/tmp/model_state.pth')
# Load state dict into a new model instance
model_new = SimpleNet()
model_new.load_state_dict(torch.load('/tmp/model_state.pth'))
# Verify parameters match
for p1, p2 in zip(model.parameters(), model_new.parameters()):
assert torch.allclose(p1, p2), "Parameters don't match"
print("Model state dict saved and loaded successfully")
# Inspect state dict
state = model.state_dict()
for name, param in state.items():
print(f"{name}: shape {param.shape}")
Checkpoint pattern for training
import torch
import torch.nn as nn
import torch.optim as optim
model = SimpleNet()
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
# Training loop with checkpointing
best_loss = float('inf')
for epoch in range(3):
x = torch.randn(64, 10)
y = torch.randint(0, 2, (64,))
output = model(x)
loss = criterion(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Save checkpoint if loss improves
if loss.item() < best_loss:
best_loss = loss.item()
checkpoint = {
'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'loss': loss.item(),
}
torch.save(checkpoint, '/tmp/best_checkpoint.pth')
print(f"Epoch {epoch}: New best loss {loss.item():.4f}, checkpoint saved")
# Resume training from checkpoint
checkpoint = torch.load('/tmp/best_checkpoint.pth')
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
start_epoch = checkpoint['epoch']
print(f"Resumed training from epoch {start_epoch} with loss {checkpoint['loss']:.4f}")
Complete model serialization: torch.save and torch.load
Save the entire model (architecture + weights) as a single file. More convenient but less flexible than state dicts.
| Approach | File Size | Flexibility | Best For |
|---|---|---|---|
| State dict | Small | High (recreate architecture separately) | Training, checkpoints |
| Complete model | Larger | Low (architecture fixed in file) | Quick prototyping, inference |
| ONNX export | Medium | Cross-framework | Deployment, non-Python inference |
Saving complete models
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(10, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
# Save entire model (architecture + weights)
torch.save(model, '/tmp/complete_model.pth')
# Load entire model directly
loaded_model = torch.load('/tmp/complete_model.pth')
# Use immediately (no need to recreate architecture)
x = torch.randn(4, 10)
output = loaded_model(x)
print(f"Loaded model output shape: {output.shape}")
# Trade-off: less portable if code changes
# Better practice: always save state_dict and recreate architecture
Handling device compatibility
Ensure models and data are on the same device when loading.
Loading models on different devices
import torch
import torch.nn as nn
model = nn.Linear(10, 5)
x = torch.randn(2, 10)
# Save on GPU (if available)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
x = x.to(device)
torch.save(model.state_dict(), '/tmp/model.pth')
# Load on CPU (regardless of save device)
model_cpu = nn.Linear(10, 5)
model_cpu.load_state_dict(
torch.load('/tmp/model.pth', map_location='cpu')
)
# Load on GPU
model_gpu = nn.Linear(10, 5)
model_gpu.load_state_dict(
torch.load('/tmp/model.pth', map_location='cuda:0')
)
# Flexible: load to any device
def load_to_device(model, checkpoint_path, device):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint)
model = model.to(device)
return model
model = load_to_device(nn.Linear(10, 5), '/tmp/model.pth', device='cpu')
print(f"Model loaded to {next(model.parameters()).device}")
Exporting to ONNX for cross-framework compatibility
Export PyTorch models to ONNX (Open Neural Network Exchange) for inference in other frameworks (TensorRT, CoreML, TensorFlow).
Exporting and validating ONNX models
import torch
import torch.nn as nn
import onnx
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, 2)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = SimpleModel()
model.eval() # Set to eval mode (no dropout/batch norm randomness)
# Define dummy input for tracing
dummy_input = torch.randn(1, 10)
# Export to ONNX
torch.onnx.export(
model,
dummy_input,
'/tmp/model.onnx',
input_names=['input'],
output_names=['output'],
opset_version=14,
verbose=False
)
print("Model exported to ONNX")
# Load and validate ONNX model
onnx_model = onnx.load('/tmp/model.onnx')
onnx.checker.check_model(onnx_model)
print("ONNX model validation passed")
# Compare PyTorch vs ONNX outputs (requires onnxruntime)
try:
import onnxruntime as ort
# Run ONNX inference
session = ort.InferenceSession('/tmp/model.onnx')
input_data = dummy_input.numpy()
onnx_output = session.run(None, {'input': input_data})
# Run PyTorch inference
pytorch_output = model(dummy_input).detach().numpy()
# Compare
print(f"PyTorch output shape: {pytorch_output.shape}")
print(f"ONNX output shape: {onnx_output[0].shape}")
except ImportError:
print("onnxruntime not installed; skipping validation")
Model quantization for inference efficiency
Reduce model size and inference latency by converting to lower precision (int8, float16).
Quantization patterns
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 10)
)
x = torch.randn(1, 100)
# Dynamic quantization (post-training)
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear}, # Layers to quantize
dtype=torch.qint8
)
print(f"Original model size: {sum(p.numel() for p in model.parameters()):,} params")
print(f"Quantized model size: {sum(p.numel() for p in quantized_model.parameters()):,} params")
# Inference speed comparison
import time
# Original model
start = time.time()
for _ in range(100):
_ = model(x)
original_time = time.time() - start
# Quantized model
start = time.time()
for _ in range(100):
_ = quantized_model(x)
quantized_time = time.time() - start
print(f"Original inference: {original_time:.4f}s")
print(f"Quantized inference: {quantized_time:.4f}s")
print(f"Speedup: {original_time / quantized_time:.2f}x")
Version compatibility and reproducibility
Handle PyTorch version changes and ensure reproducible loading.
Robust model loading
import torch
import torch.nn as nn
model = nn.Linear(10, 5)
# Save with version information
checkpoint = {
'model_state': model.state_dict(),
'pytorch_version': torch.__version__,
'architecture': 'Linear(10, 5)',
'timestamp': torch.datetime.now()
}
torch.save(checkpoint, '/tmp/model_versioned.pth')
# Load with validation
checkpoint = torch.load('/tmp/model_versioned.pth')
print(f"Model trained with PyTorch {checkpoint['pytorch_version']}")
print(f"Current PyTorch version: {torch.__version__}")
# Recreate model and load state
model = nn.Linear(10, 5)
model.load_state_dict(checkpoint['model_state'])
# For reproducibility, set random seeds
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
print("Model loaded successfully with reproducible configuration")
Key Takeaways
- Save state dicts (
model.state_dict()) for checkpointing and training; they are lightweight and flexible, requiring architecture to be recreated. - Complete model serialization (
torch.save(model)) is convenient but less portable; prefer state dicts for production code. - Use
map_locationwhen loading to ensure compatibility across devices (CPU/GPU). - Export to ONNX for cross-framework deployment and inference optimization in production systems.
- Quantize models post-training to reduce size and inference latency by 2–4x with minimal accuracy loss.
Frequently Asked Questions
Should I save state dict or complete model?
Save state dict for production and long-term use; it's more portable and decouples architecture from weights. Save complete models only for quick prototyping. Always include metadata (epoch, loss, date) in checkpoint dictionaries.
How do I handle model architecture changes when loading old checkpoints?
Carefully match old state dict keys to new architecture keys. Manually map: new_dict[new_key] = old_dict[old_key]. For incompatible changes, train from scratch or use knowledge distillation to transfer weights to the new architecture.
Can I save and load models across PyTorch versions?
Usually yes, but breaking changes can occur. Save version information with the checkpoint. If loading fails, try torch.load(..., weights_only=True) (PyTorch 2.0+) for safer deserialization. Test compatibility before deploying.
How do I reduce model size for deployment?
Use quantization (int8 saves 75% size), pruning (remove small weights), and distillation (train a smaller student on a larger teacher). ONNX export and format optimization also help. Quantized models are 4x smaller with 5–10% accuracy loss.
What is the difference between scripting and exporting with ONNX?
torch.jit.script() compiles PyTorch to TorchScript for deployment without a Python interpreter. ONNX export creates a standard format runnable in any framework. Use TorchScript for PyTorch-only deployment, ONNX for cross-framework compatibility.