DataLoader and Datasets in PyTorch
PyTorch's Dataset and DataLoader classes decouple data loading from model training, enabling efficient batching, shuffling, and parallel data loading. Dataset defines how individual samples are loaded and preprocessed; DataLoader wraps a Dataset to handle batching, multi-worker parallelism, and dynamic sampling. This separation is critical for training on large datasets that don't fit in memory.
The Dataset abstraction
A Dataset is a mapping from indices to data samples. You implement __len__() to return the total number of samples and __getitem__(idx) to return the sample at index idx. PyTorch provides built-in datasets for common benchmarks; custom datasets inherit from torch.utils.data.Dataset.
Creating a custom dataset
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
class CustomDataset(Dataset):
def __init__(self, data, labels, transform=None):
# Store data and labels as tensors or numpy arrays
self.data = data
self.labels = labels
self.transform = transform # Optional preprocessing function
def __len__(self):
# Return total number of samples
return len(self.data)
def __getitem__(self, idx):
# Return single sample (data, label) pair
sample = self.data[idx]
label = self.labels[idx]
# Apply optional transformation
if self.transform:
sample = self.transform(sample)
return sample, label
# Create dataset with synthetic data
X = np.random.randn(100, 20).astype(np.float32)
y = np.random.randint(0, 3, 100)
dataset = CustomDataset(X, y)
print(f"Dataset length: {len(dataset)}")
# Access individual samples
sample, label = dataset[0]
print(f"Sample shape: {sample.shape}, Label: {label}")
Using PyTorch's built-in datasets
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
# Define preprocessing pipeline
transform = Compose([
ToTensor(), # Convert PIL images to tensors and scale to [0, 1]
Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) # ImageNet normalization
])
# Load CIFAR-10 dataset (automatically downloads if not present)
train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
val_dataset = CIFAR10(root='./data', train=False, transform=transform, download=False)
print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
# Check a sample
image, label = train_dataset[0]
print(f"Image shape: {image.shape}, Label: {label}")
DataLoader for batching and sampling
DataLoader wraps a Dataset and provides batching, shuffling, and parallel data loading. It yields mini-batches of data instead of individual samples, essential for GPU training efficiency.
Creating and using DataLoaders
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np
# Create synthetic data
X = torch.randn(100, 20)
y = torch.randint(0, 3, (100,))
# Option 1: Using TensorDataset (simple wrapper for tensors)
dataset = TensorDataset(X, y)
# Create DataLoader with common parameters
train_loader = DataLoader(
dataset,
batch_size=16, # Batch size
shuffle=True, # Randomize order each epoch
num_workers=0, # Number of parallel data loaders (0 = main thread)
pin_memory=True, # Lock tensors in RAM (faster GPU transfer)
drop_last=False # Drop last batch if not full
)
# Iterate over batches
print("First 3 batches:")
for batch_idx, (batch_X, batch_y) in enumerate(train_loader):
if batch_idx >= 3:
break
print(f"Batch {batch_idx}: X shape {batch_X.shape}, y shape {batch_y.shape}")
# Create separate train/val loaders with different settings
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset, batch_size=64, shuffle=False) # No shuffle for validation
print(f"\nTrain loader: {len(train_loader)} batches")
print(f"Val loader: {len(val_loader)} batches")
Parallel data loading with num_workers
Parallel loading via multiple worker processes speeds up training by overlapping I/O with computation.
| Setting | Purpose | Trade-off |
|---|---|---|
num_workers=0 | Single-threaded loading | Slower; good for small datasets, debugging |
num_workers=4 | 4 parallel workers | Faster I/O; higher memory, CPU usage |
num_workers=8 | 8 parallel workers | Best for large datasets; requires more CPU |
Using parallel data loading
import torch
from torch.utils.data import DataLoader, TensorDataset
import time
# Create synthetic data
X = torch.randn(10000, 100)
y = torch.randint(0, 10, (10000,))
dataset = TensorDataset(X, y)
# Compare single-threaded vs parallel loading
for num_workers in [0, 4]:
loader = DataLoader(
dataset,
batch_size=128,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
start = time.time()
for _ in loader:
pass # Just iterate to measure loading time
elapsed = time.time() - start
print(f"num_workers={num_workers}: {elapsed:.3f}s per epoch")
# Note: num_workers > 0 has overhead on small datasets
# Usually beneficial when dataset size > 1000 samples
Custom collate functions for flexible batching
By default, DataLoader stacks samples into tensors. Custom collate functions allow complex batching logic.
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
class VariableLengthDataset(Dataset):
def __init__(self, sequences, labels):
self.sequences = sequences # List of tensors with different lengths
self.labels = labels
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
return self.sequences[idx], self.labels[idx]
# Create dataset with variable-length sequences
sequences = [
torch.randn(10, 5), # Length 10
torch.randn(15, 5), # Length 15
torch.randn(8, 5), # Length 8
]
labels = torch.tensor([0, 1, 0])
dataset = VariableLengthDataset(sequences, labels)
# Custom collate function pads sequences to same length
def collate_variable_length(batch):
sequences, labels = zip(*batch)
# Pad sequences to longest in batch
padded_sequences = pad_sequence(sequences, batch_first=True)
labels = torch.stack(labels)
return padded_sequences, labels
# Use custom collate
loader = DataLoader(
dataset,
batch_size=2,
shuffle=False,
collate_fn=collate_variable_length
)
for batch_seq, batch_labels in loader:
print(f"Padded sequences shape: {batch_seq.shape}, Labels: {batch_labels}")
Data augmentation and preprocessing
Apply transformations to samples during loading to increase diversity and prevent overfitting.
Image augmentation pipeline
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import (
Compose, RandomCrop, RandomHorizontalFlip,
ToTensor, Normalize
)
# Training augmentations (random crop, flip)
train_transform = Compose([
RandomCrop(32, padding=4), # Random crop with padding
RandomHorizontalFlip(), # Randomly flip horizontally
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Validation: no augmentation, only normalization
val_transform = Compose([
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Create datasets with different transforms
train_dataset = CIFAR10('./data', train=True, transform=train_transform, download=True)
val_dataset = CIFAR10('./data', train=False, transform=val_transform, download=False)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
print(f"Train loader batches: {len(train_loader)}")
print(f"Val loader batches: {len(val_loader)}")
Sampling strategies and weighted sampling
Control how samples are selected—uniform random, sequential, or weighted by importance.
Weighted sampling for imbalanced datasets
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import numpy as np
class ImbalancedDataset(Dataset):
def __init__(self, samples, labels):
self.samples = samples
self.labels = labels
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx], self.labels[idx]
# Create imbalanced dataset (10 class-0, 90 class-1)
samples = torch.randn(100, 20)
labels = torch.cat([torch.zeros(10, dtype=torch.long), torch.ones(90, dtype=torch.long)])
dataset = ImbalancedDataset(samples, labels)
# Compute class weights (inverse of frequency)
class_counts = torch.bincount(labels)
class_weights = 1.0 / class_counts.float()
# Create sampler that weights by class
sample_weights = class_weights[labels]
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(labels),
replacement=True
)
# DataLoader with weighted sampler (no shuffle—sampler handles randomness)
loader = DataLoader(
dataset,
batch_size=16,
sampler=sampler, # Can't use shuffle=True with sampler
num_workers=0
)
# Check that batches contain balanced classes
for batch_x, batch_y in loader:
class_counts_in_batch = torch.bincount(batch_y)
print(f"Batch class distribution: {class_counts_in_batch.tolist()}")
break
Key Takeaways
Datasetsubclasses implement__len__()and__getitem__(idx)to define how individual samples are loaded and preprocessed.DataLoaderwraps aDatasetto provide batching, shuffling, and parallel loading—essential for efficient GPU training.- Set
num_workers > 0for parallel data loading on large datasets (>1000 samples); trade-off is higher CPU/memory but faster epoch times. - Use
pin_memory=Truefor faster GPU memory transfer; applyshuffle=Trueto training loaders,shuffle=Falseto validation. - Custom collate functions enable complex batching logic like padding variable-length sequences or handling hierarchical data.
Frequently Asked Questions
What is the difference between shuffle=True and a Sampler?
shuffle=True randomly reorders samples using the default sampler. Custom samplers (like WeightedRandomSampler) control which samples are selected and in what order. You can't use both: if you pass a sampler, shuffle=True is ignored.
Why does num_workers > 0 sometimes slow down training on small datasets?
Worker processes have startup overhead and require pickling data to pass between processes. On small datasets, this overhead exceeds the benefits of parallel loading. Use num_workers=0 for datasets < 1000 samples or during debugging.
How do I handle variable-length sequences in batches?
Use a custom collate_fn that pads sequences to the longest length in the batch, or truncates to a fixed length. PyTorch provides pad_sequence() and pad_packed_sequence() for RNNs that can handle variable-length inputs natively.
Can I use DataLoader for inference?
Yes, create a dataset containing only inputs (no labels), and set batch_size appropriately for your GPU memory. Set shuffle=False and num_workers based on your system. Iterate and call model.eval() with torch.no_grad().
What is drop_last=True and when should I use it?
When drop_last=True, the DataLoader drops the last batch if it's smaller than batch_size. Use it when batch statistics (batch norm) require all batches to be the same size, or during training to avoid noisy small batches. For validation/inference, use drop_last=False to include all samples.