stupa-pdf-api/backend/src/api/middleware/rate_limit.py

166 lines
5.2 KiB
Python

"""Rate limiting middleware for API endpoints."""
import time
from typing import Dict, Optional
from collections import defaultdict
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from fastapi import status
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Middleware for rate limiting API requests."""
def __init__(self, app, settings=None):
"""
Initialize rate limit middleware.
Args:
app: The FastAPI application
settings: Rate limit settings object
"""
super().__init__(app)
self.settings = settings
# Store request counts per IP
self.request_counts: Dict[str, Dict[str, float]] = defaultdict(dict)
# Default settings if not provided
self.requests_per_minute = 60
self.requests_per_hour = 1000
if settings:
self.requests_per_minute = getattr(settings, 'requests_per_minute', 60)
self.requests_per_hour = getattr(settings, 'requests_per_hour', 1000)
def _get_client_ip(self, request: Request) -> str:
"""
Get the client IP address from the request.
Args:
request: The incoming request
Returns:
The client IP address
"""
# Try to get real IP from headers (for proxy scenarios)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fallback to client host
if request.client:
return request.client.host
return "unknown"
def _is_rate_limited(self, client_ip: str) -> tuple[bool, Optional[str]]:
"""
Check if the client has exceeded rate limits.
Args:
client_ip: The client IP address
Returns:
Tuple of (is_limited, reason)
"""
current_time = time.time()
client_data = self.request_counts[client_ip]
# Clean up old entries
minute_ago = current_time - 60
hour_ago = current_time - 3600
# Remove entries older than an hour
client_data = {
timestamp: count
for timestamp, count in client_data.items()
if float(timestamp) > hour_ago
}
# Count requests in the last minute
minute_requests = sum(
count for timestamp, count in client_data.items()
if float(timestamp) > minute_ago
)
# Count requests in the last hour
hour_requests = sum(client_data.values())
# Check minute limit
if minute_requests >= self.requests_per_minute:
return True, f"Rate limit exceeded: {self.requests_per_minute} requests per minute"
# Check hour limit
if hour_requests >= self.requests_per_hour:
return True, f"Rate limit exceeded: {self.requests_per_hour} requests per hour"
# Update request count
timestamp_key = str(current_time)
client_data[timestamp_key] = client_data.get(timestamp_key, 0) + 1
self.request_counts[client_ip] = client_data
return False, None
async def dispatch(self, request: Request, call_next):
"""
Process the request and apply rate limiting.
Args:
request: The incoming request
call_next: The next middleware or endpoint
Returns:
The response
"""
# Skip rate limiting for health check endpoints
if request.url.path in ["/health", "/ready", "/docs", "/redoc", "/openapi.json"]:
return await call_next(request)
# Get client IP
client_ip = self._get_client_ip(request)
# Check rate limit
is_limited, reason = self._is_rate_limited(client_ip)
if is_limited:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"detail": reason,
"type": "rate_limit_exceeded"
},
headers={
"Retry-After": "60", # Suggest retry after 60 seconds
"X-RateLimit-Limit": str(self.requests_per_minute),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(int(time.time()) + 60)
}
)
# Process the request
response = await call_next(request)
# Add rate limit headers to successful responses
if hasattr(self, 'request_counts') and client_ip in self.request_counts:
current_time = time.time()
minute_ago = current_time - 60
minute_requests = sum(
count for timestamp, count in self.request_counts[client_ip].items()
if float(timestamp) > minute_ago
)
remaining = max(0, self.requests_per_minute - minute_requests)
response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(int(current_time) + 60)
return response