diff --git a/src/ddns_service/__init__.py b/src/ddns_service/__init__.py index 9b51eba..888cddf 100644 --- a/src/ddns_service/__init__.py +++ b/src/ddns_service/__init__.py @@ -13,6 +13,7 @@ __all__ = [ "cleanup", "cli", "config", + "datetime_str", "dns", "email", "logging", @@ -20,5 +21,8 @@ __all__ = [ "models", "ratelimit", "server", - "validation" ] + + +def datetime_str(datetime): + return datetime.strftime("%Y-%m-%d %H:%M:%S") diff --git a/src/ddns_service/app.py b/src/ddns_service/app.py index 5c8b319..b4ef356 100644 --- a/src/ddns_service/app.py +++ b/src/ddns_service/app.py @@ -8,7 +8,7 @@ import argon2 from .dns import DNSService from .email import EmailService from .models import create_tables, init_database -from .ratelimit import RateLimiter +from .ratelimit import BadLimiter, GoodLimiter class Application: @@ -32,7 +32,8 @@ class Application: # Service instances (initialized separately) self.dns_service = None self.email_service = None - self.rate_limiter = None + self.good_limiter = None + self.bad_limiter = None def init_database(self): """Initialize database connection and run migrations.""" @@ -50,10 +51,11 @@ class Application: self.email_service = EmailService(self.config) logging.info("Email service initialized") - def init_rate_limiter(self): - """Initialize rate limiter.""" - self.rate_limiter = RateLimiter(self.config) - logging.info("Rate limiter initialized") + def init_rate_limiters(self): + """Initialize rate limiters.""" + self.good_limiter = GoodLimiter(self.config) + self.bad_limiter = BadLimiter(self.config) + logging.info("Rate limiters initialized") def signal_shutdown(self): """Signal the application to shut down.""" diff --git a/src/ddns_service/cleanup.py b/src/ddns_service/cleanup.py index cf21205..1f77217 100644 --- a/src/ddns_service/cleanup.py +++ b/src/ddns_service/cleanup.py @@ -135,8 +135,10 @@ class RateLimitCleanupThread(threading.Thread): while not self.stop_event.wait(self.interval): try: - if self.app.rate_limiter: - self.app.rate_limiter.cleanup() + if self.app.good_limiter: + self.app.good_limiter.cleanup() + if self.app.bad_limiter: + self.app.bad_limiter.cleanup() except Exception as e: logging.error(f"Rate limit cleanup error: {e}") diff --git a/src/ddns_service/cli.py b/src/ddns_service/cli.py index 67112b1..8955209 100644 --- a/src/ddns_service/cli.py +++ b/src/ddns_service/cli.py @@ -3,6 +3,7 @@ import getpass import logging +from . import datetime_str from .cleanup import cleanup_expired from .models import ( DoesNotExist, @@ -27,7 +28,7 @@ def cmd_user_list(args, app): hostname_count = Hostname.select().where( Hostname.user == user ).count() - created_at = user.created_at.strftime("%Y-%m-%d %H:%M:%S") + created_at = datetime_str(user.created_at) print( f"{user.username:<20} {user.email:<30} " @@ -153,8 +154,8 @@ def cmd_hostname_list(args, app): ) print("-" * 132) for h in hostnames: - last_ipv4_update = h.last_ipv4_update.strftime("%Y-%m-%d %H:%M:%S") if h.last_ipv4_update else "Never" - last_ipv6_update = h.last_ipv6_update.strftime("%Y-%m-%d %H:%M:%S") if h.last_ipv6_update else "Never" + last_ipv4_update = datetime_str(h.last_ipv4_update) if h.last_ipv4_update else "Never" + last_ipv6_update = datetime_str(h.last_ipv6_update) if h.last_ipv6_update else "Never" print( f"{h.hostname:<35} {h.user.username:<15} {h.zone:<20} " f"{h.dns_ttl:<8} {h.expiry_ttl:<8} {last_ipv4_update:<21} {last_ipv6_update}" diff --git a/src/ddns_service/email.py b/src/ddns_service/email.py index 12088d9..3e64888 100644 --- a/src/ddns_service/email.py +++ b/src/ddns_service/email.py @@ -2,6 +2,8 @@ import logging import smtplib + +from . import datetime_str from email.mime.text import MIMEText @@ -91,11 +93,11 @@ Hostname: {hostname} """ if last_ipv4: ip = last_ipv4[0] - last_update = last_ipv4[1].strftime("%Y-%m-%d %H:%M:%S") + last_update = datetime_str(last_ipv4[1]) body += f"IPv4 address: {ip} (last update: {last_update})\n" if last_ipv6: ip = last_ipv6[0] - last_update = last_ipv6[1].strftime("%Y-%m-%d %H:%M:%S") + last_update = datetime_str(last_ipv6[1]) body += f"IPv6 address: {ip} (last update: {last_update})\n" body += f"""Expiry TTL: {expiry_ttl} seconds diff --git a/src/ddns_service/main.py b/src/ddns_service/main.py index 8ba9b96..3c2b6ec 100644 --- a/src/ddns_service/main.py +++ b/src/ddns_service/main.py @@ -172,7 +172,7 @@ def main(): # Initialize all services for daemon mode app.init_dns() app.init_email() - app.init_rate_limiter() + app.init_rate_limiters() run_daemon(app) return 0 except Exception as e: diff --git a/src/ddns_service/ratelimit.py b/src/ddns_service/ratelimit.py index 5943e0a..400258e 100644 --- a/src/ddns_service/ratelimit.py +++ b/src/ddns_service/ratelimit.py @@ -3,127 +3,111 @@ import threading import time from collections import defaultdict +from datetime import datetime -class RateLimiter: - """Sliding window rate limiter with separate good/bad request tracking.""" +class BaseLimiter: + """Base sliding window rate limiter.""" + + def __init__(self, window, max_requests, enabled, recording=False): + """ + Initialize rate limiter. + + Args: + window: Time window in seconds. + max_requests: Maximum requests allowed in window. + enabled: Whether rate limiting is enabled. + recording: Wheter IPs are recorded when hitting limit. + """ + self.window = window + self.max_requests = max_requests + self.enabled = enabled + self.recording = recording + self.requests = defaultdict(list) + self.lock = threading.Lock() + + def _cleanup_old(self, timestamps): + """Remove timestamps older than window.""" + cutoff = time.time() - self.window + return [t for t in timestamps if t > cutoff] + + def record(self, ip): + """Record a request (without checking).""" + if not self.enabled: + return + + with self.lock: + self.requests[ip] = self._cleanup_old(self.requests[ip]) + self.requests[ip].append(time.time()) + + def cleanup(self): + """Remove stale entries to prevent memory leak.""" + with self.lock: + for ip in list(self.requests.keys()): + self.requests[ip] = self._cleanup_old(self.requests[ip]) + if not self.requests[ip]: + del self.requests[ip] + + def is_blocked(self, ip): + """ + Check if IP is blocked. + + Args: + ip: Client IP address. + + Returns: + Tuple of (blocked, retry_at_datetime or None). + """ + if not self.enabled: + return False, None + + with self.lock: + self.requests[ip] = self._cleanup_old(self.requests[ip]) + if len(self.requests[ip]) >= self.max_requests: + if self.recording: + self.requests[ip].append(time.time()) + oldest = self.requests[ip][-self.max_requests] + else: + oldest = self.requests[ip][0] + retry_at = datetime.fromtimestamp(oldest + self.window) + return True, retry_at + + return False, None + + +class GoodLimiter(BaseLimiter): + """Rate limiter for good (authenticated) requests.""" def __init__(self, config): """ - Initialize rate limiter from config. + Initialize good request rate limiter from config. Args: config: Full configuration dictionary. """ - rl_config = config.get("rate_limit", {}) - self.enabled = rl_config.get("enabled", False) - self.good_window = rl_config.get("good_window_seconds", 60) - self.good_max = rl_config.get("good_max_requests", 30) - self.bad_window = rl_config.get("bad_window_seconds", 60) - self.bad_max = rl_config.get("bad_max_requests", 5) + rl = config.get("rate_limit", {}) + super().__init__( + rl.get("good_window_seconds", 60), + rl.get("good_max_requests", 30), + rl.get("enabled", False), + False, + ) - self.bad_requests = defaultdict(list) - self.good_requests = defaultdict(list) - self.bad_lock = threading.Lock() - self.good_lock = threading.Lock() - def _cleanup_old(self, timestamps, window): - """Remove timestamps older than window.""" - cutoff = time.time() - window - return [t for t in timestamps if t > cutoff] +class BadLimiter(BaseLimiter): + """Rate limiter for bad (failed auth) requests.""" - def is_blocked_bad(self, ip): + def __init__(self, config): """ - Check if IP is blocked by the bad request rate limiter with recording - when IP is already limited. + Initialize bad request rate limiter from config. Args: - ip: Client IP address. - - Returns: - Tuple of (blocked, retry_after_seconds). + config: Full configuration dictionary. """ - if not self.enabled: - return False, 0 - - now = time.time() - - with self.bad_lock: - # Check bad requests - self.bad_requests[ip] = self._cleanup_old( - self.bad_requests[ip], self.bad_window - ) - if len(self.bad_requests[ip]) >= self.bad_max: - self.bad_requests[ip].append(time.time()) - oldest = min(self.bad_requests[ip][-self.bad_max:]) - retry_after = int(oldest + self.bad_window - now) + 1 - return True, max(1, retry_after) - - return False, 0 - - def is_blocked_good(self, ip): - """ - Check if IP is blocked by the good request rate limiter without recording. - - Args: - ip: Client IP address. - - Returns: - Tuple of (blocked, retry_after_seconds). - """ - if not self.enabled: - return False, 0 - - now = time.time() - - with self.good_lock: - # Check good requests - self.good_requests[ip] = self._cleanup_old( - self.good_requests[ip], self.good_window - ) - if len(self.good_requests[ip]) >= self.good_max: - oldest = min(self.good_requests[ip]) - retry_after = int(oldest + self.good_window - now) + 1 - return True, max(1, retry_after) - - return False, 0 - - def record_bad(self, ip): - """Record a bad request (without checking).""" - if not self.enabled: - return - - with self.bad_lock: - self.bad_requests[ip] = self._cleanup_old( - self.bad_requests[ip], self.bad_window - ) - self.bad_requests[ip].append(time.time()) - - def record_good(self, ip): - """Record a good request (without checking).""" - if not self.enabled: - return - - with self.good_lock: - self.good_requests[ip] = self._cleanup_old( - self.good_requests[ip], self.good_window - ) - self.good_requests[ip].append(time.time()) - - def cleanup(self): - """Remove stale entries to prevent memory leak.""" - with self.good_lock: - for ip in list(self.good_requests.keys()): - self.good_requests[ip] = self._cleanup_old( - self.good_requests[ip], self.good_window - ) - if not self.good_requests[ip]: - del self.good_requests[ip] - - with self.bad_lock: - for ip in list(self.bad_requests.keys()): - self.bad_requests[ip] = self._cleanup_old( - self.bad_requests[ip], self.bad_window - ) - if not self.bad_requests[ip]: - del self.bad_requests[ip] + rl = config.get("rate_limit", {}) + super().__init__( + rl.get("bad_window_seconds", 60), + rl.get("bad_max_requests", 5), + rl.get("enabled", False), + True, + ) diff --git a/src/ddns_service/server.py b/src/ddns_service/server.py index 15c2b17..12035f7 100644 --- a/src/ddns_service/server.py +++ b/src/ddns_service/server.py @@ -14,6 +14,7 @@ from urllib.parse import parse_qs, urlparse import argon2 +from . import datetime_str from .cleanup import ExpiredRecordsCleanupThread, RateLimitCleanupThread from .logging import clear_txn_id, set_txn_id from .models import DoesNotExist, get_hostname_for_user, get_user @@ -176,12 +177,12 @@ class DDNSRequestHandler(BaseHTTPRequestHandler): return # Bad rate limit check - if self.app.rate_limiter: - blocked, retry = self.app.rate_limiter.is_blocked_bad(client_ip) + if self.app.bad_limiter: + blocked, retry_at = self.app.bad_limiter.is_blocked(client_ip) if blocked: logging.warning( - f"Rate limited (bad requests): client={client_ip}, " - f"retry_after={retry}") + f"Rate limited (bad): client={client_ip}, " + f"retry_at={datetime_str(retry_at)}") self.respond(429, "abuse") return @@ -246,16 +247,18 @@ class DDNSRequestHandler(BaseHTTPRequestHandler): return # Good rate limit check - if self.app.rate_limiter: - blocked, retry = self.app.rate_limiter.is_blocked_good(client_ip) + if self.app.good_limiter: + blocked, retry_at = self.app.good_limiter.is_blocked(client_ip) if blocked: - logging.warning(f"Rate limited: client={client_ip}, retry_after={retry}") - self.respond(429, "abuse", retry_after=retry) + logging.warning( + f"Rate limited: client={client_ip}, " + f"retry_at={datetime_str(retry_at)}") + self.respond(429, "abuse") return # Record good request - if self.app.rate_limiter: - self.app.rate_limiter.record_good(client_ip) + if self.app.good_limiter: + self.app.good_limiter.record(client_ip) # Determine IPs to update result = self._process_ip_update(hostname, params, endpoint, client_ip) @@ -268,8 +271,8 @@ class DDNSRequestHandler(BaseHTTPRequestHandler): def _handle_bad_request(self, client_ip, code, status): """Handle bad request and record in rate limiter.""" - if self.app.rate_limiter: - self.app.rate_limiter.record_bad(client_ip) + if self.app.bad_limiter: + self.app.bad_limiter.record(client_ip) self.respond(code, status) def _process_ip_update(self, hostname, params, endpoint, client_ip):