Use db model hostname validation in cli and improve exception handling
This commit is contained in:
@@ -1,18 +1,19 @@
|
||||
"""CLI commands for user and hostname management."""
|
||||
|
||||
import getpass
|
||||
import logging
|
||||
|
||||
from . import datetime_str
|
||||
from .cleanup import cleanup_expired
|
||||
from .dns import encode_dnsname
|
||||
from .models import (
|
||||
DatabaseError,
|
||||
DoesNotExist,
|
||||
EncodingError,
|
||||
get_hostname,
|
||||
get_user,
|
||||
Hostname,
|
||||
User,
|
||||
)
|
||||
from .dns import encode_dnsname, EncodingError
|
||||
|
||||
|
||||
def cmd_user_list(args, app):
|
||||
@@ -166,24 +167,8 @@ def cmd_hostname_list(args, app):
|
||||
def cmd_hostname_add(args, app):
|
||||
"""Add a hostname."""
|
||||
username = args.username
|
||||
config = app.config
|
||||
|
||||
# Validate and encode hostname/zone
|
||||
try:
|
||||
hostname_str = encode_dnsname(args.hostname)
|
||||
zone = encode_dnsname(args.zone)
|
||||
except EncodingError as e:
|
||||
print(f"Error: {e}")
|
||||
return 1
|
||||
|
||||
# Get TTLs from args or config defaults
|
||||
dns_ttl = args.dns_ttl
|
||||
if dns_ttl is None:
|
||||
dns_ttl = config["defaults"]["dns_ttl"]
|
||||
expiry_ttl = args.expiry_ttl
|
||||
if expiry_ttl is None:
|
||||
expiry_ttl = config["defaults"]["expiry_ttl"]
|
||||
|
||||
# Get user
|
||||
try:
|
||||
user = get_user(username)
|
||||
@@ -192,40 +177,59 @@ def cmd_hostname_add(args, app):
|
||||
return 1
|
||||
|
||||
# Check if hostname+zone exists
|
||||
if Hostname.select().where(
|
||||
(Hostname.hostname == hostname_str) & (Hostname.zone == zone)
|
||||
).exists():
|
||||
print(f"Error: Hostname '{hostname_str}' in zone '{zone}' exists.")
|
||||
try:
|
||||
hostname = get_hostname(args.hostname, args.zone)
|
||||
print(f"Error: Hostname '{hostname.hostname}' in zone '{hostname.zone}' exists.")
|
||||
return 1
|
||||
except EncodingError as e:
|
||||
print(f"Error: {e}")
|
||||
return 1
|
||||
except DoesNotExist:
|
||||
pass
|
||||
|
||||
# Get TTLs from args or config defaults
|
||||
config = app.config
|
||||
|
||||
dns_ttl = args.dns_ttl
|
||||
if dns_ttl is None:
|
||||
dns_ttl = config["defaults"]["dns_ttl"]
|
||||
expiry_ttl = args.expiry_ttl
|
||||
if expiry_ttl is None:
|
||||
expiry_ttl = config["defaults"]["expiry_ttl"]
|
||||
|
||||
# Create hostname
|
||||
Hostname.create(
|
||||
hostname = Hostname.create(
|
||||
user=user,
|
||||
hostname=hostname_str,
|
||||
zone=zone,
|
||||
hostname=args.hostname,
|
||||
zone=args.zone,
|
||||
dns_ttl=dns_ttl,
|
||||
expiry_ttl=expiry_ttl
|
||||
)
|
||||
print(f"Hostname '{hostname_str}' added for user '{username}'.")
|
||||
print(
|
||||
f"Hostname '{hostname.hostname}' in zone '{hostname.zone}' added "
|
||||
f"for user '{username}'."
|
||||
)
|
||||
except DatabaseError as e:
|
||||
print(f"Database error: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_hostname_delete(args, app):
|
||||
"""Delete a hostname."""
|
||||
# Validate and encode hostname and zone
|
||||
try:
|
||||
hostname_str = encode_dnsname(args.hostname)
|
||||
try:
|
||||
hostname = get_hostname(args.hostname, args.zone)
|
||||
except DoesNotExist:
|
||||
hostname = encode_dnsname(args.hostname)
|
||||
zone = encode_dnsname(args.zone)
|
||||
print(f"Error: Hostname '{hostname}' in zone '{zone}' not found.")
|
||||
return 1
|
||||
except EncodingError as e:
|
||||
print(f"Error: {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
hostname = get_hostname(hostname_str, zone)
|
||||
except DoesNotExist:
|
||||
print(f"Error: Hostname '{hostname_str}' in zone '{zone}' not found.")
|
||||
return 1
|
||||
|
||||
# Delete DNS records if active
|
||||
if hostname.last_ipv4 or hostname.last_ipv6:
|
||||
# Initialize DNS service if not already
|
||||
@@ -233,45 +237,50 @@ def cmd_hostname_delete(args, app):
|
||||
try:
|
||||
app.init_dns()
|
||||
except Exception as e:
|
||||
logging.warning(f"DNS init failed: {e}")
|
||||
print(f"DNS init failed: {e}")
|
||||
return 1
|
||||
|
||||
if app.dns_service:
|
||||
if hostname.last_ipv4:
|
||||
try:
|
||||
app.dns_service.delete_record(
|
||||
hostname.hostname, hostname.zone, "A"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"DNS delete failed: type=A error={e}")
|
||||
print(f"DNS delete failed: type=A error={e}")
|
||||
return 1
|
||||
|
||||
if hostname.last_ipv6:
|
||||
try:
|
||||
app.dns_service.delete_record(
|
||||
hostname.hostname, hostname.zone, "AAAA"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"DNS delete failed: type=AAAA error={e}")
|
||||
print(f"DNS delete failed: type=AAAA error={e}")
|
||||
return 1
|
||||
|
||||
hostname.delete_instance()
|
||||
print(f"Hostname '{hostname_str}' in zone '{zone}' deleted.")
|
||||
print(f"Hostname '{hostname.hostname}' in zone '{hostname.zone}' deleted.")
|
||||
except DatabaseError as e:
|
||||
print(f"Database error: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_hostname_modify(args, app):
|
||||
"""Modify hostname settings."""
|
||||
# Validate and encode hostname and zone
|
||||
try:
|
||||
hostname_str = encode_dnsname(args.hostname)
|
||||
try:
|
||||
hostname = get_hostname(args.hostname, args.zone)
|
||||
except DoesNotExist:
|
||||
hostname = encode_dnsname(args.hostname)
|
||||
zone = encode_dnsname(args.zone)
|
||||
print(f"Error: Hostname '{hostname}' in zone '{zone}' not found.")
|
||||
return 1
|
||||
except EncodingError as e:
|
||||
print(f"Error: {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
hostname = get_hostname(hostname_str, zone)
|
||||
except DoesNotExist:
|
||||
print(f"Error: Hostname '{hostname_str}' in zone '{zone}' not found.")
|
||||
return 1
|
||||
|
||||
# Get new TTLs
|
||||
dns_ttl = args.dns_ttl if args.dns_ttl is not None else hostname.dns_ttl
|
||||
expiry_ttl = args.expiry_ttl if args.expiry_ttl is not None else hostname.expiry_ttl
|
||||
@@ -280,9 +289,13 @@ def cmd_hostname_modify(args, app):
|
||||
hostname.expiry_ttl = expiry_ttl
|
||||
hostname.save()
|
||||
print(
|
||||
f"Hostname '{hostname_str}' updated: "
|
||||
f"Hostname '{hostname.hostname}' in zone '{hostname.zone}' updated: "
|
||||
f"dns_ttl={dns_ttl}, expiry_ttl={expiry_ttl}"
|
||||
)
|
||||
except DatabaseError as e:
|
||||
print(f"Database error: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
@@ -293,11 +306,20 @@ def cmd_cleanup(args, app):
|
||||
try:
|
||||
app.init_dns()
|
||||
except Exception as e:
|
||||
logging.warning(f"DNS init failed: {e}")
|
||||
print(f"DNS init failed: {e}")
|
||||
return 1
|
||||
|
||||
if app.email_service is None:
|
||||
app.init_email()
|
||||
|
||||
try:
|
||||
count = cleanup_expired(app)
|
||||
print(f"Cleanup complete: {count} expired hostname(s) processed.")
|
||||
except DatabaseError as e:
|
||||
print(f"Database error: {e}")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
@@ -16,10 +16,12 @@ from peewee import (
|
||||
IntegerField,
|
||||
Model,
|
||||
SqliteDatabase,
|
||||
PeeweeException as DatabaseError,
|
||||
)
|
||||
from playhouse.pool import PooledMySQLDatabase
|
||||
|
||||
# Re-export DoesNotExist and EncodingError for convenience
|
||||
# Re-export PeeweeException as DatabseException, DoesNotExist and
|
||||
# EncodingError for convenience
|
||||
__all__ = [
|
||||
'db',
|
||||
'DATABASE_VERSION',
|
||||
@@ -33,6 +35,7 @@ __all__ = [
|
||||
'get_hostname_for_user',
|
||||
'DoesNotExist',
|
||||
'EncodingError',
|
||||
'DatabaseError',
|
||||
]
|
||||
|
||||
# Database proxy (initialized later with actual backend)
|
||||
@@ -273,7 +276,7 @@ def create_tables():
|
||||
logging.debug("Database tables created")
|
||||
|
||||
|
||||
def get_user(username: str):
|
||||
def get_user(username: str) -> User:
|
||||
"""
|
||||
Get user by username.
|
||||
|
||||
@@ -294,7 +297,7 @@ def get_user(username: str):
|
||||
return User.get(User.username == username)
|
||||
|
||||
|
||||
def get_hostname(hostname, zone):
|
||||
def get_hostname(hostname: str, zone: str) -> Hostname:
|
||||
"""
|
||||
Get hostname by name and zone.
|
||||
|
||||
@@ -307,6 +310,7 @@ def get_hostname(hostname, zone):
|
||||
|
||||
Raises:
|
||||
DoesNotExist: If hostname not found.
|
||||
EncodingError: If hostname or zone is invalid.
|
||||
|
||||
Example:
|
||||
>>> host = get_hostname("myhost", "example.com")
|
||||
@@ -314,7 +318,8 @@ def get_hostname(hostname, zone):
|
||||
'192.168.1.1'
|
||||
"""
|
||||
return Hostname.get(
|
||||
(Hostname.hostname == hostname) & (Hostname.zone == zone)
|
||||
(Hostname.hostname == encode_dnsname(hostname)) &
|
||||
(Hostname.zone == encode_dnsname(zone))
|
||||
)
|
||||
|
||||
|
||||
@@ -331,10 +336,14 @@ def get_hostname_for_user(hostname: str, user: User):
|
||||
|
||||
Raises:
|
||||
DoesNotExist: If hostname not found or not owned by user.
|
||||
EncodingError: If hostname is invalid.
|
||||
|
||||
Example:
|
||||
>>> user = get_user("alice")
|
||||
>>> host = get_hostname_for_user("myhost.example.com", user)
|
||||
"""
|
||||
fqdn = fn.Concat(Hostname.hostname, '.', Hostname.zone)
|
||||
return Hostname.get((fqdn == hostname) & (Hostname.user == user))
|
||||
return Hostname.get(
|
||||
(fqdn == encode_dnsname(hostname)) &
|
||||
(Hostname.user == user)
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argon2
|
||||
import base64
|
||||
import ipaddress
|
||||
import json
|
||||
@@ -9,11 +10,6 @@ import logging
|
||||
import signal
|
||||
import ssl
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import argon2
|
||||
|
||||
from . import (
|
||||
datetime_str,
|
||||
@@ -27,9 +23,18 @@ from . import (
|
||||
STATUS_BADIP,
|
||||
)
|
||||
from .cleanup import ExpiredRecordsCleanupThread, RateLimitCleanupThread
|
||||
from .dns import detect_ip_type
|
||||
from .logging import clear_txn_id, set_txn_id
|
||||
from .models import DoesNotExist, get_hostname_for_user, get_user
|
||||
from .dns import detect_ip_type, encode_dnsname, EncodingError
|
||||
from .models import (
|
||||
DatabaseError,
|
||||
DoesNotExist,
|
||||
EncodingError,
|
||||
get_hostname_for_user,
|
||||
get_user
|
||||
)
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
|
||||
def extract_param(params, aliases):
|
||||
@@ -123,7 +128,9 @@ class DDNSServer(ThreadingHTTPServer):
|
||||
logging.info(f"Waiting for {self.active_requests} active request(s)")
|
||||
self.requests_done.wait(timeout=timeout)
|
||||
if self.active_requests > 0:
|
||||
logging.warning(f"Shutdown timeout, {self.active_requests} request(s) still active")
|
||||
logging.warning(
|
||||
f"Shutdown timeout, {self.active_requests} request(s) still active"
|
||||
)
|
||||
|
||||
def server_close(self):
|
||||
"""Shutdown thread pool and close server."""
|
||||
@@ -212,6 +219,9 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
|
||||
set_txn_id()
|
||||
try:
|
||||
self._handle_get_request()
|
||||
except Exception as e:
|
||||
logging.exception(f"Uncaught exception: {e}")
|
||||
self.respond(500, "Internal Server Error")
|
||||
finally:
|
||||
clear_txn_id()
|
||||
|
||||
@@ -258,6 +268,7 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
|
||||
return
|
||||
|
||||
# Validate credentials
|
||||
try:
|
||||
try:
|
||||
user = get_user(username)
|
||||
self.app.password_hasher.verify(user.password_hash, password)
|
||||
@@ -273,16 +284,6 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
|
||||
self._handle_bad_request(client_ip, 400, STATUS_NOHOST)
|
||||
return
|
||||
|
||||
# Validate and encode hostname
|
||||
try:
|
||||
hostname_param = encode_dnsname(hostname_param)
|
||||
except EncodingError:
|
||||
logging.warning(
|
||||
f"Invalid hostname: client={client_ip}, "
|
||||
f"hostname={hostname_param}")
|
||||
self._handle_bad_request(client_ip, 400, STATUS_NOHOST)
|
||||
return
|
||||
|
||||
# Check hostname ownership
|
||||
try:
|
||||
hostname = get_hostname_for_user(hostname_param, user)
|
||||
@@ -293,6 +294,17 @@ class DDNSRequestHandler(BaseHTTPRequestHandler):
|
||||
)
|
||||
self._handle_bad_request(client_ip, 403, STATUS_NOHOST)
|
||||
return
|
||||
except EncodingError:
|
||||
logging.warning(
|
||||
f"Invalid hostname: client={client_ip}, "
|
||||
f"hostname={hostname_param}")
|
||||
self._handle_bad_request(client_ip, 400, STATUS_NOHOST)
|
||||
return
|
||||
|
||||
except DatabaseError as e:
|
||||
logging.error(f"Database error: {e}")
|
||||
self.respond(500, "Internal Server Error")
|
||||
return
|
||||
|
||||
# Good rate limit check
|
||||
if self.app.good_limiter:
|
||||
|
||||
Reference in New Issue
Block a user