From 07e37e525c80caa670d34ce06d63c912da114e07 Mon Sep 17 00:00:00 2001 From: Thomas Oettli Date: Sat, 24 Jan 2026 00:46:48 +0100 Subject: [PATCH] Use db model hostname validation in cli and improve exception handling --- src/ddns_service/cli.py | 200 ++++++++++++++++++++----------------- src/ddns_service/models.py | 19 +++- src/ddns_service/server.py | 88 +++++++++------- 3 files changed, 175 insertions(+), 132 deletions(-) diff --git a/src/ddns_service/cli.py b/src/ddns_service/cli.py index 957353b..c8a928a 100644 --- a/src/ddns_service/cli.py +++ b/src/ddns_service/cli.py @@ -1,18 +1,19 @@ """CLI commands for user and hostname management.""" import getpass -import logging from . import datetime_str from .cleanup import cleanup_expired +from .dns import encode_dnsname from .models import ( + DatabaseError, DoesNotExist, + EncodingError, get_hostname, get_user, Hostname, User, ) -from .dns import encode_dnsname, EncodingError def cmd_user_list(args, app): @@ -166,123 +167,135 @@ def cmd_hostname_list(args, app): def cmd_hostname_add(args, app): """Add a hostname.""" username = args.username - config = app.config - # Validate and encode hostname/zone try: - hostname_str = encode_dnsname(args.hostname) - zone = encode_dnsname(args.zone) - except EncodingError as e: - print(f"Error: {e}") + # Get user + try: + user = get_user(username) + except DoesNotExist: + print(f"Error: User '{username}' not found.") + return 1 + + # Check if hostname+zone exists + try: + hostname = get_hostname(args.hostname, args.zone) + print(f"Error: Hostname '{hostname.hostname}' in zone '{hostname.zone}' exists.") + return 1 + except EncodingError as e: + print(f"Error: {e}") + return 1 + except DoesNotExist: + pass + + # Get TTLs from args or config defaults + config = app.config + + dns_ttl = args.dns_ttl + if dns_ttl is None: + dns_ttl = config["defaults"]["dns_ttl"] + expiry_ttl = args.expiry_ttl + if expiry_ttl is None: + expiry_ttl = config["defaults"]["expiry_ttl"] + + # Create hostname + hostname = Hostname.create( + user=user, + hostname=args.hostname, + zone=args.zone, + dns_ttl=dns_ttl, + expiry_ttl=expiry_ttl + ) + print( + f"Hostname '{hostname.hostname}' in zone '{hostname.zone}' added " + f"for user '{username}'." + ) + except DatabaseError as e: + print(f"Database error: {e}") return 1 - # Get TTLs from args or config defaults - dns_ttl = args.dns_ttl - if dns_ttl is None: - dns_ttl = config["defaults"]["dns_ttl"] - expiry_ttl = args.expiry_ttl - if expiry_ttl is None: - expiry_ttl = config["defaults"]["expiry_ttl"] - - # Get user - try: - user = get_user(username) - except DoesNotExist: - print(f"Error: User '{username}' not found.") - return 1 - - # Check if hostname+zone exists - if Hostname.select().where( - (Hostname.hostname == hostname_str) & (Hostname.zone == zone) - ).exists(): - print(f"Error: Hostname '{hostname_str}' in zone '{zone}' exists.") - return 1 - - # Create hostname - Hostname.create( - user=user, - hostname=hostname_str, - zone=zone, - dns_ttl=dns_ttl, - expiry_ttl=expiry_ttl - ) - print(f"Hostname '{hostname_str}' added for user '{username}'.") return 0 def cmd_hostname_delete(args, app): """Delete a hostname.""" - # Validate and encode hostname and zone try: - hostname_str = encode_dnsname(args.hostname) - zone = encode_dnsname(args.zone) - except EncodingError as e: - print(f"Error: {e}") - return 1 + try: + hostname = get_hostname(args.hostname, args.zone) + except DoesNotExist: + hostname = encode_dnsname(args.hostname) + zone = encode_dnsname(args.zone) + print(f"Error: Hostname '{hostname}' in zone '{zone}' not found.") + return 1 + except EncodingError as e: + print(f"Error: {e}") + return 1 - try: - hostname = get_hostname(hostname_str, zone) - except DoesNotExist: - print(f"Error: Hostname '{hostname_str}' in zone '{zone}' not found.") - return 1 + # Delete DNS records if active + if hostname.last_ipv4 or hostname.last_ipv6: + # Initialize DNS service if not already + if app.dns_service is None: + try: + app.init_dns() + except Exception as e: + print(f"DNS init failed: {e}") + return 1 - # Delete DNS records if active - if hostname.last_ipv4 or hostname.last_ipv6: - # Initialize DNS service if not already - if app.dns_service is None: - try: - app.init_dns() - except Exception as e: - logging.warning(f"DNS init failed: {e}") - - if app.dns_service: if hostname.last_ipv4: try: app.dns_service.delete_record( hostname.hostname, hostname.zone, "A" ) except Exception as e: - logging.warning(f"DNS delete failed: type=A error={e}") + print(f"DNS delete failed: type=A error={e}") + return 1 + if hostname.last_ipv6: try: app.dns_service.delete_record( hostname.hostname, hostname.zone, "AAAA" ) except Exception as e: - logging.warning(f"DNS delete failed: type=AAAA error={e}") + print(f"DNS delete failed: type=AAAA error={e}") + return 1 + + hostname.delete_instance() + print(f"Hostname '{hostname.hostname}' in zone '{hostname.zone}' deleted.") + except DatabaseError as e: + print(f"Database error: {e}") + return 1 - hostname.delete_instance() - print(f"Hostname '{hostname_str}' in zone '{zone}' deleted.") return 0 def cmd_hostname_modify(args, app): """Modify hostname settings.""" - # Validate and encode hostname and zone try: - hostname_str = encode_dnsname(args.hostname) - zone = encode_dnsname(args.zone) - except EncodingError as e: - print(f"Error: {e}") + try: + hostname = get_hostname(args.hostname, args.zone) + except DoesNotExist: + hostname = encode_dnsname(args.hostname) + zone = encode_dnsname(args.zone) + print(f"Error: Hostname '{hostname}' in zone '{zone}' not found.") + return 1 + except EncodingError as e: + print(f"Error: {e}") + return 1 + + # Get new TTLs + dns_ttl = args.dns_ttl if args.dns_ttl is not None else hostname.dns_ttl + expiry_ttl = args.expiry_ttl if args.expiry_ttl is not None else hostname.expiry_ttl + + hostname.dns_ttl = dns_ttl + hostname.expiry_ttl = expiry_ttl + hostname.save() + print( + f"Hostname '{hostname.hostname}' in zone '{hostname.zone}' updated: " + f"dns_ttl={dns_ttl}, expiry_ttl={expiry_ttl}" + ) + except DatabaseError as e: + print(f"Database error: {e}") return 1 - try: - hostname = get_hostname(hostname_str, zone) - except DoesNotExist: - print(f"Error: Hostname '{hostname_str}' in zone '{zone}' not found.") - return 1 - - # Get new TTLs - dns_ttl = args.dns_ttl if args.dns_ttl is not None else hostname.dns_ttl - expiry_ttl = args.expiry_ttl if args.expiry_ttl is not None else hostname.expiry_ttl - - hostname.dns_ttl = dns_ttl - hostname.expiry_ttl = expiry_ttl - hostname.save() - print( - f"Hostname '{hostname_str}' updated: " - f"dns_ttl={dns_ttl}, expiry_ttl={expiry_ttl}" - ) return 0 @@ -293,11 +306,20 @@ def cmd_cleanup(args, app): try: app.init_dns() except Exception as e: - logging.warning(f"DNS init failed: {e}") + print(f"DNS init failed: {e}") + return 1 if app.email_service is None: app.init_email() - count = cleanup_expired(app) - print(f"Cleanup complete: {count} expired hostname(s) processed.") + try: + count = cleanup_expired(app) + print(f"Cleanup complete: {count} expired hostname(s) processed.") + except DatabaseError as e: + print(f"Database error: {e}") + return 1 + except Exception as e: + print(f"Error: {e}") + return 1 + return 0 diff --git a/src/ddns_service/models.py b/src/ddns_service/models.py index 0aea2c8..a18ec70 100644 --- a/src/ddns_service/models.py +++ b/src/ddns_service/models.py @@ -16,10 +16,12 @@ from peewee import ( IntegerField, Model, SqliteDatabase, + PeeweeException as DatabaseError, ) from playhouse.pool import PooledMySQLDatabase -# Re-export DoesNotExist and EncodingError for convenience +# Re-export PeeweeException as DatabseException, DoesNotExist and +# EncodingError for convenience __all__ = [ 'db', 'DATABASE_VERSION', @@ -33,6 +35,7 @@ __all__ = [ 'get_hostname_for_user', 'DoesNotExist', 'EncodingError', + 'DatabaseError', ] # Database proxy (initialized later with actual backend) @@ -273,7 +276,7 @@ def create_tables(): logging.debug("Database tables created") -def get_user(username: str): +def get_user(username: str) -> User: """ Get user by username. @@ -294,7 +297,7 @@ def get_user(username: str): return User.get(User.username == username) -def get_hostname(hostname, zone): +def get_hostname(hostname: str, zone: str) -> Hostname: """ Get hostname by name and zone. @@ -307,6 +310,7 @@ def get_hostname(hostname, zone): Raises: DoesNotExist: If hostname not found. + EncodingError: If hostname or zone is invalid. Example: >>> host = get_hostname("myhost", "example.com") @@ -314,7 +318,8 @@ def get_hostname(hostname, zone): '192.168.1.1' """ return Hostname.get( - (Hostname.hostname == hostname) & (Hostname.zone == zone) + (Hostname.hostname == encode_dnsname(hostname)) & + (Hostname.zone == encode_dnsname(zone)) ) @@ -331,10 +336,14 @@ def get_hostname_for_user(hostname: str, user: User): Raises: DoesNotExist: If hostname not found or not owned by user. + EncodingError: If hostname is invalid. Example: >>> user = get_user("alice") >>> host = get_hostname_for_user("myhost.example.com", user) """ fqdn = fn.Concat(Hostname.hostname, '.', Hostname.zone) - return Hostname.get((fqdn == hostname) & (Hostname.user == user)) + return Hostname.get( + (fqdn == encode_dnsname(hostname)) & + (Hostname.user == user) + ) diff --git a/src/ddns_service/server.py b/src/ddns_service/server.py index 02e2b35..8888f21 100644 --- a/src/ddns_service/server.py +++ b/src/ddns_service/server.py @@ -2,6 +2,7 @@ from __future__ import annotations +import argon2 import base64 import ipaddress import json @@ -9,11 +10,6 @@ import logging import signal import ssl import threading -from concurrent.futures import ThreadPoolExecutor -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from urllib.parse import parse_qs, urlparse - -import argon2 from . import ( datetime_str, @@ -27,9 +23,18 @@ from . import ( STATUS_BADIP, ) from .cleanup import ExpiredRecordsCleanupThread, RateLimitCleanupThread +from .dns import detect_ip_type from .logging import clear_txn_id, set_txn_id -from .models import DoesNotExist, get_hostname_for_user, get_user -from .dns import detect_ip_type, encode_dnsname, EncodingError +from .models import ( + DatabaseError, + DoesNotExist, + EncodingError, + get_hostname_for_user, + get_user +) +from concurrent.futures import ThreadPoolExecutor +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from urllib.parse import parse_qs, urlparse def extract_param(params, aliases): @@ -123,7 +128,9 @@ class DDNSServer(ThreadingHTTPServer): logging.info(f"Waiting for {self.active_requests} active request(s)") self.requests_done.wait(timeout=timeout) if self.active_requests > 0: - logging.warning(f"Shutdown timeout, {self.active_requests} request(s) still active") + logging.warning( + f"Shutdown timeout, {self.active_requests} request(s) still active" + ) def server_close(self): """Shutdown thread pool and close server.""" @@ -212,6 +219,9 @@ class DDNSRequestHandler(BaseHTTPRequestHandler): set_txn_id() try: self._handle_get_request() + except Exception as e: + logging.exception(f"Uncaught exception: {e}") + self.respond(500, "Internal Server Error") finally: clear_txn_id() @@ -259,39 +269,41 @@ class DDNSRequestHandler(BaseHTTPRequestHandler): # Validate credentials try: - user = get_user(username) - self.app.password_hasher.verify(user.password_hash, password) - except (DoesNotExist, argon2.exceptions.VerifyMismatchError): - logging.warning(f"Auth failed: client={client_ip} user={username}") - self._handle_bad_request(client_ip, 401, STATUS_BADAUTH) - return + try: + user = get_user(username) + self.app.password_hasher.verify(user.password_hash, password) + except (DoesNotExist, argon2.exceptions.VerifyMismatchError): + logging.warning(f"Auth failed: client={client_ip} user={username}") + self._handle_bad_request(client_ip, 401, STATUS_BADAUTH) + return - # Get hostname parameter - hostname_param = extract_param(params, endpoint["params"]["hostname"]) - if not hostname_param: - logging.warning(f"Missing hostname: client={client_ip} user={username}") - self._handle_bad_request(client_ip, 400, STATUS_NOHOST) - return + # Get hostname parameter + hostname_param = extract_param(params, endpoint["params"]["hostname"]) + if not hostname_param: + logging.warning(f"Missing hostname: client={client_ip} user={username}") + self._handle_bad_request(client_ip, 400, STATUS_NOHOST) + return - # Validate and encode hostname - try: - hostname_param = encode_dnsname(hostname_param) - except EncodingError: - logging.warning( - f"Invalid hostname: client={client_ip}, " - f"hostname={hostname_param}") - self._handle_bad_request(client_ip, 400, STATUS_NOHOST) - return + # Check hostname ownership + try: + hostname = get_hostname_for_user(hostname_param, user) + except DoesNotExist: + logging.warning( + f"Access denied: client={client_ip} user={username} " + f"hostname={hostname_param}" + ) + self._handle_bad_request(client_ip, 403, STATUS_NOHOST) + return + except EncodingError: + logging.warning( + f"Invalid hostname: client={client_ip}, " + f"hostname={hostname_param}") + self._handle_bad_request(client_ip, 400, STATUS_NOHOST) + return - # Check hostname ownership - try: - hostname = get_hostname_for_user(hostname_param, user) - except DoesNotExist: - logging.warning( - f"Access denied: client={client_ip} user={username} " - f"hostname={hostname_param}" - ) - self._handle_bad_request(client_ip, 403, STATUS_NOHOST) + except DatabaseError as e: + logging.error(f"Database error: {e}") + self.respond(500, "Internal Server Error") return # Good rate limit check