114 lines
3.3 KiB
Python
114 lines
3.3 KiB
Python
"""Rate limiting with sliding window."""
|
|
|
|
import threading
|
|
import time
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
|
|
|
|
class BaseLimiter:
|
|
"""Base sliding window rate limiter."""
|
|
|
|
def __init__(self, window, max_requests, enabled, recording=False):
|
|
"""
|
|
Initialize rate limiter.
|
|
|
|
Args:
|
|
window: Time window in seconds.
|
|
max_requests: Maximum requests allowed in window.
|
|
enabled: Whether rate limiting is enabled.
|
|
recording: Wheter IPs are recorded when hitting limit.
|
|
"""
|
|
self.window = window
|
|
self.max_requests = max_requests
|
|
self.enabled = enabled
|
|
self.recording = recording
|
|
self.requests = defaultdict(list)
|
|
self.lock = threading.Lock()
|
|
|
|
def _cleanup_old(self, timestamps):
|
|
"""Remove timestamps older than window."""
|
|
cutoff = time.time() - self.window
|
|
return [t for t in timestamps if t > cutoff]
|
|
|
|
def record(self, ip):
|
|
"""Record a request (without checking)."""
|
|
if not self.enabled:
|
|
return
|
|
|
|
with self.lock:
|
|
self.requests[ip] = self._cleanup_old(self.requests[ip])
|
|
self.requests[ip].append(time.time())
|
|
|
|
def cleanup(self):
|
|
"""Remove stale entries to prevent memory leak."""
|
|
with self.lock:
|
|
for ip in list(self.requests.keys()):
|
|
self.requests[ip] = self._cleanup_old(self.requests[ip])
|
|
if not self.requests[ip]:
|
|
del self.requests[ip]
|
|
|
|
def is_blocked(self, ip):
|
|
"""
|
|
Check if IP is blocked.
|
|
|
|
Args:
|
|
ip: Client IP address.
|
|
|
|
Returns:
|
|
Tuple of (blocked, retry_at_datetime or None).
|
|
"""
|
|
if not self.enabled:
|
|
return False, None
|
|
|
|
with self.lock:
|
|
self.requests[ip] = self._cleanup_old(self.requests[ip])
|
|
if len(self.requests[ip]) >= self.max_requests:
|
|
if self.recording:
|
|
self.requests[ip].append(time.time())
|
|
oldest = self.requests[ip][-self.max_requests]
|
|
else:
|
|
oldest = self.requests[ip][0]
|
|
retry_at = datetime.fromtimestamp(oldest + self.window)
|
|
return True, retry_at
|
|
|
|
return False, None
|
|
|
|
|
|
class GoodLimiter(BaseLimiter):
|
|
"""Rate limiter for good (authenticated) requests."""
|
|
|
|
def __init__(self, config):
|
|
"""
|
|
Initialize good request rate limiter from config.
|
|
|
|
Args:
|
|
config: Full configuration dictionary.
|
|
"""
|
|
rl = config.get("rate_limit", {})
|
|
super().__init__(
|
|
rl.get("good_window_seconds", 60),
|
|
rl.get("good_max_requests", 30),
|
|
rl.get("enabled", False),
|
|
False,
|
|
)
|
|
|
|
|
|
class BadLimiter(BaseLimiter):
|
|
"""Rate limiter for bad (failed auth) requests."""
|
|
|
|
def __init__(self, config):
|
|
"""
|
|
Initialize bad request rate limiter from config.
|
|
|
|
Args:
|
|
config: Full configuration dictionary.
|
|
"""
|
|
rl = config.get("rate_limit", {})
|
|
super().__init__(
|
|
rl.get("bad_window_seconds", 60),
|
|
rl.get("bad_max_requests", 5),
|
|
rl.get("enabled", False),
|
|
True,
|
|
)
|