Handling Class Imbalance in Python ML Models
Real-world classification datasets rarely have balanced classes. In fraud detection, fraudulent transactions are 0.1-1% of data; in disease diagnosis, positive cases might be 5% of samples. A naive model that predicts "not fraud" or "no disease" for every sample achieves 99% accuracy while being completely useless. Class imbalance tricks accuracy into lying; you need techniques that improve minority-class recall without tanking precision.
Why Imbalance Matters
Accuracy is a misleading metric for imbalanced data. A model predicting the majority class for all samples achieves high accuracy but fails on the minority class. Algorithms like logistic regression and SVM default to maximizing overall accuracy, which biases them toward the majority class. The minority class receives too little training signal; the model needs techniques to focus on minority-class samples and treat errors on minority samples as costly.
Strategy 1: Adjust Class Weights
Most scikit-learn classifiers accept a class_weight parameter. Set it to 'balanced' to weight classes inversely by frequency: rare classes get higher weight, common classes get lower weight.
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
# Imbalanced target (90% class 0, 10% class 1)
y_train = np.array([0]*90 + [1]*10)
# Without class weighting: model ignores minority class
model_unweighted = LogisticRegression()
model_unweighted.fit(X_train, y_train)
# With balanced class weights: minority class gets higher weight
model_weighted = LogisticRegression(class_weight='balanced')
model_weighted.fit(X_train, y_train)
# For Random Forest
rf_weighted = RandomForestClassifier(class_weight='balanced', random_state=42)
rf_weighted.fit(X_train, y_train)
# Manual weights: weight_class_0 : weight_class_1 = 1 : 9
model_custom = LogisticRegression(class_weight={0: 1, 1: 9})
model_custom.fit(X_train, y_train)
Class weighting is the simplest and fastest approach. It tells the optimizer that errors on the minority class are 9x more expensive than errors on the majority class, shifting the decision boundary.
Strategy 2: Resampling Methods
Resampling changes the training data distribution. Oversampling duplicates minority samples; undersampling removes majority samples. The goal is to balance classes artificially.
Undersampling
Remove majority-class samples randomly. Fast but loses data.
from sklearn.utils import resample
# Separate classes
X_majority = X_train[y_train == 0]
y_majority = y_train[y_train == 0]
X_minority = X_train[y_train == 1]
y_minority = y_train[y_train == 1]
# Undersample majority to match minority count
X_maj_resampled, y_maj_resampled = resample(
X_majority, y_majority,
n_samples=len(X_minority),
replace=False,
random_state=42
)
# Combine
X_resampled = np.vstack([X_maj_resampled, X_minority])
y_resampled = np.concatenate([y_maj_resampled, y_minority])
# Now train on balanced data
model = LogisticRegression()
model.fit(X_resampled, y_resampled)
Undersampling is fast but discards useful majority-class data, reducing generalization. Use only if you have massive majority class.
Oversampling
Duplicate minority-class samples to match majority count. Simple but risks overfitting.
from sklearn.utils import resample
# Oversample minority to match majority
X_min_resampled, y_min_resampled = resample(
X_minority, y_minority,
n_samples=len(y_majority),
replace=True, # Sampling WITH replacement (allowing duplicates)
random_state=42
)
# Combine
X_resampled = np.vstack([X_majority, X_min_resampled])
y_resampled = np.concatenate([y_majority, y_min_resampled])
model = LogisticRegression()
model.fit(X_resampled, y_resampled)
Oversampling preserves minority-class information but introduces duplicates, which can overfit if not careful.
Strategy 3: SMOTE (Synthetic Minority Oversampling)
SMOTE creates synthetic minority samples by interpolating between existing minority samples in feature space. It's more sophisticated than simple duplication.
from imblearn.over_sampling import SMOTE
# SMOTE: create synthetic minority samples
smote = SMOTE(random_state=42, k_neighbors=5)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
print(f"Before SMOTE: class 0: {(y_train == 0).sum()}, class 1: {(y_train == 1).sum()}")
print(f"After SMOTE: class 0: {(y_resampled == 0).sum()}, class 1: {(y_resampled == 1).sum()}")
# Train on balanced data
model = LogisticRegression()
model.fit(X_resampled, y_resampled)
SMOTE generates realistic synthetic samples by finding the K nearest neighbors of each minority sample and creating new samples along the line between them. It avoids the overfitting of simple duplication and the information loss of undersampling.
# Example: SMOTE with stratified cross-validation
from sklearn.model_selection import StratifiedKFold
from imblearn.pipeline import Pipeline as ImbPipeline
from imblearn.over_sampling import SMOTE
from sklearn.linear_model import LogisticRegression
# Pipeline: SMOTE then train
pipeline = ImbPipeline(steps=[
('smote', SMOTE(random_state=42)),
('classifier', LogisticRegression(max_iter=1000))
])
# Cross-validate: SMOTE fits on each fold's training split independently
skf = StratifiedKFold(n_splits=5)
scores = cross_val_score(pipeline, X, y, cv=skf, scoring='roc_auc')
print(f"Cross-validated ROC-AUC: {scores.mean():.4f}")
SMOTE inside a cross-validation pipeline prevents leakage: synthetic samples are created per fold on training data only, not test data.
Strategy 4: Combined Approaches
Combine over- and undersampling for balance without excessive duplication.
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
from imblearn.pipeline import Pipeline
# Create a pipeline: SMOTE oversamples minority, then random undersampling
sampler = Pipeline(steps=[
('oversample', SMOTE(random_state=42, sampling_strategy=0.5)),
('undersample', RandomUnderSampler(random_state=42, sampling_strategy=0.8))
])
X_resampled, y_resampled = sampler.fit_resample(X_train, y_train)
print(f"After combined sampling: class 0: {(y_resampled == 0).sum()}, class 1: {(y_resampled == 1).sum()}")
Combined sampling balances computational cost and information preservation. Oversample minority first, then undersample majority to a reasonable ratio.
Strategy 5: Stratified Sampling
When splitting into train/test or cross-validating, use stratification to preserve class balance in each fold.
from sklearn.model_selection import train_test_split, StratifiedKFold
# Stratified train/test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=42
)
# Check balance
print(f"Train class distribution: {np.bincount(y_train)}")
print(f"Test class distribution: {np.bincount(y_test)}")
# Both preserve the original imbalance
# Stratified K-fold cross-validation
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for train_idx, val_idx in skf.split(X, y):
X_fold_train, X_fold_val = X[train_idx], X[val_idx]
y_fold_train, y_fold_val = y[train_idx], y[val_idx]
# Each fold has similar class distribution
Stratification ensures that both train and test sets reflect the overall class distribution, preventing evaluation bias.
Strategy 6: Threshold Adjustment
By default, logistic regression predicts class 1 if P(class 1) > 0.5. For imbalanced data, adjust this threshold to favor the minority class.
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import precision_recall_curve
# Train model (default threshold = 0.5)
model = LogisticRegression()
model.fit(X_train, y_train)
# Get predicted probabilities on validation set
y_proba = model.predict_proba(X_val)[:, 1]
# Find precision-recall curve
precision, recall, thresholds = precision_recall_curve(y_val, y_proba)
# Plot to find optimal threshold
import matplotlib.pyplot as plt
plt.plot(thresholds, precision[:-1], label='Precision')
plt.plot(thresholds, recall[:-1], label='Recall')
plt.xlabel('Threshold')
plt.legend()
plt.show()
# Use threshold that maximizes F1 or balances precision-recall
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)
best_threshold = thresholds[np.argmax(f1_scores[:-1])]
# Predict with adjusted threshold
y_pred_adjusted = (y_proba > best_threshold).astype(int)
Threshold adjustment is powerful: it post-processes predictions without retraining, tuning the sensitivity-specificity trade-off for your use case.
Evaluation Metrics for Imbalanced Data
For imbalanced data, accuracy is misleading. Use precision, recall, F1-score, or ROC-AUC.
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, confusion_matrix
# Model predictions
y_pred = model.predict(X_test)
y_proba = model.predict_proba(X_test)[:, 1]
# Precision, recall, F1 per class
precision, recall, f1, support = precision_recall_fscore_support(y_test, y_pred)
print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
# ROC-AUC (threshold-independent)
roc_auc = roc_auc_score(y_test, y_proba)
print(f"ROC-AUC: {roc_auc:.4f}")
# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
print(cm)
# [[TN FP]
# [FN TP]]
ROC-AUC is threshold-independent and especially useful for imbalanced data. It measures the probability that the model ranks a random positive sample higher than a random negative sample.
Comparison Table
| Method | Pros | Cons | Best For |
|---|---|---|---|
| Class Weights | Simple, fast, no extra data | Less effective for severe imbalance | Mild imbalance (10:90) |
| Undersampling | Fast, reduces data | Loses information, biased | Large datasets |
| Oversampling | Preserves data | Overfitting, duplicates | Small datasets |
| SMOTE | Synthetic, realistic, no duplication | Slower, may not help on all problems | Moderate imbalance |
| Combined | Balances duplication and information | Tuning required | Severe imbalance |
| Threshold Tuning | Model-agnostic, flexible | Only shifts sensitivity/specificity | Custom precision-recall trade-off |
Key Takeaways
- Class imbalance makes accuracy misleading; use precision, recall, F1, and ROC-AUC instead.
- Start with class weighting (simplest, fastest); adjust with
class_weight='balanced'. - For severe imbalance, use SMOTE to generate synthetic minority samples or combine over- and undersampling.
- Always apply resampling inside cross-validation to prevent leakage: resample each fold independently.
- Stratify train/test splits and cross-validation folds to preserve class distribution.
- Adjust decision thresholds post-training to tune the sensitivity-specificity trade-off for your domain.
Frequently Asked Questions
Should I resample before or after train/test split?
Split first, then resample only the training set. Never resample the combined dataset; that's data leakage. The test set should reflect real-world class imbalance.
Can I use SMOTE on test data?
No. SMOTE should only transform training data. Create synthetic samples on the training set, then evaluate on the original test set. Using SMOTE on test data inflates validation metrics.
What if my imbalance is 99:1 (99% majority, 1% minority)?
Severe imbalance requires combined strategies: class weighting, SMOTE, undersampling the majority class, and threshold tuning. Consider domain-specific solutions (e.g., cost-sensitive learning). Linear models with class weighting often work better than tree models on severe imbalance.
Is ROC-AUC always better than F1 for imbalanced data?
ROC-AUC is threshold-independent, so it's useful for comparing models. F1 is threshold-dependent but more interpretable for a specific decision boundary. Use both: ROC-AUC to rank models, F1 to evaluate at your chosen threshold.