Train-Test Splits: Preventing Model Overfitting
The train-test split is the most critical step in machine learning: it separates your data into training and holdout sets to measure whether your model generalizes to unseen data. Without this split, you risk building models that memorize rather than learn—a fatal flaw called overfitting. scikit-learn's train_test_split() function automates this split and handles stratification, shuffling, and random seeding to ensure reproducible, unbiased evaluation.
Why Train-Test Splits Matter
Overfitting occurs when a model learns the training data so well—including its noise and peculiarities—that it fails on new data. Imagine training a spam detector only on your email inbox, then deploying it to handle emails it has never seen; it will likely perform worse than expected because the training set was too familiar. A train-test split prevents this by simulating real-world conditions: you hide the test set during training and use it only to evaluate final performance.
The key principle is simple: your model must not see the test set during training. This includes hyperparameter tuning, feature engineering, and data cleaning—any step that could cause the model to "peek" at test data leads to inflated performance estimates. By keeping test data completely separate, you get an honest measure of generalization.
Basic Train-Test Split with scikit-learn
The train_test_split() function from sklearn.model_selection divides your data into training and test subsets:
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
import numpy as np
# Load data
iris = load_iris()
X, y = iris.data, iris.target
# Split: 80% train, 20% test
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.2, # 20% for testing
random_state=42 # Seed for reproducibility
)
print(f"Training set size: {X_train.shape[0]}") # ~120 samples
print(f"Test set size: {X_test.shape[0]}") # ~30 samples
print(f"Feature count: {X_train.shape[1]}") # 4 features
The default test_size=0.25 (25% test, 75% train) is common, but 80/20 is the industry standard when you have 1,000+ samples. The random_state parameter seeds the shuffling for reproducibility—always set it when sharing code or running experiments.
Stratification: Maintaining Class Distribution
For imbalanced classification datasets, stratified splits preserve class distribution in both train and test sets. If your dataset is 90% negative, 10% positive, a random split might put 95% positives in the test set by chance, creating a biased evaluation.
# Without stratification: test set may have skewed class distribution
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.2,
random_state=42
# No stratification
)
print("Train class 0 ratio:", (y_train == 0).sum() / len(y_train))
print("Test class 0 ratio:", (y_test == 0).sum() / len(y_test))
# With stratification: class distribution matches original dataset
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.2,
stratify=y, # Stratify by class labels
random_state=42
)
print("Train class 0 ratio (stratified):", (y_train == 0).sum() / len(y_train))
print("Test class 0 ratio (stratified):", (y_test == 0).sum() / len(y_test))
# Both ratios now match the original dataset
Always use stratify=y for classification tasks, especially with imbalanced data. For regression, stratification does not apply because targets are continuous.
Avoiding Data Leakage
Data leakage occurs when information from the test set influences training or hyperparameter selection. The three main leakage patterns are:
1. Fitting Transformers on Combined Data
Never fit a scaler, encoder, or normalizer on train+test together:
from sklearn.preprocessing import StandardScaler
# WRONG: Leakage—scaler sees test data statistics
scaler = StandardScaler()
scaler.fit(X) # X includes test data!
X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test)
# RIGHT: Fit only on training data
scaler = StandardScaler()
scaler.fit(X_train) # Fit only on training features
X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test) # Apply learned scaling to test
When you fit the scaler on training data alone, it learns mean and standard deviation from training samples only. When applied to test data, it uses the training statistics—simulating real-world conditions where you scale production data using statistics learned from historical data.
2. Using Test Metrics for Hyperparameter Tuning
Never tune hyperparameters using test set performance. This leads to overfitting to the test set. Instead, create a third validation set or use cross-validation (covered in a later article):
# WRONG: Choosing max_depth based on test accuracy
from sklearn.tree import DecisionTreeClassifier
best_depth = None
best_test_accuracy = 0
for depth in range(1, 20):
model = DecisionTreeClassifier(max_depth=depth, random_state=42)
model.fit(X_train, y_train)
test_accuracy = model.score(X_test, y_test)
if test_accuracy > best_test_accuracy:
best_test_accuracy = test_accuracy
best_depth = depth
# RIGHT: Use validation set or cross-validation for tuning
from sklearn.model_selection import cross_val_score
best_depth = None
best_cv_score = 0
for depth in range(1, 20):
model = DecisionTreeClassifier(max_depth=depth, random_state=42)
# cross_val_score handles internal train/validation splits
cv_scores = cross_val_score(model, X_train, y_train, cv=5)
mean_cv_score = cv_scores.mean()
if mean_cv_score > best_cv_score:
best_cv_score = mean_cv_score
best_depth = depth
# Train final model with best_depth on full training data
final_model = DecisionTreeClassifier(max_depth=best_depth, random_state=42)
final_model.fit(X_train, y_train)
final_test_accuracy = final_model.score(X_test, y_test)
3. Feature Engineering Based on Full Dataset
Never compute feature statistics (mean, variance, thresholds) on combined train+test data:
# WRONG: Feature thresholds computed from combined data
feature_threshold = X.mean() # X includes test data
X_train_engineered = X_train > feature_threshold
X_test_engineered = X_test > feature_threshold
# RIGHT: Compute thresholds on training data only
feature_threshold = X_train.mean()
X_train_engineered = X_train > feature_threshold
X_test_engineered = X_test > feature_threshold
Multi-Way Splits: Adding Validation Sets
Large datasets sometimes warrant a third split: training, validation, and test. Use validation for hyperparameter tuning and test for final evaluation:
from sklearn.model_selection import train_test_split
# First split: 70% train+validation, 30% test
X_temp, X_test, y_temp, y_test = train_test_split(
X, y, test_size=0.3, stratify=y, random_state=42
)
# Second split: of the 70%, split into 70% train, 30% validation
# So final ratios: 49% train, 21% validation, 30% test
X_train, X_val, y_train, y_val = train_test_split(
X_temp, y_temp, test_size=0.3, stratify=y_temp, random_state=42
)
print(f"Train: {X_train.shape[0]}, Val: {X_val.shape[0]}, Test: {X_test.shape[0]}")
Tune hyperparameters on the validation set, then measure final performance only on the test set.
Key Takeaways
- Train-test splits simulate real-world generalization by hiding test data during all training and tuning steps.
- Use
test_size=0.2(80/20 split) for most datasets, and always setrandom_statefor reproducibility. - Stratify on class labels for classification to maintain class balance across train and test sets.
- Fit all transformers (scalers, encoders) on training data only; apply to test data using the learned transformation.
- Never use test metrics for hyperparameter tuning—use validation sets or cross-validation instead.
Frequently Asked Questions
What test size should I use?
For datasets under 1,000 samples, use 20-30% for testing (70-80% train). For larger datasets (10,000+), 10-15% test is sufficient. The rule: you need enough test samples to reliably measure performance—typically at least 100-200 test samples.
Does the 80/20 rule apply to time-series data?
No. For time-series, use temporal order: train on earlier data, test on later data, to avoid look-ahead bias. Do not shuffle; use train_test_split(shuffle=False) or manually index your data chronologically.
Can I use the same random_state every time?
Yes, for reproducibility in research and shared code, always set random_state=42 (or any fixed integer). For production systems that should vary across runs, you can omit random_state or use random_state=None.
What if my dataset is too small to split 80/20?
Use k-fold cross-validation instead (covered in a later article), which reuses data efficiently. For datasets under 100 samples, cross-validation is almost always better than a single split.
How do I check for data leakage?
Compare training and test performance: if test performance is significantly worse, leakage is unlikely. If they are surprisingly similar on a complex model, leakage may be present. Also audit all preprocessing and feature engineering steps to ensure they use only training data.