Skip to main content

DataLoader and Datasets in PyTorch

PyTorch's Dataset and DataLoader classes decouple data loading from model training, enabling efficient batching, shuffling, and parallel data loading. Dataset defines how individual samples are loaded and preprocessed; DataLoader wraps a Dataset to handle batching, multi-worker parallelism, and dynamic sampling. This separation is critical for training on large datasets that don't fit in memory.

The Dataset abstraction

A Dataset is a mapping from indices to data samples. You implement __len__() to return the total number of samples and __getitem__(idx) to return the sample at index idx. PyTorch provides built-in datasets for common benchmarks; custom datasets inherit from torch.utils.data.Dataset.

Creating a custom dataset

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class CustomDataset(Dataset):
def __init__(self, data, labels, transform=None):
# Store data and labels as tensors or numpy arrays
self.data = data
self.labels = labels
self.transform = transform # Optional preprocessing function

def __len__(self):
# Return total number of samples
return len(self.data)

def __getitem__(self, idx):
# Return single sample (data, label) pair
sample = self.data[idx]
label = self.labels[idx]

# Apply optional transformation
if self.transform:
sample = self.transform(sample)

return sample, label

# Create dataset with synthetic data
X = np.random.randn(100, 20).astype(np.float32)
y = np.random.randint(0, 3, 100)

dataset = CustomDataset(X, y)
print(f"Dataset length: {len(dataset)}")

# Access individual samples
sample, label = dataset[0]
print(f"Sample shape: {sample.shape}, Label: {label}")

Using PyTorch's built-in datasets

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize

# Define preprocessing pipeline
transform = Compose([
ToTensor(), # Convert PIL images to tensors and scale to [0, 1]
Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) # ImageNet normalization
])

# Load CIFAR-10 dataset (automatically downloads if not present)
train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
val_dataset = CIFAR10(root='./data', train=False, transform=transform, download=False)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

# Check a sample
image, label = train_dataset[0]
print(f"Image shape: {image.shape}, Label: {label}")

DataLoader for batching and sampling

DataLoader wraps a Dataset and provides batching, shuffling, and parallel data loading. It yields mini-batches of data instead of individual samples, essential for GPU training efficiency.

Creating and using DataLoaders

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np

# Create synthetic data
X = torch.randn(100, 20)
y = torch.randint(0, 3, (100,))

# Option 1: Using TensorDataset (simple wrapper for tensors)
dataset = TensorDataset(X, y)

# Create DataLoader with common parameters
train_loader = DataLoader(
dataset,
batch_size=16, # Batch size
shuffle=True, # Randomize order each epoch
num_workers=0, # Number of parallel data loaders (0 = main thread)
pin_memory=True, # Lock tensors in RAM (faster GPU transfer)
drop_last=False # Drop last batch if not full
)

# Iterate over batches
print("First 3 batches:")
for batch_idx, (batch_X, batch_y) in enumerate(train_loader):
if batch_idx >= 3:
break
print(f"Batch {batch_idx}: X shape {batch_X.shape}, y shape {batch_y.shape}")

# Create separate train/val loaders with different settings
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset, batch_size=64, shuffle=False) # No shuffle for validation

print(f"\nTrain loader: {len(train_loader)} batches")
print(f"Val loader: {len(val_loader)} batches")

Parallel data loading with num_workers

Parallel loading via multiple worker processes speeds up training by overlapping I/O with computation.

SettingPurposeTrade-off
num_workers=0Single-threaded loadingSlower; good for small datasets, debugging
num_workers=44 parallel workersFaster I/O; higher memory, CPU usage
num_workers=88 parallel workersBest for large datasets; requires more CPU

Using parallel data loading

import torch
from torch.utils.data import DataLoader, TensorDataset
import time

# Create synthetic data
X = torch.randn(10000, 100)
y = torch.randint(0, 10, (10000,))
dataset = TensorDataset(X, y)

# Compare single-threaded vs parallel loading
for num_workers in [0, 4]:
loader = DataLoader(
dataset,
batch_size=128,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)

start = time.time()
for _ in loader:
pass # Just iterate to measure loading time
elapsed = time.time() - start

print(f"num_workers={num_workers}: {elapsed:.3f}s per epoch")

# Note: num_workers > 0 has overhead on small datasets
# Usually beneficial when dataset size > 1000 samples

Custom collate functions for flexible batching

By default, DataLoader stacks samples into tensors. Custom collate functions allow complex batching logic.

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class VariableLengthDataset(Dataset):
def __init__(self, sequences, labels):
self.sequences = sequences # List of tensors with different lengths
self.labels = labels

def __len__(self):
return len(self.sequences)

def __getitem__(self, idx):
return self.sequences[idx], self.labels[idx]

# Create dataset with variable-length sequences
sequences = [
torch.randn(10, 5), # Length 10
torch.randn(15, 5), # Length 15
torch.randn(8, 5), # Length 8
]
labels = torch.tensor([0, 1, 0])

dataset = VariableLengthDataset(sequences, labels)

# Custom collate function pads sequences to same length
def collate_variable_length(batch):
sequences, labels = zip(*batch)

# Pad sequences to longest in batch
padded_sequences = pad_sequence(sequences, batch_first=True)
labels = torch.stack(labels)

return padded_sequences, labels

# Use custom collate
loader = DataLoader(
dataset,
batch_size=2,
shuffle=False,
collate_fn=collate_variable_length
)

for batch_seq, batch_labels in loader:
print(f"Padded sequences shape: {batch_seq.shape}, Labels: {batch_labels}")

Data augmentation and preprocessing

Apply transformations to samples during loading to increase diversity and prevent overfitting.

Image augmentation pipeline

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import (
Compose, RandomCrop, RandomHorizontalFlip,
ToTensor, Normalize
)

# Training augmentations (random crop, flip)
train_transform = Compose([
RandomCrop(32, padding=4), # Random crop with padding
RandomHorizontalFlip(), # Randomly flip horizontally
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

# Validation: no augmentation, only normalization
val_transform = Compose([
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

# Create datasets with different transforms
train_dataset = CIFAR10('./data', train=True, transform=train_transform, download=True)
val_dataset = CIFAR10('./data', train=False, transform=val_transform, download=False)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

print(f"Train loader batches: {len(train_loader)}")
print(f"Val loader batches: {len(val_loader)}")

Sampling strategies and weighted sampling

Control how samples are selected—uniform random, sequential, or weighted by importance.

Weighted sampling for imbalanced datasets

import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import numpy as np

class ImbalancedDataset(Dataset):
def __init__(self, samples, labels):
self.samples = samples
self.labels = labels

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
return self.samples[idx], self.labels[idx]

# Create imbalanced dataset (10 class-0, 90 class-1)
samples = torch.randn(100, 20)
labels = torch.cat([torch.zeros(10, dtype=torch.long), torch.ones(90, dtype=torch.long)])

dataset = ImbalancedDataset(samples, labels)

# Compute class weights (inverse of frequency)
class_counts = torch.bincount(labels)
class_weights = 1.0 / class_counts.float()

# Create sampler that weights by class
sample_weights = class_weights[labels]
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(labels),
replacement=True
)

# DataLoader with weighted sampler (no shuffle—sampler handles randomness)
loader = DataLoader(
dataset,
batch_size=16,
sampler=sampler, # Can't use shuffle=True with sampler
num_workers=0
)

# Check that batches contain balanced classes
for batch_x, batch_y in loader:
class_counts_in_batch = torch.bincount(batch_y)
print(f"Batch class distribution: {class_counts_in_batch.tolist()}")
break

Key Takeaways

  • Dataset subclasses implement __len__() and __getitem__(idx) to define how individual samples are loaded and preprocessed.
  • DataLoader wraps a Dataset to provide batching, shuffling, and parallel loading—essential for efficient GPU training.
  • Set num_workers > 0 for parallel data loading on large datasets (>1000 samples); trade-off is higher CPU/memory but faster epoch times.
  • Use pin_memory=True for faster GPU memory transfer; apply shuffle=True to training loaders, shuffle=False to validation.
  • Custom collate functions enable complex batching logic like padding variable-length sequences or handling hierarchical data.

Frequently Asked Questions

What is the difference between shuffle=True and a Sampler?

shuffle=True randomly reorders samples using the default sampler. Custom samplers (like WeightedRandomSampler) control which samples are selected and in what order. You can't use both: if you pass a sampler, shuffle=True is ignored.

Why does num_workers > 0 sometimes slow down training on small datasets?

Worker processes have startup overhead and require pickling data to pass between processes. On small datasets, this overhead exceeds the benefits of parallel loading. Use num_workers=0 for datasets < 1000 samples or during debugging.

How do I handle variable-length sequences in batches?

Use a custom collate_fn that pads sequences to the longest length in the batch, or truncates to a fixed length. PyTorch provides pad_sequence() and pad_packed_sequence() for RNNs that can handle variable-length inputs natively.

Can I use DataLoader for inference?

Yes, create a dataset containing only inputs (no labels), and set batch_size appropriately for your GPU memory. Set shuffle=False and num_workers based on your system. Iterate and call model.eval() with torch.no_grad().

What is drop_last=True and when should I use it?

When drop_last=True, the DataLoader drops the last batch if it's smaller than batch_size. Use it when batch statistics (batch norm) require all batches to be the same size, or during training to avoid noisy small batches. For validation/inference, use drop_last=False to include all samples.

Further Reading