Batch Processing ML Predictions: Optimize Python
While real-time request-response APIs are essential for interactive applications, many ML workloads are inherently batch: scoring a million customer records nightly, classifying a bulk upload of images, or analyzing historical logs. Batch prediction is more efficient than handling the same data one request at a time because it amortizes model loading, batches matrix operations, and tolerates higher latency.
This article covers three batch prediction patterns: in-process batch processing (NumPy vectorization), asynchronous job queues (Celery, Redis), and file-based batch pipelines (Parquet, Arrow). You'll learn how to design batches for throughput, monitor progress, and handle failures gracefully.
Why Batch Processing Beats Single-Request APIs
When you call an API endpoint once per row, you pay overhead on each call: HTTP parsing, request routing, model loading (if not cached), and response serialization. For a million rows, this overhead dominates the actual inference time.
Batch processing loads the model once and runs inference on all rows in vectorized operations. Modern ML frameworks (scikit-learn, PyTorch, TensorFlow) are optimized for batches: they use SIMD, GPU acceleration, and cache-friendly memory layouts. Processing 1,000 rows in one batch is 100× faster than 1,000 separate API calls (not 1× faster—the speedup compounds).
For workloads where latency is not critical (nightly reporting, weekly analytics), batch processing can reduce infrastructure costs by 50–80%.
In-Memory Batch Processing
For datasets that fit in RAM, load data, batch it, and run inference in a simple loop:
import joblib
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
# Load model and scaler (fitted during training)
model = joblib.load("model.joblib")
scaler = joblib.load("scaler.joblib")
# Load data
df = pd.read_csv("customers.csv")
X = df[["age", "income", "credit_score", "num_transactions"]].values
# Normalize features
X_scaled = scaler.transform(X)
# Predict in batch
predictions = model.predict(X_scaled) # All at once, vectorized
probabilities = model.predict_proba(X_scaled)
# Store results
df["prediction"] = predictions
df["confidence"] = probabilities.max(axis=1)
df.to_csv("predictions.csv", index=False)
print(f"Scored {len(df)} rows in {len(df) / 1000:.2f} batches")
This works for datasets up to a few GB. For larger data, you need to chunk:
def batch_predict(X, model, scaler, batch_size=1000):
"""Predict in batches to manage memory."""
predictions = []
probabilities = []
for i in range(0, len(X), batch_size):
X_batch = X[i:i+batch_size]
X_batch_scaled = scaler.transform(X_batch)
pred_batch = model.predict(X_batch_scaled)
proba_batch = model.predict_proba(X_batch_scaled)
predictions.append(pred_batch)
probabilities.append(proba_batch)
return np.concatenate(predictions), np.concatenate(probabilities)
# For a 100 GB file, read and process in chunks
chunk_predictions = []
for chunk in pd.read_csv("huge_file.csv", chunksize=10000):
X_chunk = chunk[["age", "income", "credit_score", "num_transactions"]].values
preds, probas = batch_predict(X_chunk, model, scaler)
chunk_predictions.append(pd.DataFrame({
"id": chunk["id"],
"prediction": preds,
"confidence": probas.max(axis=1)
}))
results = pd.concat(chunk_predictions, ignore_index=True)
results.to_csv("batch_predictions.csv", index=False)
Asynchronous Job Queues: Celery and Redis
For production systems, use a task queue to decouple submission from processing. A user submits a batch job via an API endpoint; the system queues it, processes it asynchronously, and notifies the user when it's done.
Celery with Redis is the industry standard:
# tasks.py
from celery import Celery
import joblib
import pandas as pd
import numpy as np
app = Celery("ml_tasks", broker="redis://localhost:6379")
model = joblib.load("model.joblib")
scaler = joblib.load("scaler.joblib")
@app.task(bind=True)
def batch_predict_task(self, csv_file_path: str):
"""Celery task: predict on a CSV and save results."""
self.update_state(state="PROGRESS", meta={"current": 0, "total": 1})
try:
# Load and predict
df = pd.read_csv(csv_file_path)
X = df[["age", "income", "credit_score"]].values
X_scaled = scaler.transform(X)
predictions = model.predict(X_scaled)
probabilities = model.predict_proba(X_scaled)
# Save results
df["prediction"] = predictions
df["confidence"] = probabilities.max(axis=1)
output_path = csv_file_path.replace(".csv", "_predictions.csv")
df.to_csv(output_path, index=False)
return {"status": "done", "output": output_path, "rows": len(df)}
except Exception as e:
self.update_state(state="FAILURE", meta={"error": str(e)})
raise
Then expose a FastAPI endpoint to submit jobs:
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from tasks import batch_predict_task
import shutil
app = FastAPI()
@app.post("/batch-predict")
async def submit_batch_job(file: UploadFile = File(...)):
"""Submit a batch prediction job and return task ID."""
# Save uploaded file
file_path = f"uploads/{file.filename}"
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
# Submit to Celery
task = batch_predict_task.delay(file_path)
return {"task_id": task.id, "status": "queued"}
@app.get("/batch-predict/{task_id}")
async def check_job_status(task_id: str):
"""Check the status of a batch job."""
task = batch_predict_task.AsyncResult(task_id)
if task.state == "PENDING":
return {"status": "pending"}
elif task.state == "PROGRESS":
return {"status": "processing", "meta": task.info}
elif task.state == "SUCCESS":
return {"status": "done", "result": task.result}
else:
return {"status": "failed", "error": str(task.info)}
Users submit via POST /batch-predict (uploading a CSV), then poll /batch-predict/{task_id} to check status.
Batch Processing with Apache Arrow and Parquet
For large-scale data engineering, use Arrow and Parquet (columnar storage) instead of CSV:
import pyarrow.parquet as pq
import pyarrow.compute as pc
# Read Parquet in streaming batches
table = pq.read_table("data.parquet")
# Extract features as NumPy arrays
feature_names = ["age", "income", "credit_score"]
X = table.select(feature_names).to_numpy()
# Predict
predictions = model.predict(scaler.transform(X))
# Write results back to Parquet
import pyarrow as pa
result_table = table.append_column("prediction", pa.array(predictions))
pq.write_table(result_table, "data_with_predictions.parquet")
Parquet is far more efficient than CSV for this: smaller file size, faster I/O, and supports partial column reads (don't load features you don't need).
Comparison Table: Batch Processing Approaches
| Approach | Data Size | Latency | Complexity | Monitoring | Best For |
|---|---|---|---|---|---|
| In-memory vectorized | < 10 GB | Minutes | Low | Basic | Small-to-medium batch jobs |
| Chunked file processing | 10 GB - 1 TB | Hours | Low | Basic | Weekly/nightly exports |
| Celery queue | Any | Seconds-hours | Medium | Task history | Multiple concurrent jobs |
| Spark/Dask | 1 TB+ | Hours | High | Distributed logs | Petabyte-scale analytics |
| Cloud batch (BigQuery, SageMaker) | Any | Hours | Medium | Cloud console | Managed infrastructure |
Monitoring and Failure Handling
For production batch jobs, log progress and handle failures gracefully:
import logging
import time
from datetime import datetime
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def batch_predict_with_logging(
input_file: str,
output_file: str,
batch_size: int = 1000,
retry_count: int = 3
):
"""Batch predict with logging and retry logic."""
start_time = time.time()
logger.info(f"Starting batch prediction on {input_file}")
try:
df = pd.read_csv(input_file)
total_rows = len(df)
logger.info(f"Loaded {total_rows} rows")
X = df[["age", "income", "credit_score"]].values
X_scaled = scaler.transform(X)
processed = 0
all_predictions = []
for i in range(0, len(X_scaled), batch_size):
batch_end = min(i + batch_size, len(X_scaled))
X_batch = X_scaled[i:batch_end]
for attempt in range(retry_count):
try:
preds = model.predict(X_batch)
all_predictions.append(preds)
processed = batch_end
logger.info(f"Processed {processed}/{total_rows} rows")
break
except Exception as e:
logger.warning(f"Batch {i} failed (attempt {attempt+1}): {e}")
if attempt == retry_count - 1:
raise
time.sleep(2 ** attempt) # Exponential backoff
df["prediction"] = np.concatenate(all_predictions)
df.to_csv(output_file, index=False)
elapsed = time.time() - start_time
logger.info(f"Done in {elapsed:.1f}s ({total_rows/elapsed:.0f} rows/sec)")
return {"status": "success", "rows": total_rows, "elapsed_s": elapsed}
except Exception as e:
logger.error(f"Batch prediction failed: {e}")
raise
Key Takeaways
- Batch processing is 10–100× faster than request-response for high-volume scenarios because it amortizes overhead and exploits vectorization.
- For small datasets (< 10 GB), in-memory chunked processing is simplest.
- For production systems with multiple concurrent jobs, use Celery + Redis for asynchronous task queuing.
- Parquet and Arrow are superior to CSV for large-scale data due to compression, schema enforcement, and columnar I/O.
- Always implement logging, retry logic, and progress tracking for production batch jobs.
Frequently Asked Questions
What is the optimal batch size?
It depends on model memory requirements and hardware. Start with 1,000 rows; if GPU memory is plentiful, increase to 5,000–10,000. Profile to find the sweet spot where throughput (rows per second) is maximal.
Can I distribute batch inference across multiple machines?
Yes. Use Spark, Dask, or a cloud platform like BigQuery or SageMaker. These are covered in advanced courses; for this series, Celery is sufficient for medium-scale workloads.
How do I handle data validation in batch jobs?
Add a validation step before inference: check for NaN/infinity, verify feature ranges, and log rejected rows. Store invalid rows in a separate "quarantine" file for manual review.
Should I batch on the client side or server side?
For external clients, let them send individual requests; batch on the server (Celery worker) if beneficial. If the client controls the batch (e.g., internal script), batch on the client to reduce network overhead.
Further Reading
- Celery Official Documentation — task queue framework reference.
- Apache Parquet Format — columnar data storage specification.
- Apache Arrow in Python — efficient data serialization.
- Dask Distributed Computing — parallel batch processing for Python.