Custom Middleware in FastAPI: Production Guide
Middleware in FastAPI is a layer of code that runs before your routes and after the response is generated. It intercepts HTTP requests, can modify them, run application logic, and then intercept responses before they reach the client. Common middleware handles authentication, logging, CORS headers, rate limiting, and request/response modification. Unlike decorators on individual routes, middleware applies globally—one middleware function controls behavior across your entire API.
As someone who built observability systems for production APIs, I've seen well-designed middleware transform raw requests into rich telemetry. This guide shows you how to write middleware that's both powerful and maintainable.
Understanding Middleware Order and Execution
Middleware is a stack. The first middleware added is the outermost layer; requests pass through it first, then through inner layers, hit your routes, and responses travel backward. Order matters—authentication middleware should run before rate limiting so you rate-limit per user, not per IP.
from fastapi import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
import time
app = FastAPI()
# Middleware 1: Timing
class TimingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# Middleware 2: Request ID
class RequestIDMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
app.add_middleware(RequestIDMiddleware) # Added second, inner layer
app.add_middleware(TimingMiddleware) # Added first, outer layer
Requests flow: TimingMiddleware → RequestIDMiddleware → route → RequestIDMiddleware → TimingMiddleware. Responses bubble back in reverse.
Writing a Simple Logging Middleware
Here's a production-grade logging middleware that records every request and response:
from fastapi import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
import logging
import time
import json
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next) -> Response:
# Capture request details
method = request.method
path = request.url.path
client_ip = request.client.host if request.client else "unknown"
# Read request body (if present) without consuming it
body = b""
if method in ("POST", "PUT", "PATCH"):
body = await request.body()
# Re-wrap body so route handler can read it
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive
start_time = time.time()
try:
response = await call_next(request)
process_time = time.time() - start_time
logger.info(
f"method={method} path={path} status={response.status_code} "
f"duration={process_time:.3f}s client_ip={client_ip}"
)
response.headers["X-Process-Time"] = str(process_time)
return response
except Exception as e:
process_time = time.time() - start_time
logger.error(
f"method={method} path={path} error={str(e)} "
f"duration={process_time:.3f}s client_ip={client_ip}"
)
raise
app = FastAPI()
app.add_middleware(LoggingMiddleware)
This middleware captures the HTTP method, path, response status, and duration. For debugging, it also attempts to read the request body—note that we re-wrap it so the route handler can still read it.
Async Middleware with Per-Request State
Store data on request.state to pass information through the middleware stack and into route handlers:
from fastapi import FastAPI, Depends, Request
from starlette.middleware.base import BaseHTTPMiddleware
import uuid
class RequestContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Assign a unique ID and timestamp to every request
request.state.request_id = str(uuid.uuid4())
request.state.start_time = time.time()
response = await call_next(request)
# Add elapsed time to response headers
elapsed = time.time() - request.state.start_time
response.headers["X-Elapsed-Ms"] = str(int(elapsed * 1000))
return response
app = FastAPI()
app.add_middleware(RequestContextMiddleware)
@app.get("/data")
async def get_data(request: Request):
# Access request context set by middleware
return {
"request_id": request.state.request_id,
"elapsed_ms": int((time.time() - request.state.start_time) * 1000)
}
Middleware can enrich the request object with authentication claims, rate-limit counters, or trace IDs. Routes access these via request.state.
Middleware for Authentication and Authorization
For global authentication, middleware can check credentials before routes run:
from fastapi import HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthCredentials
import jwt
security = HTTPBearer()
class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Skip auth for public routes
if request.url.path in ["/health", "/docs"]:
return await call_next(request)
# Extract bearer token
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid authorization header"
)
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
request.state.user_id = payload["sub"]
request.state.user_roles = payload.get("roles", [])
except jwt.InvalidTokenError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
response = await call_next(request)
return response
app = FastAPI()
app.add_middleware(AuthMiddleware)
@app.get("/profile")
async def get_profile(request: Request):
user_id = request.state.user_id
return {"user_id": user_id, "roles": request.state.user_roles}
Middleware-based auth is useful when all or most routes require authentication. For selective auth, dependencies (as covered in article 1) are cleaner.
Rate Limiting Middleware
Implement basic rate limiting using an in-memory counter:
from collections import defaultdict
from datetime import datetime, timedelta
from starlette.responses import JSONResponse
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, requests_per_minute: int = 60):
super().__init__(app)
self.requests_per_minute = requests_per_minute
self.clients = defaultdict(list) # IP -> [timestamp, timestamp, ...]
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host if request.client else "unknown"
now = datetime.now()
# Remove timestamps older than 1 minute
minute_ago = now - timedelta(minutes=1)
self.clients[client_ip] = [
ts for ts in self.clients[client_ip]
if ts > minute_ago
]
# Check rate limit
if len(self.clients[client_ip]) >= self.requests_per_minute:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"}
)
# Record this request
self.clients[client_ip].append(now)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute)
response.headers["X-RateLimit-Remaining"] = str(
self.requests_per_minute - len(self.clients[client_ip])
)
return response
app = FastAPI()
app.add_middleware(RateLimitMiddleware, requests_per_minute=100)
For production, use a library like slowapi or integrate with Redis. This example shows the concept—middleware can intercept and reject requests based on custom logic.
Middleware vs. Decorators: When to Use Each
Use middleware for cross-cutting concerns that apply globally (logging, CORS, rate limiting). Use decorators or dependencies for route-specific logic (authentication, validation):
| Aspect | Middleware | Decorator/Dependency |
|---|---|---|
| Scope | Global or per-router | Per-route |
| Order | Clear, deterministic | Less visible |
| Access to route info | Limited (path, headers) | Full (request, response) |
| Performance | Runs on every request | Only on matched routes |
| Use case | Auth, logging, CORS | Validation, enrichment |
Middleware for "everyone needs this"; dependencies for "this route needs this."
Key Takeaways
- Middleware runs before and after routes; order matters (first added = outermost).
- Use
BaseHTTPMiddlewareandasync def dispatch()for custom middleware. - Store per-request state in
request.stateand access it in routes. - Middleware can modify requests, responses, or reject requests entirely.
- Common uses: logging, authentication, rate limiting, CORS, request enrichment.
- Choose middleware for global concerns; dependencies for route-specific logic.
Frequently Asked Questions
How do I read and re-send a request body in middleware?
Use await request.body() to read it, then wrap it in a receive() callable before passing to call_next(). This prevents the route handler from receiving an empty body. See the LoggingMiddleware example above.
What's the difference between BaseHTTPMiddleware and pure ASGI middleware?
BaseHTTPMiddleware is a convenience wrapper; it handles the ASGI callables for you. Pure ASGI middleware is lower-level and faster but requires more boilerplate. For most use cases, BaseHTTPMiddleware is fine.
Can middleware modify the response body?
Yes, but it's tricky—the response is already being written. You'd need to buffer the entire response, modify it, and re-send. For this, use a response wrapper or middleware that modifies response headers only (simpler and more efficient).
How do I skip middleware for certain routes?
Check the path in middleware (if request.url.path.startswith("/public"): return await call_next(request)). Alternatively, use Starlette's Middleware classes with path patterns.
Is middleware synchronous or asynchronous?
Middleware is async. Use async def dispatch(). If you have sync code, run it in a thread pool using asyncio.to_thread().