Monitor ML Model Performance: Tracking Predictions
A model performing well in development can fail silently in production. Data distribution shifts (data drift), model behavior changes (concept drift), or dependency failures can degrade accuracy. Unlike traditional software, the failure is not a crash—it is a slow accuracy decline that customer complaints reveal weeks later.
Production ML systems need continuous monitoring: tracking prediction latency, throughput, error rates, and importantly, data drift. This article teaches you how to instrument a Python inference service with metrics, detect distribution changes, and alert when models need retraining.
The Three Monitoring Pillars
Operational metrics (system health): latency, throughput, errors, CPU/memory. Standard application monitoring.
Predictive metrics (model quality): prediction accuracy (if labels arrive), bias, feature distributions, predictions over time. These reveal model degradation.
Data drift metrics: distribution changes in input features or outputs. If your model trained on 2024 data but January 2026 data looks different, accuracy will drop.
Instrumenting Your FastAPI Service
Add Prometheus metrics to track predictions in real time:
from fastapi import FastAPI
from prometheus_client import Counter, Histogram, Gauge, generate_latest
import time
import numpy as np
app = FastAPI()
# Define metrics
prediction_count = Counter(
"ml_predictions_total",
"Total predictions made",
["model_version", "class_label"]
)
prediction_latency = Histogram(
"ml_prediction_latency_seconds",
"Latency of predictions",
["model_version"],
buckets=(0.01, 0.05, 0.1, 0.5, 1.0)
)
prediction_errors = Counter(
"ml_prediction_errors_total",
"Errors during prediction",
["model_version", "error_type"]
)
feature_min = Gauge(
"ml_feature_min",
"Min value of features",
["feature_name"]
)
feature_max = Gauge(
"ml_feature_max",
"Max value of features",
["feature_name"]
)
@app.post("/predict")
async def predict(request: PredictionRequest):
model_version = "1.2.0"
start_time = time.time()
try:
X = np.array(request.features).reshape(1, -1)
# Track feature distributions
for i, feature_name in enumerate(["sepal_length", "sepal_width", "petal_length", "petal_width"]):
feature_min.labels(feature_name=feature_name).set(X[0, i].min())
feature_max.labels(feature_name=feature_name).set(X[0, i].max())
pred = model.predict(X)[0]
pred_proba = model.predict_proba(X)[0]
# Record metrics
latency = time.time() - start_time
prediction_latency.labels(model_version=model_version).observe(latency)
prediction_count.labels(model_version=model_version, class_label=int(pred)).inc()
return {
"prediction": int(pred),
"confidence": float(pred_proba.max()),
"model_version": model_version,
"latency_ms": latency * 1000
}
except Exception as e:
prediction_errors.labels(model_version=model_version, error_type=type(e).__name__).inc()
raise
# Prometheus endpoint
@app.get("/metrics")
async def metrics():
return generate_latest()
Prometheus scrapes the /metrics endpoint every 15 seconds and stores time-series data. Grafana visualizes these metrics in dashboards.
Detecting Data Drift
Data drift is when the input distribution changes. If your model trained on sepal lengths between 4 and 8 cm but now sees 10 cm flowers, it extrapolates badly.
Detect drift using statistical tests:
import numpy as np
from scipy import stats
import logging
logger = logging.getLogger(__name__)
class DriftDetector:
def __init__(self, baseline_data: np.ndarray, threshold: float = 0.05):
"""
baseline_data: Reference distribution (training data)
threshold: p-value threshold for significance (0.05 = 95% confidence)
"""
self.baseline_data = baseline_data
self.threshold = threshold
def detect_kolmogorov_smirnov(self, new_data: np.ndarray) -> bool:
"""
Kolmogorov-Smirnov test: compares two distributions.
Returns True if statistically significant drift detected.
"""
statistic, p_value = stats.ks_2samp(self.baseline_data, new_data)
if p_value < self.threshold:
logger.warning(f"Drift detected! KS statistic: {statistic:.4f}, p-value: {p_value:.4f}")
return True
return False
def detect_wasserstein(self, new_data: np.ndarray) -> bool:
"""
Wasserstein distance: measures the cost of transforming one distribution to another.
Good for multivariate distributions.
"""
distance = stats.wasserstein_distance(self.baseline_data, new_data)
# Threshold depends on data scale; adjust empirically
if distance > 0.1:
logger.warning(f"Drift detected! Wasserstein distance: {distance:.4f}")
return True
return False
def detect_multivariate(self, baseline: np.ndarray, new_data: np.ndarray) -> dict:
"""
Check drift per feature (more detailed).
"""
drift_results = {}
for i, feature_name in enumerate(["sepal_length", "sepal_width", "petal_length", "petal_width"]):
statistic, p_value = stats.ks_2samp(baseline[:, i], new_data[:, i])
drifted = p_value < self.threshold
drift_results[feature_name] = {
"drifted": drifted,
"p_value": p_value,
"baseline_mean": baseline[:, i].mean(),
"new_mean": new_data[:, i].mean()
}
return drift_results
# Usage in production
baseline_data = np.random.randn(1000, 4) # Training data
detector = DriftDetector(baseline_data)
# Every hour, check incoming hour's data
hourly_data = np.random.randn(500, 4)
drift_detected = detector.detect_kolmogorov_smirnov(hourly_data)
if drift_detected:
print("Drift detected. Consider retraining.")
Tracking Prediction Accuracy Over Time
If labels arrive with a delay (e.g., user clicks on a recommendation), track actual accuracy:
import pandas as pd
from datetime import datetime, timedelta
class AccuracyTracker:
def __init__(self, db_path: str = "predictions.db"):
self.db_path = db_path
def log_prediction(self, prediction_id: str, model_version: str, prediction: int, confidence: float, features: list):
"""Log a prediction immediately."""
df = pd.DataFrame([{
"prediction_id": prediction_id,
"model_version": model_version,
"prediction": prediction,
"confidence": confidence,
"features": str(features),
"timestamp": datetime.now(),
"actual_label": None, # Will be filled in later
"correct": None
}])
df.to_sql("predictions", sqlite3.connect(self.db_path), if_exists="append", index=False)
def update_with_ground_truth(self, prediction_id: str, actual_label: int):
"""Update with ground truth when available (hours or days later)."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
UPDATE predictions
SET actual_label = ?, correct = (CASE WHEN prediction = ? THEN 1 ELSE 0 END)
WHERE prediction_id = ?
""", (actual_label, actual_label, prediction_id))
conn.commit()
conn.close()
def get_accuracy_over_time(self, model_version: str, hours: int = 24) -> pd.DataFrame:
"""Calculate accuracy by hour for recent window."""
conn = sqlite3.connect(self.db_path)
cutoff_time = datetime.now() - timedelta(hours=hours)
df = pd.read_sql_query(f"""
SELECT
strftime('%Y-%m-%d %H:00', timestamp) as hour,
COUNT(*) as predictions,
SUM(correct) as correct_predictions,
SUM(correct) * 1.0 / COUNT(*) as accuracy,
AVG(confidence) as avg_confidence
FROM predictions
WHERE model_version = '{model_version}'
AND timestamp > '{cutoff_time}'
AND actual_label IS NOT NULL
GROUP BY hour
ORDER BY hour DESC
""", conn)
conn.close()
return df
# Usage
tracker = AccuracyTracker()
# At prediction time
pred_id = str(uuid4())
tracker.log_prediction(pred_id, "1.2.0", pred=0, confidence=0.95, features=[5.1, 3.5])
# Hours later, when ground truth arrives
tracker.update_with_ground_truth(pred_id, actual_label=0)
# Monitor accuracy
accuracy_df = tracker.get_accuracy_over_time("1.2.0", hours=24)
print(accuracy_df)
Monitoring Dashboard with Grafana
Create a Grafana dashboard to visualize metrics:
- Add Prometheus as a data source pointing to
http://prometheus:9090. - Create panels:
- Prediction latency (p50, p95, p99):
histogram_quantile(0.95, ml_prediction_latency_seconds) - Throughput (predictions/sec):
rate(ml_predictions_total[1m]) - Error rate:
rate(ml_prediction_errors_total[1m]) / rate(ml_predictions_total[1m]) - Feature drift (min/max over time):
ml_feature_min,ml_feature_max - Accuracy by hour: Query from SQL database
- Prediction latency (p50, p95, p99):
Comparison Table: Monitoring Approaches
| Approach | Setup | Drift Detection | Latency | Cost | Best For |
|---|---|---|---|---|---|
| Prometheus + Grafana | Moderate | Manual | Real-time | Free | In-house deployment |
| DataDog | Easy | Built-in | Real-time | Paid | Managed/cloud services |
| Evidently AI | Moderate | Built-in | Batch | Free/paid | ML-specific monitoring |
| CloudWatch (AWS) | Easy | Manual | Real-time | Included | AWS deployments |
| Custom SQL logs | Easy | Manual | Delayed | Free | Small-scale services |
Key Takeaways
- Instrument your API with Prometheus metrics: latency, throughput, errors, feature distributions.
- Use statistical tests (KS, Wasserstein) to detect data drift automatically.
- Track prediction accuracy over time by logging predictions and updating with ground truth.
- Set up a Grafana dashboard to visualize operational and predictive metrics.
- Alert on drift (e.g., "accuracy dropped 5% in the last week") to trigger retraining.
Frequently Asked Questions
How often should I retrain a model?
Depends on drift rate. If drift is slow (accuracy decays 1% per month), retrain monthly. If fast (5% per week), retrain weekly. Monitor drift and let data guide you.
Can I automate retraining based on drift alerts?
Yes. Combine drift detection with a CI/CD pipeline: if drift is detected, trigger retraining, validate accuracy on holdout data, and automatically promote to production if metrics improve.
Should I monitor individual predictions or aggregate metrics?
Both. Aggregate metrics reveal systemic drift; individual predictions help debug corner cases (why did this prediction fail?). Log both for full observability.
How do I handle label imbalance in accuracy tracking?
Use stratified sampling and per-class metrics (precision, recall, F1 per class) instead of overall accuracy. Accuracy is misleading when one class dominates.
Further Reading
- Prometheus Monitoring — time-series metrics collection.
- Grafana Dashboards — visualization and alerting.
- Evidently AI — ML-specific monitoring and drift detection.
- MLflow Model Validation — model quality tracking.