Refactor rate limiter and datetime formatting
This commit is contained in:
@@ -13,6 +13,7 @@ __all__ = [
|
|||||||
"cleanup",
|
"cleanup",
|
||||||
"cli",
|
"cli",
|
||||||
"config",
|
"config",
|
||||||
|
"datetime_str",
|
||||||
"dns",
|
"dns",
|
||||||
"email",
|
"email",
|
||||||
"logging",
|
"logging",
|
||||||
@@ -20,5 +21,8 @@ __all__ = [
|
|||||||
"models",
|
"models",
|
||||||
"ratelimit",
|
"ratelimit",
|
||||||
"server",
|
"server",
|
||||||
"validation"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def datetime_str(datetime):
|
||||||
|
return datetime.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import argon2
|
|||||||
from .dns import DNSService
|
from .dns import DNSService
|
||||||
from .email import EmailService
|
from .email import EmailService
|
||||||
from .models import create_tables, init_database
|
from .models import create_tables, init_database
|
||||||
from .ratelimit import RateLimiter
|
from .ratelimit import BadLimiter, GoodLimiter
|
||||||
|
|
||||||
|
|
||||||
class Application:
|
class Application:
|
||||||
@@ -32,7 +32,8 @@ class Application:
|
|||||||
# Service instances (initialized separately)
|
# Service instances (initialized separately)
|
||||||
self.dns_service = None
|
self.dns_service = None
|
||||||
self.email_service = None
|
self.email_service = None
|
||||||
self.rate_limiter = None
|
self.good_limiter = None
|
||||||
|
self.bad_limiter = None
|
||||||
|
|
||||||
def init_database(self):
|
def init_database(self):
|
||||||
"""Initialize database connection and run migrations."""
|
"""Initialize database connection and run migrations."""
|
||||||
@@ -50,10 +51,11 @@ class Application:
|
|||||||
self.email_service = EmailService(self.config)
|
self.email_service = EmailService(self.config)
|
||||||
logging.info("Email service initialized")
|
logging.info("Email service initialized")
|
||||||
|
|
||||||
def init_rate_limiter(self):
|
def init_rate_limiters(self):
|
||||||
"""Initialize rate limiter."""
|
"""Initialize rate limiters."""
|
||||||
self.rate_limiter = RateLimiter(self.config)
|
self.good_limiter = GoodLimiter(self.config)
|
||||||
logging.info("Rate limiter initialized")
|
self.bad_limiter = BadLimiter(self.config)
|
||||||
|
logging.info("Rate limiters initialized")
|
||||||
|
|
||||||
def signal_shutdown(self):
|
def signal_shutdown(self):
|
||||||
"""Signal the application to shut down."""
|
"""Signal the application to shut down."""
|
||||||
|
|||||||
@@ -135,8 +135,10 @@ class RateLimitCleanupThread(threading.Thread):
|
|||||||
|
|
||||||
while not self.stop_event.wait(self.interval):
|
while not self.stop_event.wait(self.interval):
|
||||||
try:
|
try:
|
||||||
if self.app.rate_limiter:
|
if self.app.good_limiter:
|
||||||
self.app.rate_limiter.cleanup()
|
self.app.good_limiter.cleanup()
|
||||||
|
if self.app.bad_limiter:
|
||||||
|
self.app.bad_limiter.cleanup()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Rate limit cleanup error: {e}")
|
logging.error(f"Rate limit cleanup error: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import getpass
|
import getpass
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from . import datetime_str
|
||||||
from .cleanup import cleanup_expired
|
from .cleanup import cleanup_expired
|
||||||
from .models import (
|
from .models import (
|
||||||
DoesNotExist,
|
DoesNotExist,
|
||||||
@@ -27,7 +28,7 @@ def cmd_user_list(args, app):
|
|||||||
hostname_count = Hostname.select().where(
|
hostname_count = Hostname.select().where(
|
||||||
Hostname.user == user
|
Hostname.user == user
|
||||||
).count()
|
).count()
|
||||||
created_at = user.created_at.strftime("%Y-%m-%d %H:%M:%S")
|
created_at = datetime_str(user.created_at)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"{user.username:<20} {user.email:<30} "
|
f"{user.username:<20} {user.email:<30} "
|
||||||
@@ -153,8 +154,8 @@ def cmd_hostname_list(args, app):
|
|||||||
)
|
)
|
||||||
print("-" * 132)
|
print("-" * 132)
|
||||||
for h in hostnames:
|
for h in hostnames:
|
||||||
last_ipv4_update = h.last_ipv4_update.strftime("%Y-%m-%d %H:%M:%S") if h.last_ipv4_update else "Never"
|
last_ipv4_update = datetime_str(h.last_ipv4_update) if h.last_ipv4_update else "Never"
|
||||||
last_ipv6_update = h.last_ipv6_update.strftime("%Y-%m-%d %H:%M:%S") if h.last_ipv6_update else "Never"
|
last_ipv6_update = datetime_str(h.last_ipv6_update) if h.last_ipv6_update else "Never"
|
||||||
print(
|
print(
|
||||||
f"{h.hostname:<35} {h.user.username:<15} {h.zone:<20} "
|
f"{h.hostname:<35} {h.user.username:<15} {h.zone:<20} "
|
||||||
f"{h.dns_ttl:<8} {h.expiry_ttl:<8} {last_ipv4_update:<21} {last_ipv6_update}"
|
f"{h.dns_ttl:<8} {h.expiry_ttl:<8} {last_ipv4_update:<21} {last_ipv6_update}"
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import smtplib
|
import smtplib
|
||||||
|
|
||||||
|
from . import datetime_str
|
||||||
from email.mime.text import MIMEText
|
from email.mime.text import MIMEText
|
||||||
|
|
||||||
|
|
||||||
@@ -91,11 +93,11 @@ Hostname: {hostname}
|
|||||||
"""
|
"""
|
||||||
if last_ipv4:
|
if last_ipv4:
|
||||||
ip = last_ipv4[0]
|
ip = last_ipv4[0]
|
||||||
last_update = last_ipv4[1].strftime("%Y-%m-%d %H:%M:%S")
|
last_update = datetime_str(last_ipv4[1])
|
||||||
body += f"IPv4 address: {ip} (last update: {last_update})\n"
|
body += f"IPv4 address: {ip} (last update: {last_update})\n"
|
||||||
if last_ipv6:
|
if last_ipv6:
|
||||||
ip = last_ipv6[0]
|
ip = last_ipv6[0]
|
||||||
last_update = last_ipv6[1].strftime("%Y-%m-%d %H:%M:%S")
|
last_update = datetime_str(last_ipv6[1])
|
||||||
body += f"IPv6 address: {ip} (last update: {last_update})\n"
|
body += f"IPv6 address: {ip} (last update: {last_update})\n"
|
||||||
body += f"""Expiry TTL: {expiry_ttl} seconds
|
body += f"""Expiry TTL: {expiry_ttl} seconds
|
||||||
|
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ def main():
|
|||||||
# Initialize all services for daemon mode
|
# Initialize all services for daemon mode
|
||||||
app.init_dns()
|
app.init_dns()
|
||||||
app.init_email()
|
app.init_email()
|
||||||
app.init_rate_limiter()
|
app.init_rate_limiters()
|
||||||
run_daemon(app)
|
run_daemon(app)
|
||||||
return 0
|
return 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -3,127 +3,111 @@
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter:
|
class BaseLimiter:
|
||||||
"""Sliding window rate limiter with separate good/bad request tracking."""
|
"""Base sliding window rate limiter."""
|
||||||
|
|
||||||
|
def __init__(self, window, max_requests, enabled, recording=False):
|
||||||
|
"""
|
||||||
|
Initialize rate limiter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
window: Time window in seconds.
|
||||||
|
max_requests: Maximum requests allowed in window.
|
||||||
|
enabled: Whether rate limiting is enabled.
|
||||||
|
recording: Wheter IPs are recorded when hitting limit.
|
||||||
|
"""
|
||||||
|
self.window = window
|
||||||
|
self.max_requests = max_requests
|
||||||
|
self.enabled = enabled
|
||||||
|
self.recording = recording
|
||||||
|
self.requests = defaultdict(list)
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def _cleanup_old(self, timestamps):
|
||||||
|
"""Remove timestamps older than window."""
|
||||||
|
cutoff = time.time() - self.window
|
||||||
|
return [t for t in timestamps if t > cutoff]
|
||||||
|
|
||||||
|
def record(self, ip):
|
||||||
|
"""Record a request (without checking)."""
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
self.requests[ip] = self._cleanup_old(self.requests[ip])
|
||||||
|
self.requests[ip].append(time.time())
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Remove stale entries to prevent memory leak."""
|
||||||
|
with self.lock:
|
||||||
|
for ip in list(self.requests.keys()):
|
||||||
|
self.requests[ip] = self._cleanup_old(self.requests[ip])
|
||||||
|
if not self.requests[ip]:
|
||||||
|
del self.requests[ip]
|
||||||
|
|
||||||
|
def is_blocked(self, ip):
|
||||||
|
"""
|
||||||
|
Check if IP is blocked.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip: Client IP address.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (blocked, retry_at_datetime or None).
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
self.requests[ip] = self._cleanup_old(self.requests[ip])
|
||||||
|
if len(self.requests[ip]) >= self.max_requests:
|
||||||
|
if self.recording:
|
||||||
|
self.requests[ip].append(time.time())
|
||||||
|
oldest = self.requests[ip][-self.max_requests]
|
||||||
|
else:
|
||||||
|
oldest = self.requests[ip][0]
|
||||||
|
retry_at = datetime.fromtimestamp(oldest + self.window)
|
||||||
|
return True, retry_at
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
|
class GoodLimiter(BaseLimiter):
|
||||||
|
"""Rate limiter for good (authenticated) requests."""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
"""
|
"""
|
||||||
Initialize rate limiter from config.
|
Initialize good request rate limiter from config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Full configuration dictionary.
|
config: Full configuration dictionary.
|
||||||
"""
|
"""
|
||||||
rl_config = config.get("rate_limit", {})
|
rl = config.get("rate_limit", {})
|
||||||
self.enabled = rl_config.get("enabled", False)
|
super().__init__(
|
||||||
self.good_window = rl_config.get("good_window_seconds", 60)
|
rl.get("good_window_seconds", 60),
|
||||||
self.good_max = rl_config.get("good_max_requests", 30)
|
rl.get("good_max_requests", 30),
|
||||||
self.bad_window = rl_config.get("bad_window_seconds", 60)
|
rl.get("enabled", False),
|
||||||
self.bad_max = rl_config.get("bad_max_requests", 5)
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
self.bad_requests = defaultdict(list)
|
|
||||||
self.good_requests = defaultdict(list)
|
|
||||||
self.bad_lock = threading.Lock()
|
|
||||||
self.good_lock = threading.Lock()
|
|
||||||
|
|
||||||
def _cleanup_old(self, timestamps, window):
|
class BadLimiter(BaseLimiter):
|
||||||
"""Remove timestamps older than window."""
|
"""Rate limiter for bad (failed auth) requests."""
|
||||||
cutoff = time.time() - window
|
|
||||||
return [t for t in timestamps if t > cutoff]
|
|
||||||
|
|
||||||
def is_blocked_bad(self, ip):
|
def __init__(self, config):
|
||||||
"""
|
"""
|
||||||
Check if IP is blocked by the bad request rate limiter with recording
|
Initialize bad request rate limiter from config.
|
||||||
when IP is already limited.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ip: Client IP address.
|
config: Full configuration dictionary.
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (blocked, retry_after_seconds).
|
|
||||||
"""
|
"""
|
||||||
if not self.enabled:
|
rl = config.get("rate_limit", {})
|
||||||
return False, 0
|
super().__init__(
|
||||||
|
rl.get("bad_window_seconds", 60),
|
||||||
now = time.time()
|
rl.get("bad_max_requests", 5),
|
||||||
|
rl.get("enabled", False),
|
||||||
with self.bad_lock:
|
True,
|
||||||
# Check bad requests
|
|
||||||
self.bad_requests[ip] = self._cleanup_old(
|
|
||||||
self.bad_requests[ip], self.bad_window
|
|
||||||
)
|
)
|
||||||
if len(self.bad_requests[ip]) >= self.bad_max:
|
|
||||||
self.bad_requests[ip].append(time.time())
|
|
||||||
oldest = min(self.bad_requests[ip][-self.bad_max:])
|
|
||||||
retry_after = int(oldest + self.bad_window - now) + 1
|
|
||||||
return True, max(1, retry_after)
|
|
||||||
|
|
||||||
return False, 0
|
|
||||||
|
|
||||||
def is_blocked_good(self, ip):
|
|
||||||
"""
|
|
||||||
Check if IP is blocked by the good request rate limiter without recording.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ip: Client IP address.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (blocked, retry_after_seconds).
|
|
||||||
"""
|
|
||||||
if not self.enabled:
|
|
||||||
return False, 0
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
|
|
||||||
with self.good_lock:
|
|
||||||
# Check good requests
|
|
||||||
self.good_requests[ip] = self._cleanup_old(
|
|
||||||
self.good_requests[ip], self.good_window
|
|
||||||
)
|
|
||||||
if len(self.good_requests[ip]) >= self.good_max:
|
|
||||||
oldest = min(self.good_requests[ip])
|
|
||||||
retry_after = int(oldest + self.good_window - now) + 1
|
|
||||||
return True, max(1, retry_after)
|
|
||||||
|
|
||||||
return False, 0
|
|
||||||
|
|
||||||
def record_bad(self, ip):
|
|
||||||
"""Record a bad request (without checking)."""
|
|
||||||
if not self.enabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self.bad_lock:
|
|
||||||
self.bad_requests[ip] = self._cleanup_old(
|
|
||||||
self.bad_requests[ip], self.bad_window
|
|
||||||
)
|
|
||||||
self.bad_requests[ip].append(time.time())
|
|
||||||
|
|
||||||
def record_good(self, ip):
|
|
||||||
"""Record a good request (without checking)."""
|
|
||||||
if not self.enabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self.good_lock:
|
|
||||||
self.good_requests[ip] = self._cleanup_old(
|
|
||||||
self.good_requests[ip], self.good_window
|
|
||||||
)
|
|
||||||
self.good_requests[ip].append(time.time())
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Remove stale entries to prevent memory leak."""
|
|
||||||
with self.good_lock:
|
|
||||||
for ip in list(self.good_requests.keys()):
|
|
||||||
self.good_requests[ip] = self._cleanup_old(
|
|
||||||
self.good_requests[ip], self.good_window
|
|
||||||
)
|
|
||||||
if not self.good_requests[ip]:
|
|
||||||
del self.good_requests[ip]
|
|
||||||
|
|
||||||
with self.bad_lock:
|
|
||||||
for ip in list(self.bad_requests.keys()):
|
|
||||||
self.bad_requests[ip] = self._cleanup_old(
|
|
||||||
self.bad_requests[ip], self.bad_window
|
|
||||||
)
|
|
||||||
if not self.bad_requests[ip]:
|
|
||||||
del self.bad_requests[ip]
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from urllib.parse import parse_qs, urlparse
|
|||||||
|
|
||||||
import argon2
|
import argon2
|
||||||
|
|
||||||
|
from . import datetime_str
|
||||||
from .cleanup import ExpiredRecordsCleanupThread, RateLimitCleanupThread
|
from .cleanup import ExpiredRecordsCleanupThread, RateLimitCleanupThread
|
||||||
from .logging import clear_txn_id, set_txn_id
|
from .logging import clear_txn_id, set_txn_id
|
||||||
from .models import DoesNotExist, get_hostname_for_user, get_user
|
from .models import DoesNotExist, get_hostname_for_user, get_user
|
||||||
@@ -176,12 +177,12 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Bad rate limit check
|
# Bad rate limit check
|
||||||
if self.app.rate_limiter:
|
if self.app.bad_limiter:
|
||||||
blocked, retry = self.app.rate_limiter.is_blocked_bad(client_ip)
|
blocked, retry_at = self.app.bad_limiter.is_blocked(client_ip)
|
||||||
if blocked:
|
if blocked:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Rate limited (bad requests): client={client_ip}, "
|
f"Rate limited (bad): client={client_ip}, "
|
||||||
f"retry_after={retry}")
|
f"retry_at={datetime_str(retry_at)}")
|
||||||
self.respond(429, "abuse")
|
self.respond(429, "abuse")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -246,16 +247,18 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Good rate limit check
|
# Good rate limit check
|
||||||
if self.app.rate_limiter:
|
if self.app.good_limiter:
|
||||||
blocked, retry = self.app.rate_limiter.is_blocked_good(client_ip)
|
blocked, retry_at = self.app.good_limiter.is_blocked(client_ip)
|
||||||
if blocked:
|
if blocked:
|
||||||
logging.warning(f"Rate limited: client={client_ip}, retry_after={retry}")
|
logging.warning(
|
||||||
self.respond(429, "abuse", retry_after=retry)
|
f"Rate limited: client={client_ip}, "
|
||||||
|
f"retry_at={datetime_str(retry_at)}")
|
||||||
|
self.respond(429, "abuse")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Record good request
|
# Record good request
|
||||||
if self.app.rate_limiter:
|
if self.app.good_limiter:
|
||||||
self.app.rate_limiter.record_good(client_ip)
|
self.app.good_limiter.record(client_ip)
|
||||||
|
|
||||||
# Determine IPs to update
|
# Determine IPs to update
|
||||||
result = self._process_ip_update(hostname, params, endpoint, client_ip)
|
result = self._process_ip_update(hostname, params, endpoint, client_ip)
|
||||||
@@ -268,8 +271,8 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
def _handle_bad_request(self, client_ip, code, status):
|
def _handle_bad_request(self, client_ip, code, status):
|
||||||
"""Handle bad request and record in rate limiter."""
|
"""Handle bad request and record in rate limiter."""
|
||||||
if self.app.rate_limiter:
|
if self.app.bad_limiter:
|
||||||
self.app.rate_limiter.record_bad(client_ip)
|
self.app.bad_limiter.record(client_ip)
|
||||||
self.respond(code, status)
|
self.respond(code, status)
|
||||||
|
|
||||||
def _process_ip_update(self, hostname, params, endpoint, client_ip):
|
def _process_ip_update(self, hostname, params, endpoint, client_ip):
|
||||||
|
|||||||
Reference in New Issue
Block a user