Skip to main content

Debugging and Optimizing PyTorch Models

Debugging and optimization are critical skills for productive deep learning work. Common issues include shape mismatches, gradient problems, memory leaks, and slow training. Profiling identifies bottlenecks—whether training is compute-bound (optimize model) or I/O-bound (optimize data loading). Modern PyTorch provides built-in profilers, gradient debugging tools, and optimization utilities to accelerate development and deployment.

Common training issues and diagnostics

Identify and fix frequent problems that prevent or slow convergence.

Diagnosing shape mismatches and tensor errors

import torch
import torch.nn as nn

# Issue: Shape mismatch in forward pass
class BuggyModel(nn.Module):
def __init__(self):
super(BuggyModel, self).__init__()
self.fc1 = nn.Linear(10, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 5) # Outputs 5, but try to add to 10-dim vector

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
# Bug: adding a vector without reshaping
return x + x[:10] # Shape mismatch: [batch, 32] vs [batch, 10]

# Catch and understand errors
try:
model = BuggyModel()
x = torch.randn(16, 10)
output = model(x)
except RuntimeError as e:
print(f"Error caught: {e}")
print("Solution: ensure all tensor operations produce compatible shapes")

# Fixed version
class FixedModel(nn.Module):
def __init__(self):
super(FixedModel, self).__init__()
self.fc1 = nn.Linear(10, 64)
self.fc2 = nn.Linear(64, 32)

def forward(self, x):
return self.fc2(self.fc1(x))

model = FixedModel()
output = model(torch.randn(16, 10))
print(f"Fixed model output shape: {output.shape}")

Detecting NaN and infinite gradients

import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(10, 5)
optimizer = optim.SGD(model.parameters(), lr=1.0) # Intentionally high LR
criterion = nn.MSELoss()

# Training with potential numerical instability
for step in range(5):
x = torch.randn(32, 10)
y = torch.randn(32, 5)

output = model(x)
loss = criterion(output, y)

optimizer.zero_grad()
loss.backward()

# Check for NaN or Inf in gradients
has_nan = False
for name, param in model.named_parameters():
if param.grad is not None:
if torch.isnan(param.grad).any():
print(f"NaN detected in {name}.grad")
has_nan = True
if torch.isinf(param.grad).any():
print(f"Inf detected in {name}.grad")
has_nan = True

# Check loss value
if torch.isnan(loss):
print(f"Step {step}: Loss is NaN, training diverged")
break
elif torch.isinf(loss):
print(f"Step {step}: Loss is Inf, likely exploding gradients")
break

optimizer.step()

print(f"Step {step}: Loss = {loss.item():.4f}")

# Solution: clip gradients or use lower learning rate
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
print("Gradient clipping applied")

Profiling and identifying bottlenecks

Use PyTorch profilers to measure execution time and find which operations are slow.

CPU and GPU profiling with torch.profiler

import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity

model = nn.Sequential(
nn.Linear(1000, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 10)
).cuda()

x = torch.randn(64, 1000).cuda()

# Profile CPU and GPU operations
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True
) as prof:
with record_function("model_inference"):
output = model(x)

# Print top operations by time
print(prof.key_averages().table(sort_by="cuda_time_total" if torch.cuda.is_available() else "cpu_time_total", row_limit=10))

# Export to chrome trace
prof.export_chrome_trace('/tmp/pytorch_trace.json')
print("Chrome trace saved to /tmp/pytorch_trace.json")

Identifying I/O vs compute bottlenecks

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import time

# Create dummy dataset
X = torch.randn(10000, 100)
y = torch.randint(0, 10, (10000,))
dataset = TensorDataset(X, y)

model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 10)
).cuda()

criterion = nn.CrossEntropyLoss()

# Test with different num_workers
for num_workers in [0, 4]:
loader = DataLoader(
dataset,
batch_size=128,
shuffle=True,
num_workers=num_workers,
pin_memory=True if num_workers > 0 else False
)

# Measure data loading time
start = time.time()
data_load_times = []

for batch_x, batch_y in loader:
data_load_times.append(time.time() - start)

# Measure compute time
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()

output = model(batch_x)
loss = criterion(output, batch_y)
loss.backward()

start = time.time()

avg_data_time = sum(data_load_times) / len(data_load_times)
print(f"num_workers={num_workers}: Avg data load time = {avg_data_time*1000:.2f}ms")

if num_workers > 0 and avg_data_time < 0.001:
print(" -> Compute-bound: increase batch size or model complexity")
elif avg_data_time > 0.01:
print(" -> I/O-bound: increase num_workers or use SSD")

Memory profiling and optimization

Identify memory hotspots and reduce consumption.

Finding memory leaks with torch.cuda

import torch
import torch.nn as nn

def memory_usage():
return torch.cuda.memory_allocated() / 1e9

# Create model
model = nn.Sequential(
nn.Linear(10000, 5000),
nn.ReLU(),
nn.Linear(5000, 2000)
).cuda()

x = torch.randn(256, 10000).cuda()

print(f"Initial memory: {memory_usage():.2f} GB")

# Forward pass
output = model(x)
print(f"After forward: {memory_usage():.2f} GB")

# Backward pass (stores activations)
loss = output.sum()
loss.backward()
print(f"After backward: {memory_usage():.2f} GB")

# Clear activations
torch.cuda.empty_cache()
print(f"After cache clear: {memory_usage():.2f} GB")

# Check retained tensors
print(f"Memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Use torch.cuda.max_memory_allocated() for peak memory
peak_memory = torch.cuda.max_memory_allocated() / 1e9
print(f"Peak memory used: {peak_memory:.2f} GB")

# Reset for next measurement
torch.cuda.reset_peak_memory_stats()

Debugging gradient flow and training convergence

Track gradients to understand training dynamics.

IssueSymptomSolution
Vanishing gradientsGradients shrink to ~0Use batch norm, skip connections, lower LR
Exploding gradientsGradients → Inf or NaNClip gradients, lower learning rate, reduce batch size
Dead neuronsSome units always inactiveUse ReLU alternatives (LeakyReLU), initialize properly
Slow convergenceLoss plateaus earlyIncrease learning rate, adjust architecture, verify data

Monitoring gradient statistics

import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10)
).cuda()

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training with gradient monitoring
for epoch in range(3):
x = torch.randn(128, 100).cuda()
y = torch.randint(0, 10, (128,)).cuda()

output = model(x)
loss = criterion(output, y)

optimizer.zero_grad()
loss.backward()

# Log gradient statistics
total_norm = 0
param_count = 0
for name, param in model.named_parameters():
if param.grad is not None:
param_norm = param.grad.data.norm(2)
total_norm += param_norm.item() ** 2
param_count += 1

total_norm = total_norm ** 0.5
avg_grad = total_norm / param_count if param_count > 0 else 0

optimizer.step()

print(f"Epoch {epoch}: Loss = {loss.item():.4f}, Avg grad norm = {avg_grad:.6f}")

Performance optimization strategies

Apply targeted optimizations based on profiling results.

Optimizing for training speed

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.cuda.amp import autocast, GradScaler
import time

# Model
model = nn.Sequential(
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 100)
).cuda()

# Data
X = torch.randn(10000, 512).cuda()
y = torch.randint(0, 100, (10000,)).cuda()
loader = DataLoader(
TensorDataset(X, y),
batch_size=256,
shuffle=True,
pin_memory=True,
num_workers=4
)

optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

# Baseline: standard training
print("Baseline training:")
start = time.time()
for batch_x, batch_y in loader:
output = model(batch_x)
loss = criterion(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
baseline_time = time.time() - start
print(f" Time: {baseline_time:.2f}s")

# Optimized: mixed precision + gradient accumulation
print("\nOptimized training:")
accumulation_steps = 2
start = time.time()
optimizer.zero_grad()
for i, (batch_x, batch_y) in enumerate(loader):
with autocast():
output = model(batch_x)
loss = criterion(output, batch_y) / accumulation_steps

scaler.scale(loss).backward()

if (i + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

optimized_time = time.time() - start
print(f" Time: {optimized_time:.2f}s")
print(f" Speedup: {baseline_time / optimized_time:.2f}x")

Debugging common convergence issues

Address training that doesn't improve or diverges.

Systematic debugging checklist

import torch
import torch.nn as nn

# Checklist for non-converging training
def diagnose_convergence(model, criterion, x_sample, y_sample):
"""Diagnose why model isn't converging."""

print("=== Convergence Diagnosis ===\n")

# 1. Check loss value
output = model(x_sample)
loss = criterion(output, y_sample)
print(f"1. Loss value: {loss.item():.4f}")
if torch.isnan(loss) or torch.isinf(loss):
print(" Problem: Loss is NaN/Inf → reduce learning rate or clip gradients\n")
return

# 2. Check gradients exist and are non-zero
loss.backward()
grad_norms = []
for param in model.parameters():
if param.grad is not None:
grad_norms.append(param.grad.norm().item())

if not grad_norms:
print("2. Gradients: NONE found")
print(" Problem: No gradients computed → verify loss depends on model output\n")
return

avg_grad = sum(grad_norms) / len(grad_norms)
print(f"2. Average gradient norm: {avg_grad:.6f}")
if avg_grad < 1e-7:
print(" Problem: Vanishing gradients → use batch norm, skip connections\n")
elif avg_grad > 10:
print(" Problem: Exploding gradients → clip gradients, lower LR\n")

# 3. Check model output range
print(f"3. Model output range: [{output.min().item():.2f}, {output.max().item():.2f}]")
if output.std() < 1e-3:
print(" Problem: Very low output variance → check weight initialization\n")

# 4. Check if model can overfit a single batch
print("4. Overfitting test (10 iterations on 1 batch):")
model.train()
opt = torch.optim.SGD(model.parameters(), lr=0.1)

for i in range(10):
opt.zero_grad()
out = model(x_sample)
l = criterion(out, y_sample)
l.backward()
opt.step()

if i % 3 == 0:
print(f" Iteration {i}: loss = {l.item():.4f}")

print(" If loss decreases: model can learn; check data/labels")
print(" If loss constant: architecture issue; check model forward pass\n")

# Test diagnosis
model = nn.Linear(10, 2)
x = torch.randn(8, 10)
y = torch.randint(0, 2, (8,))
criterion = nn.CrossEntropyLoss()

diagnose_convergence(model, criterion, x, y)

Key Takeaways

  • Debug shape mismatches early using assertions and .shape checks; catch NaN/Inf in loss and gradients to detect numerical instability.
  • Profile to identify bottlenecks: CPU/GPU activity shows if training is compute-bound (optimize model) or I/O-bound (optimize data loading).
  • Monitor gradient norms to detect vanishing/exploding gradients; use gradient clipping, batch norm, and careful learning rate selection to stabilize training.
  • Measure peak GPU memory with torch.cuda.max_memory_allocated(); use mixed precision and gradient accumulation to fit larger models.
  • Test convergence on single batches—if model can't overfit, it's an architecture/data issue; if it can, the problem is optimization or generalization.

Frequently Asked Questions

My model trains but validation loss doesn't decrease. What's wrong?

Likely overfitting or data leakage. Check: (1) train/val data are separated, (2) validation data has consistent preprocessing, (3) model is set to eval() mode with torch.no_grad(). Use data augmentation and regularization (L2, dropout) to improve generalization.

How do I know if my learning rate is too high?

Loss diverges (becomes NaN/Inf) or oscillates wildly. Also check gradient norms—they'll be huge. Start with a small learning rate (0.0001) and gradually increase until training becomes unstable, then use half that value.

Why is my GPU utilization low despite a large batch size?

Data loading may be the bottleneck. Increase num_workers and enable pin_memory=True. Also check: model forward/backward don't have Python loops (use batched operations), and you're using torch.no_grad() during evaluation.

How do I handle class imbalance?

Use WeightedRandomSampler to oversample minority classes, or weight the loss: loss = criterion(output, target) * class_weights[target]. Alternatively, use focal loss which down-weights easy examples.

What should I do if training is extremely slow?

Profile first: identify if compute-bound or I/O-bound. For compute, use mixed precision, reduce model size, increase batch size. For I/O, increase num_workers, use SSD storage, pre-load data. GPU utilization should be >80%.

Further Reading