Refactor rate limiter and datetime formatting

This commit is contained in:
2026-01-22 02:24:03 +01:00
parent f297a8d740
commit 04114ab659
8 changed files with 133 additions and 135 deletions

View File

@@ -13,6 +13,7 @@ __all__ = [
"cleanup", "cleanup",
"cli", "cli",
"config", "config",
"datetime_str",
"dns", "dns",
"email", "email",
"logging", "logging",
@@ -20,5 +21,8 @@ __all__ = [
"models", "models",
"ratelimit", "ratelimit",
"server", "server",
"validation"
] ]
def datetime_str(datetime):
return datetime.strftime("%Y-%m-%d %H:%M:%S")

View File

@@ -8,7 +8,7 @@ import argon2
from .dns import DNSService from .dns import DNSService
from .email import EmailService from .email import EmailService
from .models import create_tables, init_database from .models import create_tables, init_database
from .ratelimit import RateLimiter from .ratelimit import BadLimiter, GoodLimiter
class Application: class Application:
@@ -32,7 +32,8 @@ class Application:
# Service instances (initialized separately) # Service instances (initialized separately)
self.dns_service = None self.dns_service = None
self.email_service = None self.email_service = None
self.rate_limiter = None self.good_limiter = None
self.bad_limiter = None
def init_database(self): def init_database(self):
"""Initialize database connection and run migrations.""" """Initialize database connection and run migrations."""
@@ -50,10 +51,11 @@ class Application:
self.email_service = EmailService(self.config) self.email_service = EmailService(self.config)
logging.info("Email service initialized") logging.info("Email service initialized")
def init_rate_limiter(self): def init_rate_limiters(self):
"""Initialize rate limiter.""" """Initialize rate limiters."""
self.rate_limiter = RateLimiter(self.config) self.good_limiter = GoodLimiter(self.config)
logging.info("Rate limiter initialized") self.bad_limiter = BadLimiter(self.config)
logging.info("Rate limiters initialized")
def signal_shutdown(self): def signal_shutdown(self):
"""Signal the application to shut down.""" """Signal the application to shut down."""

View File

@@ -135,8 +135,10 @@ class RateLimitCleanupThread(threading.Thread):
while not self.stop_event.wait(self.interval): while not self.stop_event.wait(self.interval):
try: try:
if self.app.rate_limiter: if self.app.good_limiter:
self.app.rate_limiter.cleanup() self.app.good_limiter.cleanup()
if self.app.bad_limiter:
self.app.bad_limiter.cleanup()
except Exception as e: except Exception as e:
logging.error(f"Rate limit cleanup error: {e}") logging.error(f"Rate limit cleanup error: {e}")

View File

@@ -3,6 +3,7 @@
import getpass import getpass
import logging import logging
from . import datetime_str
from .cleanup import cleanup_expired from .cleanup import cleanup_expired
from .models import ( from .models import (
DoesNotExist, DoesNotExist,
@@ -27,7 +28,7 @@ def cmd_user_list(args, app):
hostname_count = Hostname.select().where( hostname_count = Hostname.select().where(
Hostname.user == user Hostname.user == user
).count() ).count()
created_at = user.created_at.strftime("%Y-%m-%d %H:%M:%S") created_at = datetime_str(user.created_at)
print( print(
f"{user.username:<20} {user.email:<30} " f"{user.username:<20} {user.email:<30} "
@@ -153,8 +154,8 @@ def cmd_hostname_list(args, app):
) )
print("-" * 132) print("-" * 132)
for h in hostnames: for h in hostnames:
last_ipv4_update = h.last_ipv4_update.strftime("%Y-%m-%d %H:%M:%S") if h.last_ipv4_update else "Never" last_ipv4_update = datetime_str(h.last_ipv4_update) if h.last_ipv4_update else "Never"
last_ipv6_update = h.last_ipv6_update.strftime("%Y-%m-%d %H:%M:%S") if h.last_ipv6_update else "Never" last_ipv6_update = datetime_str(h.last_ipv6_update) if h.last_ipv6_update else "Never"
print( print(
f"{h.hostname:<35} {h.user.username:<15} {h.zone:<20} " f"{h.hostname:<35} {h.user.username:<15} {h.zone:<20} "
f"{h.dns_ttl:<8} {h.expiry_ttl:<8} {last_ipv4_update:<21} {last_ipv6_update}" f"{h.dns_ttl:<8} {h.expiry_ttl:<8} {last_ipv4_update:<21} {last_ipv6_update}"

View File

@@ -2,6 +2,8 @@
import logging import logging
import smtplib import smtplib
from . import datetime_str
from email.mime.text import MIMEText from email.mime.text import MIMEText
@@ -91,11 +93,11 @@ Hostname: {hostname}
""" """
if last_ipv4: if last_ipv4:
ip = last_ipv4[0] ip = last_ipv4[0]
last_update = last_ipv4[1].strftime("%Y-%m-%d %H:%M:%S") last_update = datetime_str(last_ipv4[1])
body += f"IPv4 address: {ip} (last update: {last_update})\n" body += f"IPv4 address: {ip} (last update: {last_update})\n"
if last_ipv6: if last_ipv6:
ip = last_ipv6[0] ip = last_ipv6[0]
last_update = last_ipv6[1].strftime("%Y-%m-%d %H:%M:%S") last_update = datetime_str(last_ipv6[1])
body += f"IPv6 address: {ip} (last update: {last_update})\n" body += f"IPv6 address: {ip} (last update: {last_update})\n"
body += f"""Expiry TTL: {expiry_ttl} seconds body += f"""Expiry TTL: {expiry_ttl} seconds

View File

@@ -172,7 +172,7 @@ def main():
# Initialize all services for daemon mode # Initialize all services for daemon mode
app.init_dns() app.init_dns()
app.init_email() app.init_email()
app.init_rate_limiter() app.init_rate_limiters()
run_daemon(app) run_daemon(app)
return 0 return 0
except Exception as e: except Exception as e:

View File

@@ -3,127 +3,111 @@
import threading import threading
import time import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime
class RateLimiter: class BaseLimiter:
"""Sliding window rate limiter with separate good/bad request tracking.""" """Base sliding window rate limiter."""
def __init__(self, window, max_requests, enabled, recording=False):
"""
Initialize rate limiter.
Args:
window: Time window in seconds.
max_requests: Maximum requests allowed in window.
enabled: Whether rate limiting is enabled.
recording: Wheter IPs are recorded when hitting limit.
"""
self.window = window
self.max_requests = max_requests
self.enabled = enabled
self.recording = recording
self.requests = defaultdict(list)
self.lock = threading.Lock()
def _cleanup_old(self, timestamps):
"""Remove timestamps older than window."""
cutoff = time.time() - self.window
return [t for t in timestamps if t > cutoff]
def record(self, ip):
"""Record a request (without checking)."""
if not self.enabled:
return
with self.lock:
self.requests[ip] = self._cleanup_old(self.requests[ip])
self.requests[ip].append(time.time())
def cleanup(self):
"""Remove stale entries to prevent memory leak."""
with self.lock:
for ip in list(self.requests.keys()):
self.requests[ip] = self._cleanup_old(self.requests[ip])
if not self.requests[ip]:
del self.requests[ip]
def is_blocked(self, ip):
"""
Check if IP is blocked.
Args:
ip: Client IP address.
Returns:
Tuple of (blocked, retry_at_datetime or None).
"""
if not self.enabled:
return False, None
with self.lock:
self.requests[ip] = self._cleanup_old(self.requests[ip])
if len(self.requests[ip]) >= self.max_requests:
if self.recording:
self.requests[ip].append(time.time())
oldest = self.requests[ip][-self.max_requests]
else:
oldest = self.requests[ip][0]
retry_at = datetime.fromtimestamp(oldest + self.window)
return True, retry_at
return False, None
class GoodLimiter(BaseLimiter):
"""Rate limiter for good (authenticated) requests."""
def __init__(self, config): def __init__(self, config):
""" """
Initialize rate limiter from config. Initialize good request rate limiter from config.
Args: Args:
config: Full configuration dictionary. config: Full configuration dictionary.
""" """
rl_config = config.get("rate_limit", {}) rl = config.get("rate_limit", {})
self.enabled = rl_config.get("enabled", False) super().__init__(
self.good_window = rl_config.get("good_window_seconds", 60) rl.get("good_window_seconds", 60),
self.good_max = rl_config.get("good_max_requests", 30) rl.get("good_max_requests", 30),
self.bad_window = rl_config.get("bad_window_seconds", 60) rl.get("enabled", False),
self.bad_max = rl_config.get("bad_max_requests", 5) False,
)
self.bad_requests = defaultdict(list)
self.good_requests = defaultdict(list)
self.bad_lock = threading.Lock()
self.good_lock = threading.Lock()
def _cleanup_old(self, timestamps, window): class BadLimiter(BaseLimiter):
"""Remove timestamps older than window.""" """Rate limiter for bad (failed auth) requests."""
cutoff = time.time() - window
return [t for t in timestamps if t > cutoff]
def is_blocked_bad(self, ip): def __init__(self, config):
""" """
Check if IP is blocked by the bad request rate limiter with recording Initialize bad request rate limiter from config.
when IP is already limited.
Args: Args:
ip: Client IP address. config: Full configuration dictionary.
Returns:
Tuple of (blocked, retry_after_seconds).
""" """
if not self.enabled: rl = config.get("rate_limit", {})
return False, 0 super().__init__(
rl.get("bad_window_seconds", 60),
now = time.time() rl.get("bad_max_requests", 5),
rl.get("enabled", False),
with self.bad_lock: True,
# Check bad requests
self.bad_requests[ip] = self._cleanup_old(
self.bad_requests[ip], self.bad_window
) )
if len(self.bad_requests[ip]) >= self.bad_max:
self.bad_requests[ip].append(time.time())
oldest = min(self.bad_requests[ip][-self.bad_max:])
retry_after = int(oldest + self.bad_window - now) + 1
return True, max(1, retry_after)
return False, 0
def is_blocked_good(self, ip):
"""
Check if IP is blocked by the good request rate limiter without recording.
Args:
ip: Client IP address.
Returns:
Tuple of (blocked, retry_after_seconds).
"""
if not self.enabled:
return False, 0
now = time.time()
with self.good_lock:
# Check good requests
self.good_requests[ip] = self._cleanup_old(
self.good_requests[ip], self.good_window
)
if len(self.good_requests[ip]) >= self.good_max:
oldest = min(self.good_requests[ip])
retry_after = int(oldest + self.good_window - now) + 1
return True, max(1, retry_after)
return False, 0
def record_bad(self, ip):
"""Record a bad request (without checking)."""
if not self.enabled:
return
with self.bad_lock:
self.bad_requests[ip] = self._cleanup_old(
self.bad_requests[ip], self.bad_window
)
self.bad_requests[ip].append(time.time())
def record_good(self, ip):
"""Record a good request (without checking)."""
if not self.enabled:
return
with self.good_lock:
self.good_requests[ip] = self._cleanup_old(
self.good_requests[ip], self.good_window
)
self.good_requests[ip].append(time.time())
def cleanup(self):
"""Remove stale entries to prevent memory leak."""
with self.good_lock:
for ip in list(self.good_requests.keys()):
self.good_requests[ip] = self._cleanup_old(
self.good_requests[ip], self.good_window
)
if not self.good_requests[ip]:
del self.good_requests[ip]
with self.bad_lock:
for ip in list(self.bad_requests.keys()):
self.bad_requests[ip] = self._cleanup_old(
self.bad_requests[ip], self.bad_window
)
if not self.bad_requests[ip]:
del self.bad_requests[ip]

View File

@@ -14,6 +14,7 @@ from urllib.parse import parse_qs, urlparse
import argon2 import argon2
from . import datetime_str
from .cleanup import ExpiredRecordsCleanupThread, RateLimitCleanupThread from .cleanup import ExpiredRecordsCleanupThread, RateLimitCleanupThread
from .logging import clear_txn_id, set_txn_id from .logging import clear_txn_id, set_txn_id
from .models import DoesNotExist, get_hostname_for_user, get_user from .models import DoesNotExist, get_hostname_for_user, get_user
@@ -176,12 +177,12 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
return return
# Bad rate limit check # Bad rate limit check
if self.app.rate_limiter: if self.app.bad_limiter:
blocked, retry = self.app.rate_limiter.is_blocked_bad(client_ip) blocked, retry_at = self.app.bad_limiter.is_blocked(client_ip)
if blocked: if blocked:
logging.warning( logging.warning(
f"Rate limited (bad requests): client={client_ip}, " f"Rate limited (bad): client={client_ip}, "
f"retry_after={retry}") f"retry_at={datetime_str(retry_at)}")
self.respond(429, "abuse") self.respond(429, "abuse")
return return
@@ -246,16 +247,18 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
return return
# Good rate limit check # Good rate limit check
if self.app.rate_limiter: if self.app.good_limiter:
blocked, retry = self.app.rate_limiter.is_blocked_good(client_ip) blocked, retry_at = self.app.good_limiter.is_blocked(client_ip)
if blocked: if blocked:
logging.warning(f"Rate limited: client={client_ip}, retry_after={retry}") logging.warning(
self.respond(429, "abuse", retry_after=retry) f"Rate limited: client={client_ip}, "
f"retry_at={datetime_str(retry_at)}")
self.respond(429, "abuse")
return return
# Record good request # Record good request
if self.app.rate_limiter: if self.app.good_limiter:
self.app.rate_limiter.record_good(client_ip) self.app.good_limiter.record(client_ip)
# Determine IPs to update # Determine IPs to update
result = self._process_ip_update(hostname, params, endpoint, client_ip) result = self._process_ip_update(hostname, params, endpoint, client_ip)
@@ -268,8 +271,8 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
def _handle_bad_request(self, client_ip, code, status): def _handle_bad_request(self, client_ip, code, status):
"""Handle bad request and record in rate limiter.""" """Handle bad request and record in rate limiter."""
if self.app.rate_limiter: if self.app.bad_limiter:
self.app.rate_limiter.record_bad(client_ip) self.app.bad_limiter.record(client_ip)
self.respond(code, status) self.respond(code, status)
def _process_ip_update(self, hostname, params, endpoint, client_ip): def _process_ip_update(self, hostname, params, endpoint, client_ip):