Skip to main content

PyTorch Training Loops: Step-by-Step

A PyTorch training loop is the core engine of model learning, repeating a cycle of forward propagation, loss computation, backpropagation, and parameter updates across batches of data. This iterative refinement process reduces loss by gradient descent, adjusting weights and biases to improve predictions. Understanding the mechanics of a complete training loop—including validation, early stopping, and learning rate scheduling—is essential for building effective models.

The standard training loop structure

Every PyTorch training loop follows the same pattern: initialize a model and optimizer, loop over epochs and batches, compute loss, backpropagate, update parameters, and evaluate on validation data. According to PyTorch tutorials (2026), this structure is the foundation for training any supervised learning model.

Basic training loop example

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Create synthetic training data
X_train = torch.randn(100, 20) # 100 samples, 20 features
y_train = torch.randint(0, 3, (100,)) # 3 classes

# Create a simple model
class SimpleClassifier(nn.Module):
def __init__(self, input_size, num_classes):
super(SimpleClassifier, self).__init__()
self.fc1 = nn.Linear(input_size, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, num_classes)

def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x

model = SimpleClassifier(input_size=20, num_classes=3)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Create data loader
dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
epoch_loss = 0
for batch_X, batch_y in train_loader:
# Forward pass
outputs = model(batch_X)
loss = criterion(outputs, batch_y)

# Backward pass and optimization
optimizer.zero_grad() # Clear old gradients
loss.backward() # Compute gradients
optimizer.step() # Update parameters

epoch_loss += loss.item()

avg_loss = epoch_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

Loss functions and their selection

Choose loss functions based on the task. Classification uses cross-entropy, regression uses mean squared error, and specialized tasks have specialized losses.

TaskLoss FunctionPyTorch ClassOutput Shape
Multi-class classificationCross-entropynn.CrossEntropyLoss()Logits [N, C]
Binary classificationBinary cross-entropynn.BCELoss()Probabilities [N, 1]
RegressionMean squared errornn.MSELoss()Predictions [N]
Multi-label classificationBinary cross-entropy with logitsnn.BCEWithLogitsLoss()Logits [N, C]

Using different loss functions

import torch
import torch.nn as nn

# Classification loss
classification_loss = nn.CrossEntropyLoss()
logits = torch.randn(8, 5) # Batch of 8, 5 classes
targets = torch.tensor([0, 2, 1, 3, 2, 0, 4, 1])
loss = classification_loss(logits, targets)
print(f"Classification loss: {loss.item():.4f}")

# Regression loss
regression_loss = nn.MSELoss()
predictions = torch.randn(16, 1)
targets = torch.randn(16, 1)
loss = regression_loss(predictions, targets)
print(f"Regression loss: {loss.item():.4f}")

# Binary classification with sigmoid
bce_loss = nn.BCEWithLogitsLoss() # Includes sigmoid internally
logits = torch.randn(10, 1)
labels = torch.randint(0, 2, (10, 1)).float()
loss = bce_loss(logits, labels)
print(f"Binary cross-entropy loss: {loss.item():.4f}")

Optimizers and gradient updates

Optimizers apply gradients to parameters using various algorithms—SGD, Adam, RMSprop—each with different learning dynamics.

Common optimizer patterns

import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(10, 5)
model.train()

# SGD with momentum
optimizer_sgd = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Adam (adaptive learning rate, often requires less tuning)
optimizer_adam = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))

# RMSprop (good for RNNs)
optimizer_rmsprop = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.99)

# Using Adam optimizer in a mini training loop
criterion = nn.MSELoss()
optimizer = optimizer_adam

x = torch.randn(32, 10)
y = torch.randn(32, 5)

for step in range(3):
# Forward pass
output = model(x)
loss = criterion(output, y)

# Backward pass
optimizer.zero_grad()
loss.backward()

# Inspect gradients before update (optional)
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm()

# Optimizer step updates all parameters
optimizer.step()

print(f"Step {step}, Loss: {loss.item():.4f}")

Training and validation cycles

Separate training from validation to monitor generalization. Training updates parameters; validation evaluates performance on unseen data without gradient tracking.

Complete training and validation loop

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Synthetic data
X_train = torch.randn(200, 20)
y_train = torch.randint(0, 3, (200,))
X_val = torch.randn(50, 20)
y_val = torch.randint(0, 3, (50,))

model = nn.Sequential(
nn.Linear(20, 64),
nn.ReLU(),
nn.Linear(64, 3)
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=16, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=16)

num_epochs = 5

for epoch in range(num_epochs):
# Training phase
model.train()
train_loss = 0
for batch_X, batch_y in train_loader:
outputs = model(batch_X)
loss = criterion(outputs, batch_y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

train_loss += loss.item()

# Validation phase (no gradient tracking)
model.eval()
val_loss = 0
correct = 0
total = 0

with torch.no_grad():
for batch_X, batch_y in val_loader:
outputs = model(batch_X)
loss = criterion(outputs, batch_y)
val_loss += loss.item()

# Compute accuracy
predictions = outputs.argmax(dim=1)
correct += (predictions == batch_y).sum().item()
total += batch_y.size(0)

avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
accuracy = 100 * correct / total

print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {avg_train_loss:.4f} | "
f"Val Loss: {avg_val_loss:.4f} | Accuracy: {accuracy:.2f}%")

Gradient clipping and numerical stability

Prevent exploding gradients (common in RNNs) by clipping gradient norms.

Implementing gradient clipping

import torch
import torch.nn as nn
import torch.optim as optim

model = nn.LSTM(input_size=20, hidden_size=64, batch_first=True)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Synthetic sequential data
x = torch.randn(8, 10, 20) # Batch of 8, sequence length 10, 20 features
y = torch.randn(8, 10, 1)

# Training step with gradient clipping
output, _ = model(x)
loss = criterion(output, y)

optimizer.zero_grad()
loss.backward()

# Clip gradients: max norm = 1.0
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

optimizer.step()

print(f"Loss: {loss.item():.4f}")
print("Gradient clipping applied to prevent explosion")

Learning rate scheduling and adaptive training

Adjust learning rates during training to improve convergence. High initial rates enable fast learning; lower rates near the end refine solutions.

Learning rate schedulers

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau

model = nn.Linear(10, 1)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Step decay: reduce LR by 0.5 every 3 epochs
scheduler_step = StepLR(optimizer, step_size=3, gamma=0.5)

# Plateau reduction: reduce LR when validation loss stops improving
scheduler_plateau = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

# Example: using StepLR
x = torch.randn(32, 10)
y = torch.randn(32, 1)
criterion = nn.MSELoss()

for epoch in range(6):
output = model(x)
loss = criterion(output, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

# Step scheduler after each epoch
scheduler_step.step()

current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch}: Loss = {loss.item():.4f}, LR = {current_lr:.6f}")

# Using ReduceLROnPlateau (typically with validation loss)
for epoch in range(6):
output = model(x)
loss = criterion(output, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

# Scheduler step based on validation metric
scheduler_plateau.step(loss.item())

current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch}: Validation Loss = {loss.item():.4f}, LR = {current_lr:.6f}")

Key Takeaways

  • The training loop consists of: forward pass, loss computation, backward pass (backpropagation), and optimizer step (parameter update).
  • Always call optimizer.zero_grad() before backward() to prevent gradient accumulation; call optimizer.step() after to apply updates.
  • Use model.train() during training to enable dropout/batch norm, and model.eval() with torch.no_grad() during validation to prevent gradient tracking.
  • Choose loss functions based on task: cross-entropy for classification, MSE for regression, BCE for binary problems.
  • Learning rate schedulers adapt the learning rate during training; StepLR reduces at fixed intervals, ReduceLROnPlateau reduces when validation performance plateaus.

Frequently Asked Questions

Why do I call optimizer.zero_grad() instead of manually zeroing each parameter?

optimizer.zero_grad() is a convenience method that zeros gradients for all parameters in one call. Manually zeroing is equivalent but verbose. Always call it before backward() to prevent gradients from accumulating across multiple backward passes.

What is the difference between loss.backward() and loss.backward(retain_graph=True)?

By default, backward() frees the computation graph after computing gradients. retain_graph=True keeps it in memory, allowing you to call backward() again on the same graph. Use this for multiple loss terms or debugging, but it uses extra memory.

How do I prevent overfitting during training?

Use validation monitoring and early stopping: track validation loss and stop training if it stops improving for several epochs. Add regularization (L1/L2) to loss: loss = criterion(...) + 0.0001 * sum(p.abs().sum() for p in model.parameters()). Use dropout and batch normalization.

What does shuffle=True do in DataLoader?

shuffle=True randomizes the order of samples in each epoch, preventing the model from learning dependencies on data order. Essential for good generalization. Use shuffle=False only for evaluation or when order matters (e.g., time series).

How do I track multiple metrics during training?

Maintain dictionaries or lists and compute metrics alongside loss: accuracy = (outputs.argmax(1) == targets).float().mean(). Log them to tensorboard using torch.utils.tensorboard.SummaryWriter for visualization.

Further Reading