"""Rate limiting middleware for API endpoints.""" import time from typing import Dict, Optional from collections import defaultdict from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse from fastapi import status class RateLimitMiddleware(BaseHTTPMiddleware): """Middleware for rate limiting API requests.""" def __init__(self, app, settings=None): """ Initialize rate limit middleware. Args: app: The FastAPI application settings: Rate limit settings object """ super().__init__(app) self.settings = settings # Store request counts per IP self.request_counts: Dict[str, Dict[str, float]] = defaultdict(dict) # Default settings if not provided self.requests_per_minute = 60 self.requests_per_hour = 1000 if settings: self.requests_per_minute = getattr(settings, 'requests_per_minute', 60) self.requests_per_hour = getattr(settings, 'requests_per_hour', 1000) def _get_client_ip(self, request: Request) -> str: """ Get the client IP address from the request. Args: request: The incoming request Returns: The client IP address """ # Try to get real IP from headers (for proxy scenarios) forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: return forwarded_for.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip # Fallback to client host if request.client: return request.client.host return "unknown" def _is_rate_limited(self, client_ip: str) -> tuple[bool, Optional[str]]: """ Check if the client has exceeded rate limits. Args: client_ip: The client IP address Returns: Tuple of (is_limited, reason) """ current_time = time.time() client_data = self.request_counts[client_ip] # Clean up old entries minute_ago = current_time - 60 hour_ago = current_time - 3600 # Remove entries older than an hour client_data = { timestamp: count for timestamp, count in client_data.items() if float(timestamp) > hour_ago } # Count requests in the last minute minute_requests = sum( count for timestamp, count in client_data.items() if float(timestamp) > minute_ago ) # Count requests in the last hour hour_requests = sum(client_data.values()) # Check minute limit if minute_requests >= self.requests_per_minute: return True, f"Rate limit exceeded: {self.requests_per_minute} requests per minute" # Check hour limit if hour_requests >= self.requests_per_hour: return True, f"Rate limit exceeded: {self.requests_per_hour} requests per hour" # Update request count timestamp_key = str(current_time) client_data[timestamp_key] = client_data.get(timestamp_key, 0) + 1 self.request_counts[client_ip] = client_data return False, None async def dispatch(self, request: Request, call_next): """ Process the request and apply rate limiting. Args: request: The incoming request call_next: The next middleware or endpoint Returns: The response """ # Skip rate limiting for health check endpoints if request.url.path in ["/health", "/ready", "/docs", "/redoc", "/openapi.json"]: return await call_next(request) # Get client IP client_ip = self._get_client_ip(request) # Check rate limit is_limited, reason = self._is_rate_limited(client_ip) if is_limited: return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content={ "detail": reason, "type": "rate_limit_exceeded" }, headers={ "Retry-After": "60", # Suggest retry after 60 seconds "X-RateLimit-Limit": str(self.requests_per_minute), "X-RateLimit-Remaining": "0", "X-RateLimit-Reset": str(int(time.time()) + 60) } ) # Process the request response = await call_next(request) # Add rate limit headers to successful responses if hasattr(self, 'request_counts') and client_ip in self.request_counts: current_time = time.time() minute_ago = current_time - 60 minute_requests = sum( count for timestamp, count in self.request_counts[client_ip].items() if float(timestamp) > minute_ago ) remaining = max(0, self.requests_per_minute - minute_requests) response.headers["X-RateLimit-Limit"] = str(self.requests_per_minute) response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Reset"] = str(int(current_time) + 60) return response