Transfer Learning in PyTorch
Transfer learning reuses knowledge from large-scale pretraining (ImageNet, COCO) to solve new tasks with limited data. Instead of training from scratch, you load pretrained weights, adapt the final layers for your task, and fine-tune selectively. This approach reduces training time by 10–100x and dramatically improves performance on small datasets. Transfer learning is the standard practice in production computer vision and NLP.
Understanding transfer learning fundamentals
Transfer learning leverages the hypothesis that early-layer features (edges, textures) learned on large datasets generalize across domains, while later layers are task-specific. By freezing early layers and retraining later ones, you adapt without overfitting.
Feature extraction vs fine-tuning strategies
import torch
import torch.nn as nn
from torchvision.models import resnet50
# Load pretrained ResNet50 (trained on ImageNet)
pretrained_model = resnet50(pretrained=True)
# Strategy 1: Feature extraction (freeze all but final layer)
for param in pretrained_model.parameters():
param.requires_grad = False # Freeze all parameters
# Replace final classification layer
num_classes = 10 # Your task
pretrained_model.fc = nn.Linear(pretrained_model.fc.in_features, num_classes)
# Only final layer is trainable
trainable_params = sum(p.numel() for p in pretrained_model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in pretrained_model.parameters())
print(f"Feature extraction mode: {trainable_params:,} / {total_params:,} trainable")
# Strategy 2: Fine-tuning (unfreeze later layers)
for param in pretrained_model.parameters():
param.requires_grad = True # Unfreeze all
# Lower learning rate for fine-tuning
import torch.optim as optim
optimizer = optim.Adam(
[{'params': pretrained_model.layer4.parameters(), 'lr': 0.0001},
{'params': pretrained_model.fc.parameters(), 'lr': 0.001}],
lr=0.0001 # Default for early layers
)
trainable_params = sum(p.numel() for p in pretrained_model.parameters() if p.requires_grad)
print(f"Fine-tuning mode: {trainable_params:,} / {total_params:,} trainable")
Choosing pretrained models for your task
Select architectures and weights based on your task domain, model size, and computational budget.
| Model | Pretrain Dataset | Speed | Accuracy | Use Case |
|---|---|---|---|---|
| ResNet50 | ImageNet | Medium | High | General vision tasks |
| EfficientNetB0 | ImageNet | Fast | Very High | Mobile/edge deployment |
| Vision Transformer | ImageNet21K | Slow | Very High | Large datasets, fine details |
| CLIP | Internet scale | Very Slow | Excellent | Zero-shot, multimodal tasks |
Loading different pretrained architectures
import torch
from torchvision.models import (
resnet18, resnet50, vgg16,
mobilenet_v2, efficientnet_b0
)
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
# ResNet family: good balance of speed and accuracy
resnet_18 = resnet18(pretrained=True)
resnet_50 = resnet50(pretrained=True)
# VGG: simple architecture, large memory footprint
vgg = vgg16(pretrained=True)
# Mobile-optimized: fast, lower accuracy
mobilenet = mobilenet_v2(pretrained=True)
# Efficient: excellent accuracy-to-speed tradeoff
efficientnet = efficientnet_b0(pretrained=True)
# Expected input preprocessing (ImageNet normalization)
transform = Compose([
Resize(224), # Resize to model's expected input
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Test forward pass
x = torch.randn(2, 3, 224, 224)
output = resnet_50(x)
print(f"ResNet50 output shape: {output.shape}") # [2, 1000] (ImageNet classes)
Adapting pretrained models to new tasks
Replace the final layer, adjust input/output shapes, and prepare for fine-tuning.
Adapting ResNet for a custom 5-class task
import torch
import torch.nn as nn
from torchvision.models import resnet50
# Load pretrained model
model = resnet50(pretrained=True)
# Inspect original final layer
print(f"Original fc layer: {model.fc}") # Linear(2048, 1000)
# Approach 1: Replace final layer
model.fc = nn.Linear(2048, 5) # 5 classes for custom task
# Approach 2: Add custom layers after feature extraction
class CustomResNet(nn.Module):
def __init__(self, num_classes=5):
super(CustomResNet, self).__init__()
# Load pretrained ResNet50 (remove final layer)
backbone = resnet50(pretrained=True)
self.features = nn.Sequential(*list(backbone.children())[:-1])
# Custom classification head
self.classifier = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
def forward(self, x):
# Extract features
features = self.features(x)
features = features.view(features.size(0), -1)
# Classify
output = self.classifier(features)
return output
# Test custom model
model = CustomResNet(num_classes=5)
x = torch.randn(4, 3, 224, 224)
output = model(x)
print(f"Custom ResNet output shape: {output.shape}") # [4, 5]
Progressive fine-tuning strategies
Train in stages, unfreezing layers progressively to prevent catastrophic forgetting and improve convergence.
Layer-by-layer fine-tuning schedule
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50
from torch.utils.data import DataLoader, TensorDataset
# Load pretrained model
model = resnet50(pretrained=True)
model.fc = nn.Linear(2048, 10) # Custom task: 10 classes
# Create dummy data for illustration
X_train = torch.randn(200, 3, 224, 224)
y_train = torch.randint(0, 10, (200,))
loader = DataLoader(
TensorDataset(X_train, y_train),
batch_size=32,
shuffle=True
)
criterion = nn.CrossEntropyLoss()
# Stage 1: Train final layer only (5 epochs)
print("Stage 1: Training final layer")
for param in model.parameters():
param.requires_grad = False
for param in model.fc.parameters():
param.requires_grad = True
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
for epoch in range(1):
for batch_x, batch_y in loader:
output = model(batch_x)
loss = criterion(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f" Stage 1 Epoch {epoch + 1}: Loss = {loss.item():.4f}")
# Stage 2: Unfreeze layer4 and fine-tune
print("Stage 2: Fine-tuning layer4 + fc")
for param in model.layer4.parameters():
param.requires_grad = True
optimizer = optim.Adam(
[{'params': model.layer4.parameters(), 'lr': 0.0001},
{'params': model.fc.parameters(), 'lr': 0.001}]
)
for epoch in range(1):
for batch_x, batch_y in loader:
output = model(batch_x)
loss = criterion(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f" Stage 2 Epoch {epoch + 1}: Loss = {loss.item():.4f}")
# Stage 3: Full model fine-tuning
print("Stage 3: Full model fine-tuning")
for param in model.parameters():
param.requires_grad = True
optimizer = optim.Adam(model.parameters(), lr=0.00001)
for epoch in range(1):
for batch_x, batch_y in loader:
output = model(batch_x)
loss = criterion(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f" Stage 3 Epoch {epoch + 1}: Loss = {loss.item():.4f}")
Handling input shape mismatches and architectural modifications
Adapt models to different input sizes and domain requirements.
Modifying models for different input sizes
import torch
import torch.nn as nn
from torchvision.models import vgg16
# Original VGG16 expects 224x224 inputs
model = vgg16(pretrained=True)
# Modify to accept 256x256 images (same architecture, just different input size)
# No code change needed—CNNs handle variable input sizes
x_large = torch.randn(1, 3, 256, 256)
output = model(x_large)
print(f"224x224 input output shape: {torch.randn(1, 3, 224, 224).shape}")
print(f"256x256 input output shape: {x_large.shape}")
# However, final layer shapes change
x_256 = torch.randn(1, 3, 256, 256)
features = nn.Sequential(*list(model.children())[:-1])(x_256)
print(f"Feature map shape for 256x256: {features.shape}")
# Adapt using adaptive pooling to fixed size
class AdaptiveVGG(nn.Module):
def __init__(self):
super(AdaptiveVGG, self).__init__()
vgg = vgg16(pretrained=True)
self.features = vgg.features
# Adaptive pooling to fixed size
self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = vgg.classifier
def forward(self, x):
x = self.features(x)
x = self.adaptive_pool(x) # Handle any input size
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
model = AdaptiveVGG()
x_small = torch.randn(1, 3, 224, 224)
x_large = torch.randn(1, 3, 512, 512)
output_small = model(x_small)
output_large = model(x_large)
print(f"Output shapes match: {output_small.shape == output_large.shape}")
Avoiding overfitting with transfer learning
When fine-tuning on small datasets, use techniques to prevent overfitting.
| Technique | Effect | When to Use |
|---|---|---|
| Lower learning rate | Smaller parameter updates | Always during fine-tuning |
| Dropout/weight decay | Regularization | Small datasets (<10k) |
| Freeze early layers | Keep learned features fixed | Very limited data |
| Data augmentation | Increase effective dataset size | Small datasets |
| Early stopping | Stop before overfitting | Limited validation data |
Regularization during fine-tuning
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50
model = resnet50(pretrained=True)
model.fc = nn.Linear(2048, 10)
# Fine-tune with regularization
optimizer = optim.SGD(
model.parameters(),
lr=0.0001,
momentum=0.9,
weight_decay=0.0001 # L2 regularization
)
criterion = nn.CrossEntropyLoss()
# Synthetic small dataset
X = torch.randn(500, 3, 224, 224)
y = torch.randint(0, 10, (500,))
# Simple training loop with early stopping
best_loss = float('inf')
patience = 3
patience_counter = 0
for epoch in range(10):
output = model(X[:100])
loss = criterion(output, y[:100])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if loss.item() < best_loss:
best_loss = loss.item()
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch}")
break
print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
Key Takeaways
- Transfer learning leverages pretrained weights from large datasets (ImageNet) to solve new tasks with minimal training data and time.
- Feature extraction (freeze all but final layer) is fastest for small datasets; progressive fine-tuning (unfreeze layers gradually) is best for larger custom datasets.
- Use lower learning rates during fine-tuning to make small adjustments without catastrophic forgetting; staged unfreezing (layer4 → layer3 → full) improves convergence.
- Adapt model inputs and outputs with custom layers, adaptive pooling, and parameter replacement; architecture and forward pass remain flexible.
- Prevent overfitting on small datasets with weight decay, dropout, data augmentation, and early stopping during fine-tuning.
Frequently Asked Questions
When should I use feature extraction vs fine-tuning?
Use feature extraction for very small datasets (< 1,000 samples) or small models; it's fast and prevents overfitting. Use fine-tuning for larger datasets (>10,000) where you have enough data to adapt deeper layers. For medium datasets (1-10k), use progressive fine-tuning.
What learning rate should I use for fine-tuning?
Start with 10x lower than training from scratch (e.g., 0.0001 instead of 0.001). Use different learning rates for different layers: lower for early layers (preserve features), higher for final layers (task adaptation). Use learning rate schedulers to reduce during training.
Can I fine-tune a model trained on a different task?
Yes, features learned on ImageNet transfer well to many vision tasks. Models trained on related tasks (e.g., object detection to instance segmentation) transfer even better. Unrelated domains require careful layer selection and may need more fine-tuning.
How do I know if a pretrained model will help?
If your target dataset is small (< 50k samples) or your task is similar to ImageNet classification (object recognition, scene understanding), pretrained models dramatically help. For unique domains (medical imaging, aerial photos), pretrained models still help but may need more fine-tuning.
Should I use batch normalization statistics from pretraining?
By default, use running statistics (set model.eval() on validation). During fine-tuning, consider updating batch norm statistics by setting track_running_stats=True and using model.train() on your dataset for a few epochs before full fine-tuning.