Compare commits

...

11 Commits

8 changed files with 276 additions and 63 deletions

View File

@@ -10,6 +10,15 @@ import datetime
__version__ = "1.0.0" __version__ = "1.0.0"
__author__ = "Thomas Oettli <spacefreak@noop.ch>" __author__ = "Thomas Oettli <spacefreak@noop.ch>"
# DynDNS-compatible response statuses
STATUS_GOOD = "good"
STATUS_NOCHG = "nochg"
STATUS_BADAUTH = "badauth"
STATUS_NOHOST = "nohost"
STATUS_DNSERR = "dnserr"
STATUS_ABUSE = "abuse"
STATUS_BADIP = "badip"
__all__ = [ __all__ = [
"app", "app",
"cleanup", "cleanup",
@@ -23,12 +32,37 @@ __all__ = [
"models", "models",
"ratelimit", "ratelimit",
"server", "server",
"STATUS_GOOD",
"STATUS_NOCHG",
"STATUS_BADAUTH",
"STATUS_NOHOST",
"STATUS_DNSERR",
"STATUS_ABUSE",
"STATUS_BADIP",
] ]
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S %Z" DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S %Z"
# Datetime convention:
# All datetime objects in this codebase are naive UTC to match database storage.
# - utc_now(): returns naive UTC datetime
# - datetime_str(): converts naive UTC to display string (adds tzinfo for formatting)
# - Database stores/returns naive datetimes (always UTC by convention)
def datetime_str(dt, utc=False): def datetime_str(dt, utc=False):
"""
Convert datetime to display string.
Assumes naive datetimes are UTC per codebase convention.
Args:
dt: Datetime object (naive UTC or timezone-aware).
utc: If True, display in UTC; otherwise convert to local timezone.
Returns:
Formatted datetime string, or "Never" if dt is not a datetime.
"""
if not isinstance(dt, datetime.datetime): if not isinstance(dt, datetime.datetime):
return "Never" return "Never"
@@ -41,4 +75,10 @@ def datetime_str(dt, utc=False):
def utc_now(): def utc_now():
"""
Get current time as naive UTC datetime.
Returns naive datetime to match database storage behavior.
All naive datetimes in this codebase are assumed to be UTC.
"""
return datetime.datetime.now(datetime.UTC).replace(tzinfo=None) return datetime.datetime.now(datetime.UTC).replace(tzinfo=None)

View File

@@ -1,12 +1,13 @@
"""Application class - central dependency holder.""" """Application class - central dependency holder."""
import argon2
import logging import logging
import threading import threading
import argon2 from .config import load_config
from .dns import DNSService from .dns import DNSService
from .email import EmailService from .email import EmailService
from .logging import setup_logging
from .models import create_tables, init_database from .models import create_tables, init_database
from .ratelimit import BadLimiter, GoodLimiter from .ratelimit import BadLimiter, GoodLimiter
@@ -18,14 +19,16 @@ class Application:
Holds configuration and all service instances. Holds configuration and all service instances.
""" """
def __init__(self, config): def __init__(self, config, config_path=None):
""" """
Initialize application with configuration. Initialize application with configuration.
Args: Args:
config: Configuration dictionary from TOML file. config: Configuration dictionary from TOML file.
config_path: Path to configuration file (for reload).
""" """
self.config = config self.config = config
self.config_path = config_path
self.password_hasher = argon2.PasswordHasher() self.password_hasher = argon2.PasswordHasher()
self.shutdown_event = threading.Event() self.shutdown_event = threading.Event()
@@ -57,6 +60,39 @@ class Application:
self.bad_limiter = BadLimiter(self.config) self.bad_limiter = BadLimiter(self.config)
logging.info("Rate limiters initialized") logging.info("Rate limiters initialized")
def reload_config(self):
"""
Reload configuration from file.
Does not reload: database settings, host, port.
"""
new_config = load_config(self.config_path)
# Preserve DB and bind settings
new_config["database"] = self.config["database"]
new_config["daemon"]["host"] = self.config["daemon"]["host"]
new_config["daemon"]["port"] = self.config["daemon"]["port"]
self.config = new_config
# Reconfigure logging
setup_logging(
level=self.config["daemon"]["log_level"],
target=self.config["daemon"]["log_target"],
syslog_socket=self.config["daemon"]["syslog_socket"],
syslog_facility=self.config["daemon"]["syslog_facility"],
log_file=self.config["daemon"]["log_file"],
log_file_size=self.config["daemon"]["log_file_size"],
log_versions=self.config["daemon"]["log_versions"],
)
# Re-init services
self.init_dns()
self.init_email()
self.init_rate_limiters()
logging.info("Configuration reloaded")
def signal_shutdown(self): def signal_shutdown(self):
"""Signal the application to shut down.""" """Signal the application to shut down."""
logging.info("Shutdown signaled") logging.info("Shutdown signaled")

View File

@@ -141,9 +141,12 @@ def load_config(config_path):
cfg["daemon"].setdefault("ssl", False) cfg["daemon"].setdefault("ssl", False)
cfg["daemon"].setdefault("proxy_header", "") cfg["daemon"].setdefault("proxy_header", "")
cfg["daemon"].setdefault("trusted_proxies", []) cfg["daemon"].setdefault("trusted_proxies", [])
cfg["daemon"].setdefault("thread_pool_size", 10)
cfg["daemon"].setdefault("request_timeout", 10)
cfg.setdefault("database", {}) cfg.setdefault("database", {})
cfg["database"].setdefault("backend", "sqlite") cfg["database"].setdefault("backend", "sqlite")
cfg["database"].setdefault("pool_size", 5)
cfg.setdefault("dns_service", {}) cfg.setdefault("dns_service", {})
cfg["dns_service"].setdefault("dns_server", "127.0.0.1") cfg["dns_service"].setdefault("dns_server", "127.0.0.1")

View File

@@ -37,6 +37,12 @@ def encode_dnsname(hostname):
Raises: Raises:
EncodingError: If hostname is invalid. EncodingError: If hostname is invalid.
Example:
>>> encode_dnsname("münchen")
'xn--mnchen-3ya'
>>> encode_dnsname("example.com.")
'example.com'
""" """
hostname = hostname.lower().strip() hostname = hostname.lower().strip()
@@ -84,6 +90,24 @@ def encode_dnsname(hostname):
def detect_ip_type(ip): def detect_ip_type(ip):
"""
Detect IP address type and normalize.
Args:
ip: IP address string.
Returns:
Tuple of (record_type, normalized_ip).
Raises:
ValueError: If IP address is invalid.
Example:
>>> detect_ip_type("192.168.1.1")
('A', '192.168.1.1')
>>> detect_ip_type("2001:db8::1")
('AAAA', '2001:db8::1')
"""
try: try:
addr = ipaddress.ip_address(ip) addr = ipaddress.ip_address(ip)
if isinstance(addr, ipaddress.IPv4Address): if isinstance(addr, ipaddress.IPv4Address):
@@ -127,6 +151,16 @@ def parse_bind_key_file(path):
Raises: Raises:
DNSError: If parsing fails. DNSError: If parsing fails.
Example:
Key file contents::
key "ddns-key." {
algorithm hmac-sha256;
secret "base64secret==";
};
>>> keyring, algo = parse_bind_key_file("/etc/bind/ddns.key")
""" """
if not path: if not path:
return None, None return None, None
@@ -324,13 +358,17 @@ class DNSService:
Update a DNS record for the given hostname. Update a DNS record for the given hostname.
Args: Args:
hostname: Fully qualified hostname. hostname: Hostname (without zone suffix).
zone: DNS zone name. zone: DNS zone name.
ip: IP address to set. ip: IP address to set.
ttl: DNS record TTL. ttl: DNS record TTL.
Raises: Raises:
DNSError: If update fails. DNSError: If update fails.
Example:
>>> dns_service.update_record("myhost", "example.com", "192.168.1.1", 60)
>>> dns_service.update_record("myhost", "example.com", "2001:db8::1", 60)
""" """
try: try:
record_type, normalized_ip = detect_ip_type(ip) record_type, normalized_ip = detect_ip_type(ip)

View File

@@ -153,7 +153,7 @@ def main():
) )
# Create application instance # Create application instance
app = Application(config) app = Application(config, config_path)
# Initialize database # Initialize database
try: try:

View File

@@ -4,21 +4,39 @@ import logging
import os import os
from . import utc_now from . import utc_now
from .dns import encode_dnsname, EncodingError
from peewee import ( from peewee import (
AutoField, AutoField,
CharField, CharField,
DatabaseProxy,
DateTimeField, DateTimeField,
DoesNotExist, DoesNotExist,
fn, fn,
ForeignKeyField, ForeignKeyField,
IntegerField, IntegerField,
Model, Model,
MySQLDatabase,
SqliteDatabase, SqliteDatabase,
) )
from playhouse.pool import PooledMySQLDatabase
# Database instance (initialized later) # Re-export DoesNotExist and EncodingError for convenience
db = SqliteDatabase(None) __all__ = [
'db',
'DATABASE_VERSION',
'User',
'Hostname',
'Version',
'init_database',
'create_tables',
'get_user',
'get_hostname',
'get_hostname_for_user',
'DoesNotExist',
'EncodingError',
]
# Database proxy (initialized later with actual backend)
db = DatabaseProxy()
# Current database schema version # Current database schema version
DATABASE_VERSION = 2 DATABASE_VERSION = 2
@@ -83,6 +101,14 @@ class Hostname(BaseModel):
(('hostname', 'zone'), True), (('hostname', 'zone'), True),
) )
def save(self, *args, **kwargs):
"""Validate and encode hostname/zone before saving."""
if self.hostname:
self.hostname = encode_dnsname(self.hostname)
if self.zone:
self.zone = encode_dnsname(self.zone)
return super().save(*args, **kwargs)
class Version(BaseModel): class Version(BaseModel):
"""Database schema version for migrations.""" """Database schema version for migrations."""
@@ -110,7 +136,6 @@ def init_database(config: dict):
Raises: Raises:
ValueError: If unknown database backend. ValueError: If unknown database backend.
""" """
global db
backend = config["database"].get("backend", "sqlite") backend = config["database"].get("backend", "sqlite")
@@ -119,21 +144,20 @@ def init_database(config: dict):
db_dir = os.path.dirname(db_path) db_dir = os.path.dirname(db_path)
if db_dir: if db_dir:
os.makedirs(db_dir, exist_ok=True) os.makedirs(db_dir, exist_ok=True)
db.init(db_path) actual_db = SqliteDatabase(db_path)
db.initialize(actual_db)
logging.debug(f"Database backend: SQLite path={db_path}") logging.debug(f"Database backend: SQLite path={db_path}")
elif backend == "mariadb": elif backend == "mariadb":
db = MySQLDatabase( actual_db = PooledMySQLDatabase(
config["database"]["database"], config["database"]["database"],
host=config["database"].get("host", "localhost"), host=config["database"].get("host", "localhost"),
port=config["database"].get("port", 3306), port=config["database"].get("port", 3306),
user=config["database"]["user"], user=config["database"]["user"],
password=config["database"]["password"], password=config["database"]["password"],
max_connections=config["database"].get("pool_size", 5),
) )
# Re-bind models to new database db.initialize(actual_db)
User._meta.database = db
Hostname._meta.database = db
Version._meta.database = db
db_name = config['database']['database'] db_name = config['database']['database']
logging.debug(f"Database backend: MariaDB db={db_name}") logging.debug(f"Database backend: MariaDB db={db_name}")
@@ -261,6 +285,11 @@ def get_user(username: str):
Raises: Raises:
DoesNotExist: If user not found. DoesNotExist: If user not found.
Example:
>>> user = get_user("alice")
>>> print(user.email)
'alice@example.com'
""" """
return User.get(User.username == username) return User.get(User.username == username)
@@ -278,6 +307,11 @@ def get_hostname(hostname, zone):
Raises: Raises:
DoesNotExist: If hostname not found. DoesNotExist: If hostname not found.
Example:
>>> host = get_hostname("myhost", "example.com")
>>> print(host.last_ipv4)
'192.168.1.1'
""" """
return Hostname.get( return Hostname.get(
(Hostname.hostname == hostname) & (Hostname.zone == zone) (Hostname.hostname == hostname) & (Hostname.zone == zone)
@@ -289,7 +323,7 @@ def get_hostname_for_user(hostname: str, user: User):
Get hostname owned by specific user. Get hostname owned by specific user.
Args: Args:
hostname: Hostname to look up. hostname: Hostname to look up (FQDN).
user: User who should own the hostname. user: User who should own the hostname.
Returns: Returns:
@@ -297,22 +331,10 @@ def get_hostname_for_user(hostname: str, user: User):
Raises: Raises:
DoesNotExist: If hostname not found or not owned by user. DoesNotExist: If hostname not found or not owned by user.
Example:
>>> user = get_user("alice")
>>> host = get_hostname_for_user("myhost.example.com", user)
""" """
fqdn = fn.Concat(Hostname.hostname, '.', Hostname.zone) fqdn = fn.Concat(Hostname.hostname, '.', Hostname.zone)
return Hostname.get((fqdn == hostname) & (Hostname.user == user)) return Hostname.get((fqdn == hostname) & (Hostname.user == user))
# Re-export DoesNotExist for convenience
__all__ = [
'db',
'DATABASE_VERSION',
'User',
'Hostname',
'Version',
'init_database',
'create_tables',
'get_user',
'get_hostname',
'get_hostname_for_user',
'DoesNotExist',
]

View File

@@ -85,11 +85,11 @@ class GoodLimiter(BaseLimiter):
Args: Args:
config: Full configuration dictionary. config: Full configuration dictionary.
""" """
rl = config.get("rate_limit", {}) rl = config["rate_limit"]
super().__init__( super().__init__(
rl.get("good_window_seconds", 60), rl["good_window_seconds"],
rl.get("good_max_requests", 30), rl["good_max_requests"],
rl.get("enabled", False), rl["enabled"],
False, False,
) )
@@ -104,10 +104,10 @@ class BadLimiter(BaseLimiter):
Args: Args:
config: Full configuration dictionary. config: Full configuration dictionary.
""" """
rl = config.get("rate_limit", {}) rl = config["rate_limit"]
super().__init__( super().__init__(
rl.get("bad_window_seconds", 60), rl["bad_window_seconds"],
rl.get("bad_max_requests", 5), rl["bad_max_requests"],
rl.get("enabled", False), rl["enabled"],
True, True,
) )

View File

@@ -8,12 +8,24 @@ import json
import logging import logging
import signal import signal
import ssl import ssl
import threading
from concurrent.futures import ThreadPoolExecutor
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
import argon2 import argon2
from . import datetime_str, utc_now from . import (
datetime_str,
utc_now,
STATUS_GOOD,
STATUS_NOCHG,
STATUS_BADAUTH,
STATUS_NOHOST,
STATUS_DNSERR,
STATUS_ABUSE,
STATUS_BADIP,
)
from .cleanup import ExpiredRecordsCleanupThread, RateLimitCleanupThread from .cleanup import ExpiredRecordsCleanupThread, RateLimitCleanupThread
from .logging import clear_txn_id, set_txn_id from .logging import clear_txn_id, set_txn_id
from .models import DoesNotExist, get_hostname_for_user, get_user from .models import DoesNotExist, get_hostname_for_user, get_user
@@ -64,7 +76,7 @@ def _is_trusted_proxy(client_ip, trusted_networks):
class DDNSServer(ThreadingHTTPServer): class DDNSServer(ThreadingHTTPServer):
"""HTTP server with Application instance.""" """HTTP server with Application instance and thread pool."""
def __init__(self, app, address): def __init__(self, app, address):
""" """
@@ -79,8 +91,45 @@ class DDNSServer(ThreadingHTTPServer):
self.trusted_networks = _parse_trusted_proxies( self.trusted_networks = _parse_trusted_proxies(
app.config["daemon"].get("trusted_proxies", []) app.config["daemon"].get("trusted_proxies", [])
) )
self.pool_size = app.config["daemon"]["thread_pool_size"]
self.request_timeout = app.config["daemon"]["request_timeout"]
self.executor = ThreadPoolExecutor(max_workers=self.pool_size)
self.active_requests = 0
self.requests_lock = threading.Lock()
self.requests_done = threading.Condition(self.requests_lock)
super().__init__(address, DDNSRequestHandler) super().__init__(address, DDNSRequestHandler)
def process_request(self, request, client_address):
"""Submit request to thread pool."""
with self.requests_lock:
self.active_requests += 1
request.settimeout(self.request_timeout)
self.executor.submit(self._handle_request_wrapper, request, client_address)
def _handle_request_wrapper(self, request, client_address):
"""Wrap request handling to track active requests."""
try:
self.process_request_thread(request, client_address)
finally:
with self.requests_lock:
self.active_requests -= 1
if self.active_requests == 0:
self.requests_done.notify_all()
def wait_for_requests(self, timeout=5):
"""Wait for active requests to complete."""
with self.requests_lock:
if self.active_requests > 0:
logging.info(f"Waiting for {self.active_requests} active request(s)")
self.requests_done.wait(timeout=timeout)
if self.active_requests > 0:
logging.warning(f"Shutdown timeout, {self.active_requests} request(s) still active")
def server_close(self):
"""Shutdown thread pool and close server."""
self.executor.shutdown(wait=True)
super().server_close()
class DDNSRequestHandler(BaseHTTPRequestHandler): class DDNSRequestHandler(BaseHTTPRequestHandler):
"""HTTP request handler for DDNS updates.""" """HTTP request handler for DDNS updates."""
@@ -182,7 +231,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
logging.warning( logging.warning(
f"Rate limited (bad): client={client_ip}, " f"Rate limited (bad): client={client_ip}, "
f"retry_at={datetime_str(retry_at)}") f"retry_at={datetime_str(retry_at)}")
self.respond(429, "abuse") self.respond(429, STATUS_ABUSE)
return return
# Parse URL # Parse URL
@@ -205,7 +254,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
if not username or not password: if not username or not password:
logging.warning(f"Auth failed: client={client_ip} user=anonymous") logging.warning(f"Auth failed: client={client_ip} user=anonymous")
self._handle_bad_request(client_ip, 401, "badauth") self._handle_bad_request(client_ip, 401, STATUS_BADAUTH)
return return
# Validate credentials # Validate credentials
@@ -214,14 +263,14 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
self.app.password_hasher.verify(user.password_hash, password) self.app.password_hasher.verify(user.password_hash, password)
except (DoesNotExist, argon2.exceptions.VerifyMismatchError): except (DoesNotExist, argon2.exceptions.VerifyMismatchError):
logging.warning(f"Auth failed: client={client_ip} user={username}") logging.warning(f"Auth failed: client={client_ip} user={username}")
self._handle_bad_request(client_ip, 401, "badauth") self._handle_bad_request(client_ip, 401, STATUS_BADAUTH)
return return
# Get hostname parameter # Get hostname parameter
hostname_param = extract_param(params, endpoint["params"]["hostname"]) hostname_param = extract_param(params, endpoint["params"]["hostname"])
if not hostname_param: if not hostname_param:
logging.warning(f"Missing hostname: client={client_ip} user={username}") logging.warning(f"Missing hostname: client={client_ip} user={username}")
self._handle_bad_request(client_ip, 400, "nohost") self._handle_bad_request(client_ip, 400, STATUS_NOHOST)
return return
# Validate and encode hostname # Validate and encode hostname
@@ -231,7 +280,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
logging.warning( logging.warning(
f"Invalid hostname: client={client_ip}, " f"Invalid hostname: client={client_ip}, "
f"hostname={hostname_param}") f"hostname={hostname_param}")
self._handle_bad_request(client_ip, 400, "nohost") self._handle_bad_request(client_ip, 400, STATUS_NOHOST)
return return
# Check hostname ownership # Check hostname ownership
@@ -242,7 +291,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
f"Access denied: client={client_ip} user={username} " f"Access denied: client={client_ip} user={username} "
f"hostname={hostname_param}" f"hostname={hostname_param}"
) )
self._handle_bad_request(client_ip, 403, "nohost") self._handle_bad_request(client_ip, 403, STATUS_NOHOST)
return return
# Good rate limit check # Good rate limit check
@@ -252,7 +301,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
logging.warning( logging.warning(
f"Rate limited: client={client_ip}, " f"Rate limited: client={client_ip}, "
f"retry_at={datetime_str(retry_at)}") f"retry_at={datetime_str(retry_at)}")
self.respond(429, "abuse") self.respond(429, STATUS_ABUSE)
return return
# Record good request # Record good request
@@ -262,11 +311,8 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
# Determine IPs to update # Determine IPs to update
result = self._process_ip_update(hostname, params, endpoint, client_ip) result = self._process_ip_update(hostname, params, endpoint, client_ip)
if result: if result:
code, status, *kwargs = result code, status, kwargs = result
if kwargs: self.respond(code, status, **kwargs)
self.respond(code, status, **kwargs[0])
else:
self.respond(code, status)
def _handle_bad_request(self, client_ip, code, status): def _handle_bad_request(self, client_ip, code, status):
"""Handle bad request and record in rate limiter.""" """Handle bad request and record in rate limiter."""
@@ -291,7 +337,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
else: else:
ipv6 = myip ipv6 = myip
except ValueError: except ValueError:
return (400, "badip") return (400, STATUS_BADIP, {})
# Process myip6 parameter # Process myip6 parameter
if myip6: if myip6:
@@ -300,9 +346,9 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
if rtype == "AAAA": if rtype == "AAAA":
ipv6 = myip6 ipv6 = myip6
else: else:
return (400, "badip") return (400, STATUS_BADIP, {})
except ValueError: except ValueError:
return (400, "badip") return (400, STATUS_BADIP, {})
# Auto-detect from client IP if no params # Auto-detect from client IP if no params
if ipv4 is None and ipv6 is None: if ipv4 is None and ipv6 is None:
@@ -313,7 +359,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
else: else:
ipv6 = ip ipv6 = ip
except ValueError: except ValueError:
return (400, "badip") return (400, STATUS_BADIP, {})
now = utc_now() now = utc_now()
@@ -338,7 +384,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
f"DNS update failed: client={client_ip} hostname={hostname.hostname} " f"DNS update failed: client={client_ip} hostname={hostname.hostname} "
f"zone={hostname.zone} ipv4={ipv4} error={e}" f"zone={hostname.zone} ipv4={ipv4} error={e}"
) )
return (500, "dnserr") return (500, STATUS_DNSERR, {})
if ipv6: if ipv6:
hostname.last_ipv6_update = now hostname.last_ipv6_update = now
@@ -359,7 +405,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
f"DNS update failed: client={client_ip} hostname={hostname.hostname} " f"DNS update failed: client={client_ip} hostname={hostname.hostname} "
f"zone={hostname.zone} ipv6={ipv6} error={e}" f"zone={hostname.zone} ipv6={ipv6} error={e}"
) )
return (500, "dnserr") return (500, STATUS_DNSERR, {})
# Update database # Update database
hostname.save() hostname.save()
@@ -380,7 +426,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
f"zone={hostname.zone}{changed_addrs} notify_change={str(notify_change).lower()}" f"zone={hostname.zone}{changed_addrs} notify_change={str(notify_change).lower()}"
) )
return ( return (
200, "nochg", 200, STATUS_NOCHG,
{"ipv4": hostname.last_ipv4, "ipv6": hostname.last_ipv6} {"ipv4": hostname.last_ipv4, "ipv6": hostname.last_ipv6}
) )
@@ -401,7 +447,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
logging.error(f"Sending change notification error: {e}") logging.error(f"Sending change notification error: {e}")
return ( return (
200, "good", 200, STATUS_GOOD,
{"ipv4": hostname.last_ipv4, "ipv6": hostname.last_ipv6} {"ipv4": hostname.last_ipv4, "ipv6": hostname.last_ipv6}
) )
@@ -440,14 +486,39 @@ def run_daemon(app):
expired_cleanup_thread = ExpiredRecordsCleanupThread(app) expired_cleanup_thread = ExpiredRecordsCleanupThread(app)
expired_cleanup_thread.start() expired_cleanup_thread.start()
# Setup signal handlers # Setup signal handlers
def signal_handler(signum, frame): def signal_handler(signum, frame):
logging.info(f"Signal received: {signum}, shutting down") logging.info(f"Signal received: {signum}, shutting down")
app.signal_shutdown() app.signal_shutdown()
def sighup_handler(signum, frame):
logging.info("SIGHUP received, reloading configuration")
try:
app.reload_config()
# Update server attributes
server.proxy_header = app.config["daemon"].get("proxy_header", "")
server.trusted_networks = _parse_trusted_proxies(
app.config["daemon"].get("trusted_proxies", [])
)
# Reload SSL if enabled
if app.config["daemon"]["ssl"]:
new_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
new_context.load_cert_chain(
app.config["daemon"]["ssl_cert_file"],
app.config["daemon"]["ssl_key_file"]
)
# Note: existing connections use old cert, new connections use new
server.socket = new_context.wrap_socket(
server.socket.detach(), server_side=True
)
except Exception as e:
logging.error(f"Config reload failed: {e}")
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGHUP, sighup_handler)
paths = ", ".join(ep["path"] for ep in config["endpoints"]) paths = ", ".join(ep["path"] for ep in config["endpoints"])
logging.info(f"Daemon started: {proto}://{host}:{port} endpoints=[{paths}]") logging.info(f"Daemon started: {proto}://{host}:{port} endpoints=[{paths}]")
@@ -457,6 +528,9 @@ def run_daemon(app):
while not app.is_shutting_down(): while not app.is_shutting_down():
server.handle_request() server.handle_request()
# Graceful shutdown - wait for active requests
server.wait_for_requests(5)
# Cleanup # Cleanup
expired_cleanup_thread.stop() expired_cleanup_thread.stop()
ratelimit_cleanup_thread.stop() ratelimit_cleanup_thread.stop()