Refactor rate limiter and datetime formatting

This commit is contained in:
2026-01-22 02:24:03 +01:00
parent f297a8d740
commit 04114ab659
8 changed files with 133 additions and 135 deletions

View File

@@ -3,127 +3,111 @@
import threading
import time
from collections import defaultdict
from datetime import datetime
class RateLimiter:
"""Sliding window rate limiter with separate good/bad request tracking."""
class BaseLimiter:
"""Base sliding window rate limiter."""
def __init__(self, window, max_requests, enabled, recording=False):
"""
Initialize rate limiter.
Args:
window: Time window in seconds.
max_requests: Maximum requests allowed in window.
enabled: Whether rate limiting is enabled.
recording: Wheter IPs are recorded when hitting limit.
"""
self.window = window
self.max_requests = max_requests
self.enabled = enabled
self.recording = recording
self.requests = defaultdict(list)
self.lock = threading.Lock()
def _cleanup_old(self, timestamps):
"""Remove timestamps older than window."""
cutoff = time.time() - self.window
return [t for t in timestamps if t > cutoff]
def record(self, ip):
"""Record a request (without checking)."""
if not self.enabled:
return
with self.lock:
self.requests[ip] = self._cleanup_old(self.requests[ip])
self.requests[ip].append(time.time())
def cleanup(self):
"""Remove stale entries to prevent memory leak."""
with self.lock:
for ip in list(self.requests.keys()):
self.requests[ip] = self._cleanup_old(self.requests[ip])
if not self.requests[ip]:
del self.requests[ip]
def is_blocked(self, ip):
"""
Check if IP is blocked.
Args:
ip: Client IP address.
Returns:
Tuple of (blocked, retry_at_datetime or None).
"""
if not self.enabled:
return False, None
with self.lock:
self.requests[ip] = self._cleanup_old(self.requests[ip])
if len(self.requests[ip]) >= self.max_requests:
if self.recording:
self.requests[ip].append(time.time())
oldest = self.requests[ip][-self.max_requests]
else:
oldest = self.requests[ip][0]
retry_at = datetime.fromtimestamp(oldest + self.window)
return True, retry_at
return False, None
class GoodLimiter(BaseLimiter):
"""Rate limiter for good (authenticated) requests."""
def __init__(self, config):
"""
Initialize rate limiter from config.
Initialize good request rate limiter from config.
Args:
config: Full configuration dictionary.
"""
rl_config = config.get("rate_limit", {})
self.enabled = rl_config.get("enabled", False)
self.good_window = rl_config.get("good_window_seconds", 60)
self.good_max = rl_config.get("good_max_requests", 30)
self.bad_window = rl_config.get("bad_window_seconds", 60)
self.bad_max = rl_config.get("bad_max_requests", 5)
rl = config.get("rate_limit", {})
super().__init__(
rl.get("good_window_seconds", 60),
rl.get("good_max_requests", 30),
rl.get("enabled", False),
False,
)
self.bad_requests = defaultdict(list)
self.good_requests = defaultdict(list)
self.bad_lock = threading.Lock()
self.good_lock = threading.Lock()
def _cleanup_old(self, timestamps, window):
"""Remove timestamps older than window."""
cutoff = time.time() - window
return [t for t in timestamps if t > cutoff]
class BadLimiter(BaseLimiter):
"""Rate limiter for bad (failed auth) requests."""
def is_blocked_bad(self, ip):
def __init__(self, config):
"""
Check if IP is blocked by the bad request rate limiter with recording
when IP is already limited.
Initialize bad request rate limiter from config.
Args:
ip: Client IP address.
Returns:
Tuple of (blocked, retry_after_seconds).
config: Full configuration dictionary.
"""
if not self.enabled:
return False, 0
now = time.time()
with self.bad_lock:
# Check bad requests
self.bad_requests[ip] = self._cleanup_old(
self.bad_requests[ip], self.bad_window
)
if len(self.bad_requests[ip]) >= self.bad_max:
self.bad_requests[ip].append(time.time())
oldest = min(self.bad_requests[ip][-self.bad_max:])
retry_after = int(oldest + self.bad_window - now) + 1
return True, max(1, retry_after)
return False, 0
def is_blocked_good(self, ip):
"""
Check if IP is blocked by the good request rate limiter without recording.
Args:
ip: Client IP address.
Returns:
Tuple of (blocked, retry_after_seconds).
"""
if not self.enabled:
return False, 0
now = time.time()
with self.good_lock:
# Check good requests
self.good_requests[ip] = self._cleanup_old(
self.good_requests[ip], self.good_window
)
if len(self.good_requests[ip]) >= self.good_max:
oldest = min(self.good_requests[ip])
retry_after = int(oldest + self.good_window - now) + 1
return True, max(1, retry_after)
return False, 0
def record_bad(self, ip):
"""Record a bad request (without checking)."""
if not self.enabled:
return
with self.bad_lock:
self.bad_requests[ip] = self._cleanup_old(
self.bad_requests[ip], self.bad_window
)
self.bad_requests[ip].append(time.time())
def record_good(self, ip):
"""Record a good request (without checking)."""
if not self.enabled:
return
with self.good_lock:
self.good_requests[ip] = self._cleanup_old(
self.good_requests[ip], self.good_window
)
self.good_requests[ip].append(time.time())
def cleanup(self):
"""Remove stale entries to prevent memory leak."""
with self.good_lock:
for ip in list(self.good_requests.keys()):
self.good_requests[ip] = self._cleanup_old(
self.good_requests[ip], self.good_window
)
if not self.good_requests[ip]:
del self.good_requests[ip]
with self.bad_lock:
for ip in list(self.bad_requests.keys()):
self.bad_requests[ip] = self._cleanup_old(
self.bad_requests[ip], self.bad_window
)
if not self.bad_requests[ip]:
del self.bad_requests[ip]
rl = config.get("rate_limit", {})
super().__init__(
rl.get("bad_window_seconds", 60),
rl.get("bad_max_requests", 5),
rl.get("enabled", False),
True,
)