Compare commits

..

11 Commits

8 changed files with 276 additions and 63 deletions

View File

@@ -10,6 +10,15 @@ import datetime
__version__ = "1.0.0"
__author__ = "Thomas Oettli <spacefreak@noop.ch>"
# DynDNS-compatible response statuses
STATUS_GOOD = "good"
STATUS_NOCHG = "nochg"
STATUS_BADAUTH = "badauth"
STATUS_NOHOST = "nohost"
STATUS_DNSERR = "dnserr"
STATUS_ABUSE = "abuse"
STATUS_BADIP = "badip"
__all__ = [
"app",
"cleanup",
@@ -23,12 +32,37 @@ __all__ = [
"models",
"ratelimit",
"server",
"STATUS_GOOD",
"STATUS_NOCHG",
"STATUS_BADAUTH",
"STATUS_NOHOST",
"STATUS_DNSERR",
"STATUS_ABUSE",
"STATUS_BADIP",
]
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S %Z"
# Datetime convention:
# All datetime objects in this codebase are naive UTC to match database storage.
# - utc_now(): returns naive UTC datetime
# - datetime_str(): converts naive UTC to display string (adds tzinfo for formatting)
# - Database stores/returns naive datetimes (always UTC by convention)
def datetime_str(dt, utc=False):
"""
Convert datetime to display string.
Assumes naive datetimes are UTC per codebase convention.
Args:
dt: Datetime object (naive UTC or timezone-aware).
utc: If True, display in UTC; otherwise convert to local timezone.
Returns:
Formatted datetime string, or "Never" if dt is not a datetime.
"""
if not isinstance(dt, datetime.datetime):
return "Never"
@@ -41,4 +75,10 @@ def datetime_str(dt, utc=False):
def utc_now():
"""
Get current time as naive UTC datetime.
Returns naive datetime to match database storage behavior.
All naive datetimes in this codebase are assumed to be UTC.
"""
return datetime.datetime.now(datetime.UTC).replace(tzinfo=None)

View File

@@ -1,12 +1,13 @@
"""Application class - central dependency holder."""
import argon2
import logging
import threading
import argon2
from .config import load_config
from .dns import DNSService
from .email import EmailService
from .logging import setup_logging
from .models import create_tables, init_database
from .ratelimit import BadLimiter, GoodLimiter
@@ -18,14 +19,16 @@ class Application:
Holds configuration and all service instances.
"""
def __init__(self, config):
def __init__(self, config, config_path=None):
"""
Initialize application with configuration.
Args:
config: Configuration dictionary from TOML file.
config_path: Path to configuration file (for reload).
"""
self.config = config
self.config_path = config_path
self.password_hasher = argon2.PasswordHasher()
self.shutdown_event = threading.Event()
@@ -57,6 +60,39 @@ class Application:
self.bad_limiter = BadLimiter(self.config)
logging.info("Rate limiters initialized")
def reload_config(self):
"""
Reload configuration from file.
Does not reload: database settings, host, port.
"""
new_config = load_config(self.config_path)
# Preserve DB and bind settings
new_config["database"] = self.config["database"]
new_config["daemon"]["host"] = self.config["daemon"]["host"]
new_config["daemon"]["port"] = self.config["daemon"]["port"]
self.config = new_config
# Reconfigure logging
setup_logging(
level=self.config["daemon"]["log_level"],
target=self.config["daemon"]["log_target"],
syslog_socket=self.config["daemon"]["syslog_socket"],
syslog_facility=self.config["daemon"]["syslog_facility"],
log_file=self.config["daemon"]["log_file"],
log_file_size=self.config["daemon"]["log_file_size"],
log_versions=self.config["daemon"]["log_versions"],
)
# Re-init services
self.init_dns()
self.init_email()
self.init_rate_limiters()
logging.info("Configuration reloaded")
def signal_shutdown(self):
"""Signal the application to shut down."""
logging.info("Shutdown signaled")

View File

@@ -141,9 +141,12 @@ def load_config(config_path):
cfg["daemon"].setdefault("ssl", False)
cfg["daemon"].setdefault("proxy_header", "")
cfg["daemon"].setdefault("trusted_proxies", [])
cfg["daemon"].setdefault("thread_pool_size", 10)
cfg["daemon"].setdefault("request_timeout", 10)
cfg.setdefault("database", {})
cfg["database"].setdefault("backend", "sqlite")
cfg["database"].setdefault("pool_size", 5)
cfg.setdefault("dns_service", {})
cfg["dns_service"].setdefault("dns_server", "127.0.0.1")

View File

@@ -37,6 +37,12 @@ def encode_dnsname(hostname):
Raises:
EncodingError: If hostname is invalid.
Example:
>>> encode_dnsname("münchen")
'xn--mnchen-3ya'
>>> encode_dnsname("example.com.")
'example.com'
"""
hostname = hostname.lower().strip()
@@ -84,6 +90,24 @@ def encode_dnsname(hostname):
def detect_ip_type(ip):
"""
Detect IP address type and normalize.
Args:
ip: IP address string.
Returns:
Tuple of (record_type, normalized_ip).
Raises:
ValueError: If IP address is invalid.
Example:
>>> detect_ip_type("192.168.1.1")
('A', '192.168.1.1')
>>> detect_ip_type("2001:db8::1")
('AAAA', '2001:db8::1')
"""
try:
addr = ipaddress.ip_address(ip)
if isinstance(addr, ipaddress.IPv4Address):
@@ -127,6 +151,16 @@ def parse_bind_key_file(path):
Raises:
DNSError: If parsing fails.
Example:
Key file contents::
key "ddns-key." {
algorithm hmac-sha256;
secret "base64secret==";
};
>>> keyring, algo = parse_bind_key_file("/etc/bind/ddns.key")
"""
if not path:
return None, None
@@ -324,13 +358,17 @@ class DNSService:
Update a DNS record for the given hostname.
Args:
hostname: Fully qualified hostname.
hostname: Hostname (without zone suffix).
zone: DNS zone name.
ip: IP address to set.
ttl: DNS record TTL.
Raises:
DNSError: If update fails.
Example:
>>> dns_service.update_record("myhost", "example.com", "192.168.1.1", 60)
>>> dns_service.update_record("myhost", "example.com", "2001:db8::1", 60)
"""
try:
record_type, normalized_ip = detect_ip_type(ip)

View File

@@ -153,7 +153,7 @@ def main():
)
# Create application instance
app = Application(config)
app = Application(config, config_path)
# Initialize database
try:

View File

@@ -4,21 +4,39 @@ import logging
import os
from . import utc_now
from .dns import encode_dnsname, EncodingError
from peewee import (
AutoField,
CharField,
DatabaseProxy,
DateTimeField,
DoesNotExist,
fn,
ForeignKeyField,
IntegerField,
Model,
MySQLDatabase,
SqliteDatabase,
)
from playhouse.pool import PooledMySQLDatabase
# Database instance (initialized later)
db = SqliteDatabase(None)
# Re-export DoesNotExist and EncodingError for convenience
__all__ = [
'db',
'DATABASE_VERSION',
'User',
'Hostname',
'Version',
'init_database',
'create_tables',
'get_user',
'get_hostname',
'get_hostname_for_user',
'DoesNotExist',
'EncodingError',
]
# Database proxy (initialized later with actual backend)
db = DatabaseProxy()
# Current database schema version
DATABASE_VERSION = 2
@@ -83,6 +101,14 @@ class Hostname(BaseModel):
(('hostname', 'zone'), True),
)
def save(self, *args, **kwargs):
"""Validate and encode hostname/zone before saving."""
if self.hostname:
self.hostname = encode_dnsname(self.hostname)
if self.zone:
self.zone = encode_dnsname(self.zone)
return super().save(*args, **kwargs)
class Version(BaseModel):
"""Database schema version for migrations."""
@@ -110,7 +136,6 @@ def init_database(config: dict):
Raises:
ValueError: If unknown database backend.
"""
global db
backend = config["database"].get("backend", "sqlite")
@@ -119,21 +144,20 @@ def init_database(config: dict):
db_dir = os.path.dirname(db_path)
if db_dir:
os.makedirs(db_dir, exist_ok=True)
db.init(db_path)
actual_db = SqliteDatabase(db_path)
db.initialize(actual_db)
logging.debug(f"Database backend: SQLite path={db_path}")
elif backend == "mariadb":
db = MySQLDatabase(
actual_db = PooledMySQLDatabase(
config["database"]["database"],
host=config["database"].get("host", "localhost"),
port=config["database"].get("port", 3306),
user=config["database"]["user"],
password=config["database"]["password"],
max_connections=config["database"].get("pool_size", 5),
)
# Re-bind models to new database
User._meta.database = db
Hostname._meta.database = db
Version._meta.database = db
db.initialize(actual_db)
db_name = config['database']['database']
logging.debug(f"Database backend: MariaDB db={db_name}")
@@ -261,6 +285,11 @@ def get_user(username: str):
Raises:
DoesNotExist: If user not found.
Example:
>>> user = get_user("alice")
>>> print(user.email)
'alice@example.com'
"""
return User.get(User.username == username)
@@ -278,6 +307,11 @@ def get_hostname(hostname, zone):
Raises:
DoesNotExist: If hostname not found.
Example:
>>> host = get_hostname("myhost", "example.com")
>>> print(host.last_ipv4)
'192.168.1.1'
"""
return Hostname.get(
(Hostname.hostname == hostname) & (Hostname.zone == zone)
@@ -289,7 +323,7 @@ def get_hostname_for_user(hostname: str, user: User):
Get hostname owned by specific user.
Args:
hostname: Hostname to look up.
hostname: Hostname to look up (FQDN).
user: User who should own the hostname.
Returns:
@@ -297,22 +331,10 @@ def get_hostname_for_user(hostname: str, user: User):
Raises:
DoesNotExist: If hostname not found or not owned by user.
Example:
>>> user = get_user("alice")
>>> host = get_hostname_for_user("myhost.example.com", user)
"""
fqdn = fn.Concat(Hostname.hostname, '.', Hostname.zone)
return Hostname.get((fqdn == hostname) & (Hostname.user == user))
# Re-export DoesNotExist for convenience
__all__ = [
'db',
'DATABASE_VERSION',
'User',
'Hostname',
'Version',
'init_database',
'create_tables',
'get_user',
'get_hostname',
'get_hostname_for_user',
'DoesNotExist',
]

View File

@@ -85,11 +85,11 @@ class GoodLimiter(BaseLimiter):
Args:
config: Full configuration dictionary.
"""
rl = config.get("rate_limit", {})
rl = config["rate_limit"]
super().__init__(
rl.get("good_window_seconds", 60),
rl.get("good_max_requests", 30),
rl.get("enabled", False),
rl["good_window_seconds"],
rl["good_max_requests"],
rl["enabled"],
False,
)
@@ -104,10 +104,10 @@ class BadLimiter(BaseLimiter):
Args:
config: Full configuration dictionary.
"""
rl = config.get("rate_limit", {})
rl = config["rate_limit"]
super().__init__(
rl.get("bad_window_seconds", 60),
rl.get("bad_max_requests", 5),
rl.get("enabled", False),
rl["bad_window_seconds"],
rl["bad_max_requests"],
rl["enabled"],
True,
)

View File

@@ -8,12 +8,24 @@ import json
import logging
import signal
import ssl
import threading
from concurrent.futures import ThreadPoolExecutor
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import parse_qs, urlparse
import argon2
from . import datetime_str, utc_now
from . import (
datetime_str,
utc_now,
STATUS_GOOD,
STATUS_NOCHG,
STATUS_BADAUTH,
STATUS_NOHOST,
STATUS_DNSERR,
STATUS_ABUSE,
STATUS_BADIP,
)
from .cleanup import ExpiredRecordsCleanupThread, RateLimitCleanupThread
from .logging import clear_txn_id, set_txn_id
from .models import DoesNotExist, get_hostname_for_user, get_user
@@ -64,7 +76,7 @@ def _is_trusted_proxy(client_ip, trusted_networks):
class DDNSServer(ThreadingHTTPServer):
"""HTTP server with Application instance."""
"""HTTP server with Application instance and thread pool."""
def __init__(self, app, address):
"""
@@ -79,8 +91,45 @@ class DDNSServer(ThreadingHTTPServer):
self.trusted_networks = _parse_trusted_proxies(
app.config["daemon"].get("trusted_proxies", [])
)
self.pool_size = app.config["daemon"]["thread_pool_size"]
self.request_timeout = app.config["daemon"]["request_timeout"]
self.executor = ThreadPoolExecutor(max_workers=self.pool_size)
self.active_requests = 0
self.requests_lock = threading.Lock()
self.requests_done = threading.Condition(self.requests_lock)
super().__init__(address, DDNSRequestHandler)
def process_request(self, request, client_address):
"""Submit request to thread pool."""
with self.requests_lock:
self.active_requests += 1
request.settimeout(self.request_timeout)
self.executor.submit(self._handle_request_wrapper, request, client_address)
def _handle_request_wrapper(self, request, client_address):
"""Wrap request handling to track active requests."""
try:
self.process_request_thread(request, client_address)
finally:
with self.requests_lock:
self.active_requests -= 1
if self.active_requests == 0:
self.requests_done.notify_all()
def wait_for_requests(self, timeout=5):
"""Wait for active requests to complete."""
with self.requests_lock:
if self.active_requests > 0:
logging.info(f"Waiting for {self.active_requests} active request(s)")
self.requests_done.wait(timeout=timeout)
if self.active_requests > 0:
logging.warning(f"Shutdown timeout, {self.active_requests} request(s) still active")
def server_close(self):
"""Shutdown thread pool and close server."""
self.executor.shutdown(wait=True)
super().server_close()
class DDNSRequestHandler(BaseHTTPRequestHandler):
"""HTTP request handler for DDNS updates."""
@@ -182,7 +231,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
logging.warning(
f"Rate limited (bad): client={client_ip}, "
f"retry_at={datetime_str(retry_at)}")
self.respond(429, "abuse")
self.respond(429, STATUS_ABUSE)
return
# Parse URL
@@ -205,7 +254,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
if not username or not password:
logging.warning(f"Auth failed: client={client_ip} user=anonymous")
self._handle_bad_request(client_ip, 401, "badauth")
self._handle_bad_request(client_ip, 401, STATUS_BADAUTH)
return
# Validate credentials
@@ -214,14 +263,14 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
self.app.password_hasher.verify(user.password_hash, password)
except (DoesNotExist, argon2.exceptions.VerifyMismatchError):
logging.warning(f"Auth failed: client={client_ip} user={username}")
self._handle_bad_request(client_ip, 401, "badauth")
self._handle_bad_request(client_ip, 401, STATUS_BADAUTH)
return
# Get hostname parameter
hostname_param = extract_param(params, endpoint["params"]["hostname"])
if not hostname_param:
logging.warning(f"Missing hostname: client={client_ip} user={username}")
self._handle_bad_request(client_ip, 400, "nohost")
self._handle_bad_request(client_ip, 400, STATUS_NOHOST)
return
# Validate and encode hostname
@@ -231,7 +280,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
logging.warning(
f"Invalid hostname: client={client_ip}, "
f"hostname={hostname_param}")
self._handle_bad_request(client_ip, 400, "nohost")
self._handle_bad_request(client_ip, 400, STATUS_NOHOST)
return
# Check hostname ownership
@@ -242,7 +291,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
f"Access denied: client={client_ip} user={username} "
f"hostname={hostname_param}"
)
self._handle_bad_request(client_ip, 403, "nohost")
self._handle_bad_request(client_ip, 403, STATUS_NOHOST)
return
# Good rate limit check
@@ -252,7 +301,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
logging.warning(
f"Rate limited: client={client_ip}, "
f"retry_at={datetime_str(retry_at)}")
self.respond(429, "abuse")
self.respond(429, STATUS_ABUSE)
return
# Record good request
@@ -262,11 +311,8 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
# Determine IPs to update
result = self._process_ip_update(hostname, params, endpoint, client_ip)
if result:
code, status, *kwargs = result
if kwargs:
self.respond(code, status, **kwargs[0])
else:
self.respond(code, status)
code, status, kwargs = result
self.respond(code, status, **kwargs)
def _handle_bad_request(self, client_ip, code, status):
"""Handle bad request and record in rate limiter."""
@@ -291,7 +337,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
else:
ipv6 = myip
except ValueError:
return (400, "badip")
return (400, STATUS_BADIP, {})
# Process myip6 parameter
if myip6:
@@ -300,9 +346,9 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
if rtype == "AAAA":
ipv6 = myip6
else:
return (400, "badip")
return (400, STATUS_BADIP, {})
except ValueError:
return (400, "badip")
return (400, STATUS_BADIP, {})
# Auto-detect from client IP if no params
if ipv4 is None and ipv6 is None:
@@ -313,7 +359,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
else:
ipv6 = ip
except ValueError:
return (400, "badip")
return (400, STATUS_BADIP, {})
now = utc_now()
@@ -338,7 +384,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
f"DNS update failed: client={client_ip} hostname={hostname.hostname} "
f"zone={hostname.zone} ipv4={ipv4} error={e}"
)
return (500, "dnserr")
return (500, STATUS_DNSERR, {})
if ipv6:
hostname.last_ipv6_update = now
@@ -359,7 +405,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
f"DNS update failed: client={client_ip} hostname={hostname.hostname} "
f"zone={hostname.zone} ipv6={ipv6} error={e}"
)
return (500, "dnserr")
return (500, STATUS_DNSERR, {})
# Update database
hostname.save()
@@ -380,7 +426,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
f"zone={hostname.zone}{changed_addrs} notify_change={str(notify_change).lower()}"
)
return (
200, "nochg",
200, STATUS_NOCHG,
{"ipv4": hostname.last_ipv4, "ipv6": hostname.last_ipv6}
)
@@ -401,7 +447,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
logging.error(f"Sending change notification error: {e}")
return (
200, "good",
200, STATUS_GOOD,
{"ipv4": hostname.last_ipv4, "ipv6": hostname.last_ipv6}
)
@@ -440,14 +486,39 @@ def run_daemon(app):
expired_cleanup_thread = ExpiredRecordsCleanupThread(app)
expired_cleanup_thread.start()
# Setup signal handlers
def signal_handler(signum, frame):
logging.info(f"Signal received: {signum}, shutting down")
app.signal_shutdown()
def sighup_handler(signum, frame):
logging.info("SIGHUP received, reloading configuration")
try:
app.reload_config()
# Update server attributes
server.proxy_header = app.config["daemon"].get("proxy_header", "")
server.trusted_networks = _parse_trusted_proxies(
app.config["daemon"].get("trusted_proxies", [])
)
# Reload SSL if enabled
if app.config["daemon"]["ssl"]:
new_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
new_context.load_cert_chain(
app.config["daemon"]["ssl_cert_file"],
app.config["daemon"]["ssl_key_file"]
)
# Note: existing connections use old cert, new connections use new
server.socket = new_context.wrap_socket(
server.socket.detach(), server_side=True
)
except Exception as e:
logging.error(f"Config reload failed: {e}")
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGHUP, sighup_handler)
paths = ", ".join(ep["path"] for ep in config["endpoints"])
logging.info(f"Daemon started: {proto}://{host}:{port} endpoints=[{paths}]")
@@ -457,6 +528,9 @@ def run_daemon(app):
while not app.is_shutting_down():
server.handle_request()
# Graceful shutdown - wait for active requests
server.wait_for_requests(5)
# Cleanup
expired_cleanup_thread.stop()
ratelimit_cleanup_thread.stop()