WebSocket Security: Authentication and Authorization
WebSockets are stateful connections, making them targets for session hijacking, message injection, and DoS attacks. Unlike HTTP requests, which are inherently one-off, a compromised WebSocket connection gives an attacker a persistent foothold. This article secures the chat application with token-based authentication, origin validation, per-user rate limiting, and optional message encryption.
Token-Based Authentication
Users authenticate via HTTP with a username and password, receiving a JWT token. This token is then used to establish a WebSocket connection:
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthCredentials
import jwt
from datetime import datetime, timedelta
from typing import Optional
SECRET_KEY = "your-secret-key-change-in-production"
ALGORITHM = "HS256"
TOKEN_EXPIRY = timedelta(hours=1)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + TOKEN_EXPIRY
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def verify_token(token: str) -> dict:
"""Verify JWT and return payload."""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=401, detail="Invalid token")
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
@app.post("/auth/login")
async def login(username: str, password: str):
"""Authenticate and return a JWT token."""
# Validate credentials (use proper password hashing in production)
if not authenticate_user(username, password):
raise HTTPException(status_code=401, detail="Invalid credentials")
access_token = create_access_token(data={"sub": username})
return {"access_token": access_token, "token_type": "bearer"}
def authenticate_user(username: str, password: str) -> bool:
"""Check username and password. In production, use bcrypt."""
# Placeholder; use a database and hashed passwords in reality
users_db = {
"alice": "password123",
"bob": "secret456"
}
return users_db.get(username) == password
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
# Verify token before accepting connection
try:
payload = verify_token(token)
username = payload.get("sub")
except HTTPException:
await websocket.close(code=1008, reason="Unauthorized")
return
# Token valid; accept connection
await websocket.accept()
# ... rest of chat logic using authenticated username ...
# Never accept an unauthenticated connection
The browser client obtains a token at login, then passes it to the WebSocket URL:
async function login(username, password) {
const response = await fetch('/auth/login', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({username, password})
});
const data = await response.json();
return data.access_token;
}
const token = await login('alice', 'password123');
const ws = new WebSocket(`ws://localhost:8000/ws?token=${encodeURIComponent(token)}`);
CORS and Origin Validation
Browsers enforce CORS, but attackers can still forge WebSocket requests from malicious domains. Validate the Origin header:
ALLOWED_ORIGINS = ["https://example.com", "https://app.example.com"]
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
# Validate origin
origin = websocket.headers.get("origin")
if origin not in ALLOWED_ORIGINS:
await websocket.close(code=1008, reason="Origin not allowed")
return
# Verify token
try:
payload = verify_token(token)
except HTTPException:
await websocket.close(code=1008, reason="Unauthorized")
return
await websocket.accept()
# ...
This prevents attackers from establishing connections from evil.com even if they know a valid token.
Per-User Rate Limiting
Prevent spam and DoS by limiting messages per user:
from collections import defaultdict
from time import time
class RateLimiter:
def __init__(self, max_messages: int = 10, window_seconds: int = 60):
self.max_messages = max_messages
self.window = window_seconds
self.user_messages: defaultdict[str, list[float]] = defaultdict(list)
def is_allowed(self, username: str) -> bool:
"""Check if user has exceeded rate limit."""
now = time()
messages = self.user_messages[username]
# Remove old messages outside the window
messages[:] = [ts for ts in messages if now - ts < self.window]
# Check limit
if len(messages) >= self.max_messages:
return False
# Record new message
messages.append(now)
return True
limiter = RateLimiter(max_messages=10, window_seconds=60)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
# ... auth and accept ...
username = payload.get("sub")
try:
while True:
data = await websocket.receive_text()
# Check rate limit
if not limiter.is_allowed(username):
await websocket.send_json({
"type": "error",
"message": "Rate limit exceeded. Max 10 messages per 60 seconds."
})
continue
# Process message
# ...
except Exception:
pass
Users are limited to 10 messages per 60 seconds. Exceeding this results in an error message; the connection isn't closed, but the message is rejected. This prevents legitimate users who temporarily spam from losing their session while blocking attackers.
Message Encryption (Optional)
For sensitive applications, encrypt message payloads end-to-end:
from cryptography.fernet import Fernet
import base64
# Client and server share a key (or derive it from a shared secret)
ENCRYPTION_KEY = base64.urlsafe_b64encode(b"your-32-char-secret-key-here-!!")
cipher = Fernet(ENCRYPTION_KEY)
def encrypt_message(message: str) -> str:
return cipher.encrypt(message.encode()).decode()
def decrypt_message(encrypted: str) -> str:
return cipher.decrypt(encrypted.encode()).decode()
# Server-side
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
# ... auth ...
try:
while True:
encrypted_data = await websocket.receive_text()
try:
decrypted = decrypt_message(encrypted_data)
message = json.loads(decrypted)
# Process message
# ...
# Broadcast encrypted response
response_encrypted = encrypt_message(json.dumps({"type": "ack"}))
await websocket.send_text(response_encrypted)
except Exception as e:
await websocket.send_json({"type": "error", "message": "Decryption failed"})
except Exception:
pass
Client-side:
async function encryptMessage(msg) {
// Use TweetNaCl.js or libsodium.js for browser encryption
return nacl.secretbox(msg, nonce, sharedKey);
}
ws.send(await encryptMessage(JSON.stringify({text: "Hello"})));
Note: Encryption adds latency and complexity. Use it only if your threat model requires it (e.g., legal/financial data). For most applications, HTTPS/TLS is sufficient.
Blocking Malicious Users
If a user is repeatedly attacking (high rate limit violations, invalid tokens), temporarily block them:
class BlockList:
def __init__(self, block_duration: int = 3600):
self.blocked_users: dict[str, float] = {}
self.duration = block_duration
def is_blocked(self, username: str) -> bool:
if username in self.blocked_users:
if time() - self.blocked_users[username] < self.duration:
return True
else:
del self.blocked_users[username]
return False
def block(self, username: str):
self.blocked_users[username] = time()
blocklist = BlockList(block_duration=3600)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
payload = verify_token(token)
username = payload.get("sub")
# Check blocklist
if blocklist.is_blocked(username):
await websocket.close(code=1008, reason="User blocked")
return
try:
while True:
data = await websocket.receive_text()
if not limiter.is_allowed(username):
# Multiple rate limit violations; block user
if limiter.user_messages[username].count() > 20:
blocklist.block(username)
await websocket.close(code=1008, reason="Blocked for abuse")
return
except Exception:
pass
Users blocked for abuse can't reconnect for 1 hour. Log their IP and username for audit purposes.
Logging and Audit Trail
Log authentication events and suspicious activity:
import logging
logging.basicConfig(filename="websocket_security.log", level=logging.INFO)
security_log = logging.getLogger("security")
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
ip = websocket.client.host
try:
payload = verify_token(token)
username = payload.get("sub")
security_log.info(f"WebSocket connected: user={username}, ip={ip}")
except HTTPException as e:
security_log.warning(f"WebSocket auth failed: reason={e.detail}, ip={ip}")
await websocket.close(code=1008)
return
if blocklist.is_blocked(username):
security_log.warning(f"WebSocket blocked user: user={username}, ip={ip}")
await websocket.close(code=1008)
return
# ... rest of handler ...
try:
while True:
data = await websocket.receive_text()
# ... process ...
except Exception:
security_log.info(f"WebSocket disconnected: user={username}, ip={ip}")
Monitor this log for suspicious patterns: repeated auth failures, rate limit violations from the same IP, etc.
Key Takeaways
- Require authentication before accepting WebSocket connections; use JWT tokens from an HTTP login endpoint.
- Validate the
Originheader to prevent cross-origin hijacking. - Implement per-user rate limiting to prevent spam and DoS.
- Optionally encrypt messages end-to-end for sensitive data.
- Maintain a blocklist for repeat offenders, temporarily banning them.
- Log all authentication and suspicious activity for audit and incident response.
Frequently Asked Questions
Can I use OAuth 2.0 / OpenID Connect instead of JWT?
Yes. During the OAuth flow, users authenticate with a third-party provider (Google, GitHub), receive an OAuth token, and exchange it for a short-lived JWT from your server. Your WebSocket endpoint then validates the JWT. This allows SSO without managing passwords.
What if a JWT token is compromised?
Short-lived tokens (1 hour) limit exposure. Implement token revocation: when a user logs out, add their token to a blacklist. Check the blacklist during WebSocket authentication: if token in blacklist: reject(). Use Redis for fast blacklist lookups.
How do I handle token expiry during a long WebSocket session?
Implement token refresh: before expiry, the client requests a new token via HTTP and resends it via a WebSocket message (e.g., {"type": "refresh_token", "token": "new-token"}). The server validates and updates its session.
Can WebSocket connections share authentication state with HTTP endpoints?
Yes. Both use the same JWT. A user logged in via HTTP can immediately use WebSocket with their token. However, WebSockets are stateful and longer-lived, so validate tokens more frequently (e.g., every 5 minutes) to catch revoked tokens.