Skip to main content

Saving and Loading PyTorch Models

Saving and loading models is essential for training, evaluation, and deployment. PyTorch provides multiple approaches: state_dict() for checkpointing during training, torch.save() for complete model serialization, and ONNX export for cross-framework compatibility. Choosing the right approach impacts reproducibility, model size, and deployment flexibility.

State dict: saving parameters and buffers

The state dictionary contains all learnable parameters and non-trainable buffers (like batch norm statistics). Saving only state dicts is lightweight and flexible—you reload into a recreated model architecture.

Saving and loading state dicts

import torch
import torch.nn as nn

# Define model
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, 2)

def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x

# Create model, train briefly
model = SimpleNet()
x = torch.randn(32, 10)
y = model(x)

# Save state dict (lightweight, recommended)
torch.save(model.state_dict(), '/tmp/model_state.pth')

# Load state dict into a new model instance
model_new = SimpleNet()
model_new.load_state_dict(torch.load('/tmp/model_state.pth'))

# Verify parameters match
for p1, p2 in zip(model.parameters(), model_new.parameters()):
assert torch.allclose(p1, p2), "Parameters don't match"

print("Model state dict saved and loaded successfully")

# Inspect state dict
state = model.state_dict()
for name, param in state.items():
print(f"{name}: shape {param.shape}")

Checkpoint pattern for training

import torch
import torch.nn as nn
import torch.optim as optim

model = SimpleNet()
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

# Training loop with checkpointing
best_loss = float('inf')

for epoch in range(3):
x = torch.randn(64, 10)
y = torch.randint(0, 2, (64,))

output = model(x)
loss = criterion(output, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

# Save checkpoint if loss improves
if loss.item() < best_loss:
best_loss = loss.item()
checkpoint = {
'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'loss': loss.item(),
}
torch.save(checkpoint, '/tmp/best_checkpoint.pth')
print(f"Epoch {epoch}: New best loss {loss.item():.4f}, checkpoint saved")

# Resume training from checkpoint
checkpoint = torch.load('/tmp/best_checkpoint.pth')
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
start_epoch = checkpoint['epoch']

print(f"Resumed training from epoch {start_epoch} with loss {checkpoint['loss']:.4f}")

Complete model serialization: torch.save and torch.load

Save the entire model (architecture + weights) as a single file. More convenient but less flexible than state dicts.

ApproachFile SizeFlexibilityBest For
State dictSmallHigh (recreate architecture separately)Training, checkpoints
Complete modelLargerLow (architecture fixed in file)Quick prototyping, inference
ONNX exportMediumCross-frameworkDeployment, non-Python inference

Saving complete models

import torch
import torch.nn as nn

model = nn.Sequential(
nn.Linear(10, 64),
nn.ReLU(),
nn.Linear(64, 10)
)

# Save entire model (architecture + weights)
torch.save(model, '/tmp/complete_model.pth')

# Load entire model directly
loaded_model = torch.load('/tmp/complete_model.pth')

# Use immediately (no need to recreate architecture)
x = torch.randn(4, 10)
output = loaded_model(x)
print(f"Loaded model output shape: {output.shape}")

# Trade-off: less portable if code changes
# Better practice: always save state_dict and recreate architecture

Handling device compatibility

Ensure models and data are on the same device when loading.

Loading models on different devices

import torch
import torch.nn as nn

model = nn.Linear(10, 5)
x = torch.randn(2, 10)

# Save on GPU (if available)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
x = x.to(device)

torch.save(model.state_dict(), '/tmp/model.pth')

# Load on CPU (regardless of save device)
model_cpu = nn.Linear(10, 5)
model_cpu.load_state_dict(
torch.load('/tmp/model.pth', map_location='cpu')
)

# Load on GPU
model_gpu = nn.Linear(10, 5)
model_gpu.load_state_dict(
torch.load('/tmp/model.pth', map_location='cuda:0')
)

# Flexible: load to any device
def load_to_device(model, checkpoint_path, device):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint)
model = model.to(device)
return model

model = load_to_device(nn.Linear(10, 5), '/tmp/model.pth', device='cpu')
print(f"Model loaded to {next(model.parameters()).device}")

Exporting to ONNX for cross-framework compatibility

Export PyTorch models to ONNX (Open Neural Network Exchange) for inference in other frameworks (TensorRT, CoreML, TensorFlow).

Exporting and validating ONNX models

import torch
import torch.nn as nn
import onnx

class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, 2)

def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x

model = SimpleModel()
model.eval() # Set to eval mode (no dropout/batch norm randomness)

# Define dummy input for tracing
dummy_input = torch.randn(1, 10)

# Export to ONNX
torch.onnx.export(
model,
dummy_input,
'/tmp/model.onnx',
input_names=['input'],
output_names=['output'],
opset_version=14,
verbose=False
)

print("Model exported to ONNX")

# Load and validate ONNX model
onnx_model = onnx.load('/tmp/model.onnx')
onnx.checker.check_model(onnx_model)
print("ONNX model validation passed")

# Compare PyTorch vs ONNX outputs (requires onnxruntime)
try:
import onnxruntime as ort

# Run ONNX inference
session = ort.InferenceSession('/tmp/model.onnx')
input_data = dummy_input.numpy()
onnx_output = session.run(None, {'input': input_data})

# Run PyTorch inference
pytorch_output = model(dummy_input).detach().numpy()

# Compare
print(f"PyTorch output shape: {pytorch_output.shape}")
print(f"ONNX output shape: {onnx_output[0].shape}")
except ImportError:
print("onnxruntime not installed; skipping validation")

Model quantization for inference efficiency

Reduce model size and inference latency by converting to lower precision (int8, float16).

Quantization patterns

import torch
import torch.nn as nn

model = nn.Sequential(
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 10)
)

x = torch.randn(1, 100)

# Dynamic quantization (post-training)
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear}, # Layers to quantize
dtype=torch.qint8
)

print(f"Original model size: {sum(p.numel() for p in model.parameters()):,} params")
print(f"Quantized model size: {sum(p.numel() for p in quantized_model.parameters()):,} params")

# Inference speed comparison
import time

# Original model
start = time.time()
for _ in range(100):
_ = model(x)
original_time = time.time() - start

# Quantized model
start = time.time()
for _ in range(100):
_ = quantized_model(x)
quantized_time = time.time() - start

print(f"Original inference: {original_time:.4f}s")
print(f"Quantized inference: {quantized_time:.4f}s")
print(f"Speedup: {original_time / quantized_time:.2f}x")

Version compatibility and reproducibility

Handle PyTorch version changes and ensure reproducible loading.

Robust model loading

import torch
import torch.nn as nn

model = nn.Linear(10, 5)

# Save with version information
checkpoint = {
'model_state': model.state_dict(),
'pytorch_version': torch.__version__,
'architecture': 'Linear(10, 5)',
'timestamp': torch.datetime.now()
}

torch.save(checkpoint, '/tmp/model_versioned.pth')

# Load with validation
checkpoint = torch.load('/tmp/model_versioned.pth')

print(f"Model trained with PyTorch {checkpoint['pytorch_version']}")
print(f"Current PyTorch version: {torch.__version__}")

# Recreate model and load state
model = nn.Linear(10, 5)
model.load_state_dict(checkpoint['model_state'])

# For reproducibility, set random seeds
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)

print("Model loaded successfully with reproducible configuration")

Key Takeaways

  • Save state dicts (model.state_dict()) for checkpointing and training; they are lightweight and flexible, requiring architecture to be recreated.
  • Complete model serialization (torch.save(model)) is convenient but less portable; prefer state dicts for production code.
  • Use map_location when loading to ensure compatibility across devices (CPU/GPU).
  • Export to ONNX for cross-framework deployment and inference optimization in production systems.
  • Quantize models post-training to reduce size and inference latency by 2–4x with minimal accuracy loss.

Frequently Asked Questions

Should I save state dict or complete model?

Save state dict for production and long-term use; it's more portable and decouples architecture from weights. Save complete models only for quick prototyping. Always include metadata (epoch, loss, date) in checkpoint dictionaries.

How do I handle model architecture changes when loading old checkpoints?

Carefully match old state dict keys to new architecture keys. Manually map: new_dict[new_key] = old_dict[old_key]. For incompatible changes, train from scratch or use knowledge distillation to transfer weights to the new architecture.

Can I save and load models across PyTorch versions?

Usually yes, but breaking changes can occur. Save version information with the checkpoint. If loading fails, try torch.load(..., weights_only=True) (PyTorch 2.0+) for safer deserialization. Test compatibility before deploying.

How do I reduce model size for deployment?

Use quantization (int8 saves 75% size), pruning (remove small weights), and distillation (train a smaller student on a larger teacher). ONNX export and format optimization also help. Quantized models are 4x smaller with 5–10% accuracy loss.

What is the difference between scripting and exporting with ONNX?

torch.jit.script() compiles PyTorch to TorchScript for deployment without a Python interpreter. ONNX export creates a standard format runnable in any framework. Use TorchScript for PyTorch-only deployment, ONNX for cross-framework compatibility.

Further Reading