166 lines
5.2 KiB
Python
166 lines
5.2 KiB
Python
"""Rate limiting middleware for API endpoints."""
|
|
|
|
import time
|
|
from typing import Dict, Optional
|
|
from collections import defaultdict
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse
|
|
from fastapi import status
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware for rate limiting API requests."""
|
|
|
|
def __init__(self, app, settings=None):
|
|
"""
|
|
Initialize rate limit middleware.
|
|
|
|
Args:
|
|
app: The FastAPI application
|
|
settings: Rate limit settings object
|
|
"""
|
|
super().__init__(app)
|
|
self.settings = settings
|
|
|
|
# Store request counts per IP
|
|
self.request_counts: Dict[str, Dict[str, float]] = defaultdict(dict)
|
|
|
|
# Default settings if not provided
|
|
self.requests_per_minute = 60
|
|
self.requests_per_hour = 1000
|
|
|
|
if settings:
|
|
self.requests_per_minute = getattr(settings, 'requests_per_minute', 60)
|
|
self.requests_per_hour = getattr(settings, 'requests_per_hour', 1000)
|
|
|
|
def _get_client_ip(self, request: Request) -> str:
|
|
"""
|
|
Get the client IP address from the request.
|
|
|
|
Args:
|
|
request: The incoming request
|
|
|
|
Returns:
|
|
The client IP address
|
|
"""
|
|
# Try to get real IP from headers (for proxy scenarios)
|
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
if forwarded_for:
|
|
return forwarded_for.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("X-Real-IP")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
# Fallback to client host
|
|
if request.client:
|
|
return request.client.host
|
|
|
|
return "unknown"
|
|
|
|
def _is_rate_limited(self, client_ip: str) -> tuple[bool, Optional[str]]:
|
|
"""
|
|
Check if the client has exceeded rate limits.
|
|
|
|
Args:
|
|
client_ip: The client IP address
|
|
|
|
Returns:
|
|
Tuple of (is_limited, reason)
|
|
"""
|
|
current_time = time.time()
|
|
client_data = self.request_counts[client_ip]
|
|
|
|
# Clean up old entries
|
|
minute_ago = current_time - 60
|
|
hour_ago = current_time - 3600
|
|
|
|
# Remove entries older than an hour
|
|
client_data = {
|
|
timestamp: count
|
|
for timestamp, count in client_data.items()
|
|
if float(timestamp) > hour_ago
|
|
}
|
|
|
|
# Count requests in the last minute
|
|
minute_requests = sum(
|
|
count for timestamp, count in client_data.items()
|
|
if float(timestamp) > minute_ago
|
|
)
|
|
|
|
# Count requests in the last hour
|
|
hour_requests = sum(client_data.values())
|
|
|
|
# Check minute limit
|
|
if minute_requests >= self.requests_per_minute:
|
|
return True, f"Rate limit exceeded: {self.requests_per_minute} requests per minute"
|
|
|
|
# Check hour limit
|
|
if hour_requests >= self.requests_per_hour:
|
|
return True, f"Rate limit exceeded: {self.requests_per_hour} requests per hour"
|
|
|
|
# Update request count
|
|
timestamp_key = str(current_time)
|
|
client_data[timestamp_key] = client_data.get(timestamp_key, 0) + 1
|
|
self.request_counts[client_ip] = client_data
|
|
|
|
return False, None
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
"""
|
|
Process the request and apply rate limiting.
|
|
|
|
Args:
|
|
request: The incoming request
|
|
call_next: The next middleware or endpoint
|
|
|
|
Returns:
|
|
The response
|
|
"""
|
|
# Skip rate limiting for health check endpoints
|
|
if request.url.path in ["/health", "/ready", "/docs", "/redoc", "/openapi.json"]:
|
|
return await call_next(request)
|
|
|
|
# Get client IP
|
|
client_ip = self._get_client_ip(request)
|
|
|
|
# Check rate limit
|
|
is_limited, reason = self._is_rate_limited(client_ip)
|
|
|
|
if is_limited:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
content={
|
|
"detail": reason,
|
|
"type": "rate_limit_exceeded"
|
|
},
|
|
headers={
|
|
"Retry-After": "60", # Suggest retry after 60 seconds
|
|
"X-RateLimit-Limit": str(self.requests_per_minute),
|
|
"X-RateLimit-Remaining": "0",
|
|
"X-RateLimit-Reset": str(int(time.time()) + 60)
|
|
}
|
|
)
|
|
|
|
# Process the request
|
|
response = await call_next(request)
|
|
|
|
# Add rate limit headers to successful responses
|
|
if hasattr(self, 'request_counts') and client_ip in self.request_counts:
|
|
current_time = time.time()
|
|
minute_ago = current_time - 60
|
|
|
|
minute_requests = sum(
|
|
count for timestamp, count in self.request_counts[client_ip].items()
|
|
if float(timestamp) > minute_ago
|
|
)
|
|
|
|
remaining = max(0, self.requests_per_minute - minute_requests)
|
|
|
|
response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute)
|
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
|
response.headers["X-RateLimit-Reset"] = str(int(current_time) + 60)
|
|
|
|
return response
|