Convolutional Neural Networks in PyTorch
Convolutional Neural Networks (CNNs) are specialized architectures for processing spatial data like images. Convolutional layers apply learnable filters across image regions, pooling layers downsample feature maps, and fully connected layers perform final classification. CNNs have achieved state-of-the-art results on image recognition, object detection, and segmentation tasks since their breakthrough in 2012.
Understanding convolution operations
A convolution applies a learnable filter (kernel) across image regions, computing dot products to extract features. The kernel slides across the image, capturing patterns like edges, textures, and shapes at different scales.
How convolution works: mechanics and visualization
import torch
import torch.nn as nn
import torch.nn.functional as F
# Create a single-channel input (like grayscale image)
x = torch.randn(1, 1, 5, 5) # Batch=1, Channels=1, Height=5, Width=5
# Create a 3x3 convolutional kernel
conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=0)
# Manual convolution example: 3x3 kernel on 5x5 image
print(f"Input shape: {x.shape}")
output = conv(x)
print(f"Output shape: {output.shape}") # [1, 1, 3, 3] (5-3+1=3)
# Convolution with padding preserves spatial dimensions
conv_padded = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1)
output_padded = conv_padded(x)
print(f"Output shape (padded): {output_padded.shape}") # [1, 1, 5, 5]
# Stride=2 downsamples
conv_stride = nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1)
output_stride = conv_stride(x)
print(f"Output shape (stride=2): {output_stride.shape}") # [1, 1, 3, 3]
Building a CNN for image classification
Stack convolutional blocks (conv + activation + pooling) to extract hierarchical features, then flatten and classify.
Simple CNN architecture for CIFAR-10
import torch
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
# Feature extraction layers
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
# Classification layers
self.fc1 = nn.Linear(128 * 4 * 4, 256) # 32->16->8->4 after 3 poolings
self.fc2 = nn.Linear(256, num_classes)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# Block 1
x = self.conv1(x)
x = self.relu(x)
x = self.pool(x) # 32x32 -> 16x16
# Block 2
x = self.conv2(x)
x = self.relu(x)
x = self.pool(x) # 16x16 -> 8x8
# Block 3
x = self.conv3(x)
x = self.relu(x)
x = self.pool(x) # 8x8 -> 4x4
# Flatten and classify
x = x.view(x.size(0), -1) # Flatten to [batch, 128*4*4]
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
# Test the model
model = SimpleCNN(num_classes=10)
x = torch.randn(8, 3, 32, 32) # Batch of 8 CIFAR-10 images
output = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}") # [8, 10]
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
Feature map visualization and understanding
Visualize what filters learn at different layers to understand CNN representations.
Extracting and inspecting intermediate feature maps
import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision.transforms import ToTensor, Normalize, Compose
# Load pretrained ResNet18
model = resnet18(pretrained=True)
model.eval()
# Create a hook to capture intermediate activations
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
# Register hook on a specific layer (e.g., first conv layer)
model.layer1[0].register_forward_hook(get_activation('layer1_0'))
# Forward pass on dummy image
x = torch.randn(1, 3, 224, 224)
_ = model(x)
# Access captured feature map
feature_map = activation['layer1_0']
print(f"Feature map shape: {feature_map.shape}") # [1, 64, 112, 112]
# Visualize a single filter's output
first_filter_output = feature_map[0, 0, :, :] # First batch, first channel
print(f"Single feature map shape: {first_filter_output.shape}") # [112, 112]
Pooling, stride, and receptive field
Pooling reduces spatial dimensions and increases receptive field—the region of input that influences each output.
| Operation | Effect | Receptive Field |
|---|---|---|
| Conv stride=1 | No spatial downsampling | Grows slowly |
| Conv stride=2 | Downsample by 2 | Grows faster |
| MaxPool 2x2 | Downsample by 2, keep max | Same as stride=2 conv |
| AvgPool 2x2 | Downsample by 2, average | Same as stride=2 conv |
Pooling and downsampling patterns
import torch
import torch.nn as nn
# Input: 32x32 image
x = torch.randn(1, 3, 32, 32)
# Method 1: Max pooling
max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
x_maxpool = max_pool(x)
print(f"After MaxPool: {x_maxpool.shape}") # [1, 3, 16, 16]
# Method 2: Average pooling
avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
x_avgpool = avg_pool(x)
print(f"After AvgPool: {x_avgpool.shape}") # [1, 3, 16, 16]
# Method 3: Stride in convolution (no separate pooling)
conv_stride = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
x_conv_stride = conv_stride(x)
print(f"After Conv stride=2: {x_conv_stride.shape}") # [1, 32, 16, 16]
# Adaptive pooling: output fixed size regardless of input
adaptive_pool = nn.AdaptiveAvgPool2d(output_size=(7, 7))
x_adaptive = adaptive_pool(torch.randn(1, 3, 32, 32))
print(f"After AdaptiveAvgPool to 7x7: {x_adaptive.shape}") # [1, 3, 7, 7]
Modern CNN architectures: ResNet and beyond
Use pretrained models for transfer learning or study their design patterns.
Loading and fine-tuning ResNet
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
# Load pretrained ResNet18
model = resnet18(pretrained=True)
# Freeze all parameters (feature extraction mode)
for param in model.parameters():
param.requires_grad = False
# Replace final classification layer for your task (5 classes instead of 1000)
num_classes = 5
model.fc = nn.Linear(model.fc.in_features, num_classes)
# Only train the new final layer
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
# Forward pass
x = torch.randn(8, 3, 224, 224)
output = model(x)
print(f"Output shape: {output.shape}") # [8, 5]
# Fine-tune: unfreeze later layers for more training
for param in model.layer4.parameters():
param.requires_grad = True
# Add more parameters to optimizer
optimizer = optim.Adam(
list(model.layer4.parameters()) + list(model.fc.parameters()),
lr=0.0001 # Lower learning rate for fine-tuning
)
print("Model ready for fine-tuning")
Batch normalization in CNNs
Batch normalization stabilizes training, allows higher learning rates, and acts as regularization.
Batch norm mechanics and usage
import torch
import torch.nn as nn
class CNNWithBatchNorm(nn.Module):
def __init__(self):
super(CNNWithBatchNorm, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32) # After each conv
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(64 * 8 * 8, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x) # Normalize after conv, before activation
x = self.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = CNNWithBatchNorm()
# Batch norm behaves differently in train vs eval mode
model.train() # Enable dropout and batch statistics
x_train = torch.randn(32, 3, 32, 32)
output_train = model(x_train)
model.eval() # Use running statistics (no randomness)
x_eval = torch.randn(32, 3, 32, 32)
output_eval = model(x_eval)
print(f"Train output: {output_train.shape}, Eval output: {output_eval.shape}")
Key Takeaways
- Convolutions apply learnable filters across spatial dimensions to extract features; stride downsamples, padding preserves dimensions.
- Build CNNs by stacking conv-relu-pool blocks for feature extraction, then fully connected layers for classification.
- Pooling (max or average) reduces spatial dimensions and increases receptive field, improving computational efficiency and translation invariance.
- Batch normalization stabilizes training and allows higher learning rates; always use
model.train()andmodel.eval()correctly. - Transfer learning with pretrained models (ResNet, VGG, EfficientNet) is more practical than training from scratch for most image tasks.
Frequently Asked Questions
What is the receptive field and why does it matter?
The receptive field is the region of input that influences an output neuron. Larger receptive fields capture broader context, critical for tasks like object detection. Design networks so that the final layers see the entire image (e.g., 224x224 image with receptive field ≥224).
Should I use stride=2 or MaxPool for downsampling?
Stride=2 convolutions and MaxPool have similar effects, but stride is learnable while MaxPool discards information. Modern networks favor stride=2 convolutions, but MaxPool is simpler and works well. Try both for your task.
How do I decide between global average pooling and flattening?
Global average pooling reduces each feature map to a scalar, reducing parameters and improving generalization. Flattening preserves spatial structure. Global average pooling is preferred for modern architectures; flatten only for small spatial dimensions.
Can I use 1x1 convolutions?
Yes, 1x1 convolutions are useful for channel reduction (decreasing memory) and adding nonlinearity without spatial interaction. Used extensively in ResNet and Inception architectures for computational efficiency.
Why do I get different results in train vs eval mode?
Batch normalization uses running statistics in eval mode and mini-batch statistics in train mode. Dropout is disabled in eval. Always call model.eval() before evaluation and model.train() before training to ensure correct behavior.