""" 简单的速率限制中间件 基于内存的滑动窗口计数器 """ import time from collections import defaultdict from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse class RateLimitMiddleware(BaseHTTPMiddleware): """ 速率限制中间件 - 默认: 60 次/分钟 per IP - 登录/注册: 10 次/分钟 per IP """ def __init__(self, app, default_limit: int = 60, window_seconds: int = 60): super().__init__(app) self.default_limit = default_limit self.window_seconds = window_seconds self.requests: dict[str, list[float]] = defaultdict(list) # Stricter limits for auth endpoints self.strict_paths = {"/api/v1/auth/login", "/api/v1/auth/register"} self.strict_limit = 10 async def dispatch(self, request: Request, call_next): client_ip = request.client.host if request.client else "unknown" path = request.url.path now = time.time() # Determine rate limit if path in self.strict_paths: key = f"{client_ip}:{path}" limit = self.strict_limit else: key = client_ip limit = self.default_limit # Clean old entries window_start = now - self.window_seconds self.requests[key] = [t for t in self.requests[key] if t > window_start] # Check limit if len(self.requests[key]) >= limit: return JSONResponse( status_code=429, content={"detail": "请求过于频繁,请稍后再试"}, ) # Record request self.requests[key].append(now) # Periodic cleanup (every 1000 requests to this key) if len(self.requests) > 10000: self._cleanup(now) response = await call_next(request) return response def _cleanup(self, now: float): """Clean up expired entries""" window_start = now - self.window_seconds expired_keys = [ k for k, v in self.requests.items() if not v or v[-1] < window_start ] for k in expired_keys: del self.requests[k]