Compare commits

...

2 Commits

5 changed files with 264 additions and 160 deletions

View File

@@ -10,26 +10,20 @@ import datetime
__version__ = "1.0.0" __version__ = "1.0.0"
__author__ = "Thomas Oettli <spacefreak@noop.ch>" __author__ = "Thomas Oettli <spacefreak@noop.ch>"
# DynDNS-compatible response statuses
STATUS_GOOD = "good"
STATUS_NOCHG = "nochg"
STATUS_BADAUTH = "badauth"
STATUS_NOHOST = "nohost"
STATUS_DNSERR = "dnserr"
STATUS_ABUSE = "abuse"
STATUS_BADIP = "badip"
__all__ = [ __all__ = [
"app", "app",
"cleanup", "cleanup",
"cli", "cli",
"config", "config",
"datetime_aware_utc",
"datetime_naive_utc",
"datetime_str", "datetime_str",
"dns", "dns",
"email", "email",
"logging", "logging",
"main", "main",
"models", "models",
"now_utc"
"ratelimit", "ratelimit",
"server", "server",
"STATUS_GOOD", "STATUS_GOOD",
@@ -41,13 +35,75 @@ __all__ = [
"STATUS_BADIP", "STATUS_BADIP",
] ]
# DynDNS-compatible response statuses
STATUS_GOOD = "good"
STATUS_NOCHG = "nochg"
STATUS_BADAUTH = "badauth"
STATUS_NOHOST = "nohost"
STATUS_DNSERR = "dnserr"
STATUS_ABUSE = "abuse"
STATUS_BADIP = "badip"
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S %Z" DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S %Z"
# Datetime convention: # Datetime convention:
# All datetime objects in this codebase are naive UTC to match database storage. # All datetime objects in this codebase are timezone-aware.
# - utc_now(): returns naive UTC datetime # - now_utc(): returns timezone-aware UTC datetime
# - datetime_str(): converts naive UTC to display string (adds tzinfo for formatting) # - datetime_str(): converts naive UTC (adds tzinfo for formatting)
# or timezone-aware datetime to display string
# - Database stores/returns naive datetimes (always UTC by convention) # - Database stores/returns naive datetimes (always UTC by convention)
# - Database models automatically convert between naive/timezone-aware datetimes
def now_utc():
"""
Get current date and time in UTC.
Returns:
Timezone-aware datetime object in UTC.
"""
return datetime.datetime.now(datetime.UTC)
def datetime_naive_utc(dt):
"""
Convert datetime to naive UTC datetime.
Args:
dt: Datetime object (naive UTC or timezone-aware or None).
Returns:
Naive datetime object in UTC or None if dt is not a datetime.
"""
if not isinstance(dt, datetime.datetime):
return None
if not dt.tzinfo:
return dt
return dt.astimezone(datetime.UTC).replace(tzinfo=None)
def datetime_aware_utc(dt):
"""
Convert datetime to UTC datetime.
Args:
dt: Datetime object (naive UTC or timezone-aware or None).
Returns:
Timzone-aware datetime object in UTC or None if dt is not a datetime.
"""
if not isinstance(dt, datetime.datetime):
return None
if not dt.tzinfo:
return dt.replace(tzinfo=datetime.UTC)
if dt.tzinfo == datetime.UTC:
return dt
return dt.astimezone(datetime.UTC)
def datetime_str(dt, utc=False): def datetime_str(dt, utc=False):
@@ -72,13 +128,3 @@ def datetime_str(dt, utc=False):
return aware_dt.strftime(DATETIME_FORMAT) return aware_dt.strftime(DATETIME_FORMAT)
else: else:
return aware_dt.astimezone().strftime(DATETIME_FORMAT) return aware_dt.astimezone().strftime(DATETIME_FORMAT)
def utc_now():
"""
Get current time as naive UTC datetime.
Returns naive datetime to match database storage behavior.
All naive datetimes in this codebase are assumed to be UTC.
"""
return datetime.datetime.now(datetime.UTC).replace(tzinfo=None)

View File

@@ -3,7 +3,7 @@
import logging import logging
import threading import threading
from . import utc_now from . import now_utc
from .models import Hostname, User from .models import Hostname, User
from datetime import timedelta from datetime import timedelta
@@ -18,7 +18,7 @@ def cleanup_expired(app):
Returns: Returns:
Number of expired hostnames processed. Number of expired hostnames processed.
""" """
now = utc_now() now = now_utc()
expired_count = 0 expired_count = 0
for hostname in Hostname.select().join(User).where( for hostname in Hostname.select().join(User).where(

View File

@@ -1,18 +1,19 @@
"""CLI commands for user and hostname management.""" """CLI commands for user and hostname management."""
import getpass import getpass
import logging
from . import datetime_str from . import datetime_str
from .cleanup import cleanup_expired from .cleanup import cleanup_expired
from .dns import encode_dnsname
from .models import ( from .models import (
DatabaseError,
DoesNotExist, DoesNotExist,
EncodingError,
get_hostname, get_hostname,
get_user, get_user,
Hostname, Hostname,
User, User,
) )
from .dns import encode_dnsname, EncodingError
def cmd_user_list(args, app): def cmd_user_list(args, app):
@@ -166,24 +167,8 @@ def cmd_hostname_list(args, app):
def cmd_hostname_add(args, app): def cmd_hostname_add(args, app):
"""Add a hostname.""" """Add a hostname."""
username = args.username username = args.username
config = app.config
# Validate and encode hostname/zone
try: try:
hostname_str = encode_dnsname(args.hostname)
zone = encode_dnsname(args.zone)
except EncodingError as e:
print(f"Error: {e}")
return 1
# Get TTLs from args or config defaults
dns_ttl = args.dns_ttl
if dns_ttl is None:
dns_ttl = config["defaults"]["dns_ttl"]
expiry_ttl = args.expiry_ttl
if expiry_ttl is None:
expiry_ttl = config["defaults"]["expiry_ttl"]
# Get user # Get user
try: try:
user = get_user(username) user = get_user(username)
@@ -192,40 +177,59 @@ def cmd_hostname_add(args, app):
return 1 return 1
# Check if hostname+zone exists # Check if hostname+zone exists
if Hostname.select().where( try:
(Hostname.hostname == hostname_str) & (Hostname.zone == zone) hostname = get_hostname(args.hostname, args.zone)
).exists(): print(f"Error: Hostname '{hostname.hostname}' in zone '{hostname.zone}' exists.")
print(f"Error: Hostname '{hostname_str}' in zone '{zone}' exists.")
return 1 return 1
except EncodingError as e:
print(f"Error: {e}")
return 1
except DoesNotExist:
pass
# Get TTLs from args or config defaults
config = app.config
dns_ttl = args.dns_ttl
if dns_ttl is None:
dns_ttl = config["defaults"]["dns_ttl"]
expiry_ttl = args.expiry_ttl
if expiry_ttl is None:
expiry_ttl = config["defaults"]["expiry_ttl"]
# Create hostname # Create hostname
Hostname.create( hostname = Hostname.create(
user=user, user=user,
hostname=hostname_str, hostname=args.hostname,
zone=zone, zone=args.zone,
dns_ttl=dns_ttl, dns_ttl=dns_ttl,
expiry_ttl=expiry_ttl expiry_ttl=expiry_ttl
) )
print(f"Hostname '{hostname_str}' added for user '{username}'.") print(
f"Hostname '{hostname.hostname}' in zone '{hostname.zone}' added "
f"for user '{username}'."
)
except DatabaseError as e:
print(f"Database error: {e}")
return 1
return 0 return 0
def cmd_hostname_delete(args, app): def cmd_hostname_delete(args, app):
"""Delete a hostname.""" """Delete a hostname."""
# Validate and encode hostname and zone
try: try:
hostname_str = encode_dnsname(args.hostname) try:
hostname = get_hostname(args.hostname, args.zone)
except DoesNotExist:
hostname = encode_dnsname(args.hostname)
zone = encode_dnsname(args.zone) zone = encode_dnsname(args.zone)
print(f"Error: Hostname '{hostname}' in zone '{zone}' not found.")
return 1
except EncodingError as e: except EncodingError as e:
print(f"Error: {e}") print(f"Error: {e}")
return 1 return 1
try:
hostname = get_hostname(hostname_str, zone)
except DoesNotExist:
print(f"Error: Hostname '{hostname_str}' in zone '{zone}' not found.")
return 1
# Delete DNS records if active # Delete DNS records if active
if hostname.last_ipv4 or hostname.last_ipv6: if hostname.last_ipv4 or hostname.last_ipv6:
# Initialize DNS service if not already # Initialize DNS service if not already
@@ -233,45 +237,50 @@ def cmd_hostname_delete(args, app):
try: try:
app.init_dns() app.init_dns()
except Exception as e: except Exception as e:
logging.warning(f"DNS init failed: {e}") print(f"DNS init failed: {e}")
return 1
if app.dns_service:
if hostname.last_ipv4: if hostname.last_ipv4:
try: try:
app.dns_service.delete_record( app.dns_service.delete_record(
hostname.hostname, hostname.zone, "A" hostname.hostname, hostname.zone, "A"
) )
except Exception as e: except Exception as e:
logging.warning(f"DNS delete failed: type=A error={e}") print(f"DNS delete failed: type=A error={e}")
return 1
if hostname.last_ipv6: if hostname.last_ipv6:
try: try:
app.dns_service.delete_record( app.dns_service.delete_record(
hostname.hostname, hostname.zone, "AAAA" hostname.hostname, hostname.zone, "AAAA"
) )
except Exception as e: except Exception as e:
logging.warning(f"DNS delete failed: type=AAAA error={e}") print(f"DNS delete failed: type=AAAA error={e}")
return 1
hostname.delete_instance() hostname.delete_instance()
print(f"Hostname '{hostname_str}' in zone '{zone}' deleted.") print(f"Hostname '{hostname.hostname}' in zone '{hostname.zone}' deleted.")
except DatabaseError as e:
print(f"Database error: {e}")
return 1
return 0 return 0
def cmd_hostname_modify(args, app): def cmd_hostname_modify(args, app):
"""Modify hostname settings.""" """Modify hostname settings."""
# Validate and encode hostname and zone
try: try:
hostname_str = encode_dnsname(args.hostname) try:
hostname = get_hostname(args.hostname, args.zone)
except DoesNotExist:
hostname = encode_dnsname(args.hostname)
zone = encode_dnsname(args.zone) zone = encode_dnsname(args.zone)
print(f"Error: Hostname '{hostname}' in zone '{zone}' not found.")
return 1
except EncodingError as e: except EncodingError as e:
print(f"Error: {e}") print(f"Error: {e}")
return 1 return 1
try:
hostname = get_hostname(hostname_str, zone)
except DoesNotExist:
print(f"Error: Hostname '{hostname_str}' in zone '{zone}' not found.")
return 1
# Get new TTLs # Get new TTLs
dns_ttl = args.dns_ttl if args.dns_ttl is not None else hostname.dns_ttl dns_ttl = args.dns_ttl if args.dns_ttl is not None else hostname.dns_ttl
expiry_ttl = args.expiry_ttl if args.expiry_ttl is not None else hostname.expiry_ttl expiry_ttl = args.expiry_ttl if args.expiry_ttl is not None else hostname.expiry_ttl
@@ -280,9 +289,13 @@ def cmd_hostname_modify(args, app):
hostname.expiry_ttl = expiry_ttl hostname.expiry_ttl = expiry_ttl
hostname.save() hostname.save()
print( print(
f"Hostname '{hostname_str}' updated: " f"Hostname '{hostname.hostname}' in zone '{hostname.zone}' updated: "
f"dns_ttl={dns_ttl}, expiry_ttl={expiry_ttl}" f"dns_ttl={dns_ttl}, expiry_ttl={expiry_ttl}"
) )
except DatabaseError as e:
print(f"Database error: {e}")
return 1
return 0 return 0
@@ -293,11 +306,20 @@ def cmd_cleanup(args, app):
try: try:
app.init_dns() app.init_dns()
except Exception as e: except Exception as e:
logging.warning(f"DNS init failed: {e}") print(f"DNS init failed: {e}")
return 1
if app.email_service is None: if app.email_service is None:
app.init_email() app.init_email()
try:
count = cleanup_expired(app) count = cleanup_expired(app)
print(f"Cleanup complete: {count} expired hostname(s) processed.") print(f"Cleanup complete: {count} expired hostname(s) processed.")
except DatabaseError as e:
print(f"Database error: {e}")
return 1
except Exception as e:
print(f"Error: {e}")
return 1
return 0 return 0

View File

@@ -3,7 +3,7 @@
import logging import logging
import os import os
from . import utc_now from . import datetime_naive_utc, datetime_aware_utc, now_utc
from .dns import encode_dnsname, EncodingError from .dns import encode_dnsname, EncodingError
from peewee import ( from peewee import (
AutoField, AutoField,
@@ -16,10 +16,12 @@ from peewee import (
IntegerField, IntegerField,
Model, Model,
SqliteDatabase, SqliteDatabase,
PeeweeException as DatabaseError,
) )
from playhouse.pool import PooledMySQLDatabase from playhouse.pool import PooledMySQLDatabase
# Re-export DoesNotExist and EncodingError for convenience # Re-export PeeweeException as DatabseException, DoesNotExist and
# EncodingError for convenience
__all__ = [ __all__ = [
'db', 'db',
'DATABASE_VERSION', 'DATABASE_VERSION',
@@ -33,6 +35,7 @@ __all__ = [
'get_hostname_for_user', 'get_hostname_for_user',
'DoesNotExist', 'DoesNotExist',
'EncodingError', 'EncodingError',
'DatabaseError',
] ]
# Database proxy (initialized later with actual backend) # Database proxy (initialized later with actual backend)
@@ -75,11 +78,19 @@ class User(BaseModel):
username = CharField(max_length=64, unique=True) username = CharField(max_length=64, unique=True)
password_hash = CharField(max_length=128) password_hash = CharField(max_length=128)
email = CharField(max_length=255) email = CharField(max_length=255)
created_at = DateTimeField(default=utc_now) created_at = DateTimeField(default=now_utc)
class Meta: class Meta:
table_name = "users" table_name = "users"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.created_at = datetime_aware_utc(self.created_at)
def save(self, *args, **kwargs):
self.created_at = datetime_naive_utc(self.created_at)
return super().save(*args, **kwargs)
class Hostname(BaseModel): class Hostname(BaseModel):
"""Hostname model for DNS records.""" """Hostname model for DNS records."""
@@ -101,12 +112,19 @@ class Hostname(BaseModel):
(('hostname', 'zone'), True), (('hostname', 'zone'), True),
) )
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_ipv4_update = datetime_aware_utc(self.last_ipv4_update)
self.last_ipv6_update = datetime_aware_utc(self.last_ipv6_update)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Validate and encode hostname/zone before saving.""" """Validate and encode hostname/zone before saving."""
if self.hostname: if self.hostname:
self.hostname = encode_dnsname(self.hostname) self.hostname = encode_dnsname(self.hostname)
if self.zone: if self.zone:
self.zone = encode_dnsname(self.zone) self.zone = encode_dnsname(self.zone)
self.last_ipv4_update = datetime_naive_utc(self.last_ipv4_update)
self.last_ipv6_update = datetime_naive_utc(self.last_ipv6_update)
return super().save(*args, **kwargs) return super().save(*args, **kwargs)
@@ -273,7 +291,7 @@ def create_tables():
logging.debug("Database tables created") logging.debug("Database tables created")
def get_user(username: str): def get_user(username: str) -> User:
""" """
Get user by username. Get user by username.
@@ -294,7 +312,7 @@ def get_user(username: str):
return User.get(User.username == username) return User.get(User.username == username)
def get_hostname(hostname, zone): def get_hostname(hostname: str, zone: str) -> Hostname:
""" """
Get hostname by name and zone. Get hostname by name and zone.
@@ -307,6 +325,7 @@ def get_hostname(hostname, zone):
Raises: Raises:
DoesNotExist: If hostname not found. DoesNotExist: If hostname not found.
EncodingError: If hostname or zone is invalid.
Example: Example:
>>> host = get_hostname("myhost", "example.com") >>> host = get_hostname("myhost", "example.com")
@@ -314,7 +333,8 @@ def get_hostname(hostname, zone):
'192.168.1.1' '192.168.1.1'
""" """
return Hostname.get( return Hostname.get(
(Hostname.hostname == hostname) & (Hostname.zone == zone) (Hostname.hostname == encode_dnsname(hostname)) &
(Hostname.zone == encode_dnsname(zone))
) )
@@ -331,10 +351,14 @@ def get_hostname_for_user(hostname: str, user: User):
Raises: Raises:
DoesNotExist: If hostname not found or not owned by user. DoesNotExist: If hostname not found or not owned by user.
EncodingError: If hostname is invalid.
Example: Example:
>>> user = get_user("alice") >>> user = get_user("alice")
>>> host = get_hostname_for_user("myhost.example.com", user) >>> host = get_hostname_for_user("myhost.example.com", user)
""" """
fqdn = fn.Concat(Hostname.hostname, '.', Hostname.zone) fqdn = fn.Concat(Hostname.hostname, '.', Hostname.zone)
return Hostname.get((fqdn == hostname) & (Hostname.user == user)) return Hostname.get(
(fqdn == encode_dnsname(hostname)) &
(Hostname.user == user)
)

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import argon2
import base64 import base64
import ipaddress import ipaddress
import json import json
@@ -9,15 +10,10 @@ import logging
import signal import signal
import ssl import ssl
import threading import threading
from concurrent.futures import ThreadPoolExecutor
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import parse_qs, urlparse
import argon2
from . import ( from . import (
now_utc,
datetime_str, datetime_str,
utc_now,
STATUS_GOOD, STATUS_GOOD,
STATUS_NOCHG, STATUS_NOCHG,
STATUS_BADAUTH, STATUS_BADAUTH,
@@ -27,9 +23,18 @@ from . import (
STATUS_BADIP, STATUS_BADIP,
) )
from .cleanup import ExpiredRecordsCleanupThread, RateLimitCleanupThread from .cleanup import ExpiredRecordsCleanupThread, RateLimitCleanupThread
from .dns import detect_ip_type
from .logging import clear_txn_id, set_txn_id from .logging import clear_txn_id, set_txn_id
from .models import DoesNotExist, get_hostname_for_user, get_user from .models import (
from .dns import detect_ip_type, encode_dnsname, EncodingError DatabaseError,
DoesNotExist,
EncodingError,
get_hostname_for_user,
get_user
)
from concurrent.futures import ThreadPoolExecutor
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import parse_qs, urlparse
def extract_param(params, aliases): def extract_param(params, aliases):
@@ -123,7 +128,9 @@ class DDNSServer(ThreadingHTTPServer):
logging.info(f"Waiting for {self.active_requests} active request(s)") logging.info(f"Waiting for {self.active_requests} active request(s)")
self.requests_done.wait(timeout=timeout) self.requests_done.wait(timeout=timeout)
if self.active_requests > 0: if self.active_requests > 0:
logging.warning(f"Shutdown timeout, {self.active_requests} request(s) still active") logging.warning(
f"Shutdown timeout, {self.active_requests} request(s) still active"
)
def server_close(self): def server_close(self):
"""Shutdown thread pool and close server.""" """Shutdown thread pool and close server."""
@@ -212,6 +219,9 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
set_txn_id() set_txn_id()
try: try:
self._handle_get_request() self._handle_get_request()
except Exception as e:
logging.exception(f"Uncaught exception: {e}")
self.respond(500, "Internal Server Error")
finally: finally:
clear_txn_id() clear_txn_id()
@@ -258,6 +268,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
return return
# Validate credentials # Validate credentials
try:
try: try:
user = get_user(username) user = get_user(username)
self.app.password_hasher.verify(user.password_hash, password) self.app.password_hasher.verify(user.password_hash, password)
@@ -273,16 +284,6 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
self._handle_bad_request(client_ip, 400, STATUS_NOHOST) self._handle_bad_request(client_ip, 400, STATUS_NOHOST)
return return
# Validate and encode hostname
try:
hostname_param = encode_dnsname(hostname_param)
except EncodingError:
logging.warning(
f"Invalid hostname: client={client_ip}, "
f"hostname={hostname_param}")
self._handle_bad_request(client_ip, 400, STATUS_NOHOST)
return
# Check hostname ownership # Check hostname ownership
try: try:
hostname = get_hostname_for_user(hostname_param, user) hostname = get_hostname_for_user(hostname_param, user)
@@ -293,6 +294,17 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
) )
self._handle_bad_request(client_ip, 403, STATUS_NOHOST) self._handle_bad_request(client_ip, 403, STATUS_NOHOST)
return return
except EncodingError:
logging.warning(
f"Invalid hostname: client={client_ip}, "
f"hostname={hostname_param}")
self._handle_bad_request(client_ip, 400, STATUS_NOHOST)
return
except DatabaseError as e:
logging.error(f"Database error: {e}")
self.respond(500, "Internal Server Error")
return
# Good rate limit check # Good rate limit check
if self.app.good_limiter: if self.app.good_limiter:
@@ -361,7 +373,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
except ValueError: except ValueError:
return (400, STATUS_BADIP, {}) return (400, STATUS_BADIP, {})
now = utc_now() now = now_utc()
ipv4_changed = False ipv4_changed = False
ipv6_changed = False ipv6_changed = False