Files

456 lines
13 KiB
Python

"""DNS operations using RFC 2136 dynamic updates."""
import ipaddress
import logging
import os
import re
import stat
import dns.name
import dns.query
import dns.rcode
import dns.rdatatype
import dns.resolver
import dns.tsigkeyring
import dns.update
# Valid hostname label pattern (after punycode encoding)
LABEL_PATTERN = re.compile(
r'^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?$', re.IGNORECASE
)
class EncodingError(Exception):
"""Raised when hostname encoding fails."""
pass
def encode_dnsname(hostname):
"""
Encode hostname to ASCII using punycode (IDNA).
Args:
hostname: Hostname string, possibly with unicode characters.
Returns:
ASCII-encoded hostname.
Raises:
EncodingError: If hostname is invalid.
Example:
>>> encode_dnsname("münchen")
'xn--mnchen-3ya'
>>> encode_dnsname("example.com.")
'example.com'
"""
hostname = hostname.lower().strip()
if not hostname:
raise EncodingError("Hostname cannot be empty")
# Remove trailing dot if present
if hostname.endswith('.'):
hostname = hostname[:-1]
if len(hostname) > 253:
raise EncodingError("Hostname too long (max 253 characters)")
try:
# Encode each label using IDNA
labels = hostname.split('.')
encoded_labels = []
for label in labels:
if not label:
raise EncodingError("Empty label in hostname")
# Encode to punycode if needed
try:
encoded = label.encode('idna').decode('ascii')
except UnicodeError as e:
raise EncodingError(f"Invalid label '{label}': {e}")
if len(encoded) > 63:
raise EncodingError(
f"Label '{label}' too long (max 63 characters)"
)
if not LABEL_PATTERN.match(encoded):
raise EncodingError(f"Invalid label format: '{label}'")
encoded_labels.append(encoded)
return '.'.join(encoded_labels)
except EncodingError:
raise
except Exception as e:
raise EncodingError(f"Invalid hostname '{hostname}': {e}")
def detect_ip_type(ip):
"""
Detect IP address type and normalize.
Args:
ip: IP address string.
Returns:
Tuple of (record_type, normalized_ip).
Raises:
ValueError: If IP address is invalid.
Example:
>>> detect_ip_type("192.168.1.1")
('A', '192.168.1.1')
>>> detect_ip_type("2001:db8::1")
('AAAA', '2001:db8::1')
"""
try:
addr = ipaddress.ip_address(ip)
if isinstance(addr, ipaddress.IPv4Address):
rdtype = 'A'
else:
rdtype = 'AAAA'
return (rdtype, str(addr))
except ValueError:
raise ValueError(f"Invalid IP address: {ip}")
class DNSError(Exception):
"""Raised when DNS operations fail."""
pass
# TSIG algorithm name mapping (BIND short names -> dnspython FQDN form)
TSIG_ALGORITHMS = {
"hmac-md5": "hmac-md5.sig-alg.reg.int.",
"hmac-sha1": "hmac-sha1.",
"hmac-sha224": "hmac-sha224.",
"hmac-sha256": "hmac-sha256.",
"hmac-sha384": "hmac-sha384.",
"hmac-sha512": "hmac-sha512.",
}
def parse_bind_key_file(path):
"""
Parse BIND TSIG key file.
Format: key "keyname" { algorithm hmac-sha256; secret "base64..."; };
Args:
path: Path to key file, or None.
Returns:
Tuple of (keyring, algorithm) or (None, None) if path is None.
Raises:
DNSError: If parsing fails.
Example:
Key file contents::
key "ddns-key." {
algorithm hmac-sha256;
secret "base64secret==";
};
>>> keyring, algo = parse_bind_key_file("/etc/bind/ddns.key")
"""
if not path:
return None, None
try:
# Check file permissions
file_stat = os.stat(path)
if file_stat.st_mode & stat.S_IROTH:
logging.warning(f"TSIG key file is world-readable: {path}")
with open(path, "r") as f:
content = f.read()
# Extract key name
name_match = re.search(r'key\s+"([^"]+)"', content)
if not name_match:
raise DNSError(f"Invalid key file {path}: no key name found")
key_name = name_match.group(1)
# Ensure key name ends with dot for FQDN
if not key_name.endswith("."):
key_name = key_name + "."
# Extract algorithm
algo_match = re.search(r'algorithm\s+([^;]+);', content)
if not algo_match:
raise DNSError(f"Invalid key file {path}: no algorithm found")
algo_str = algo_match.group(1).strip().lower()
# Map to dnspython algorithm name
algorithm = TSIG_ALGORITHMS.get(algo_str)
if not algorithm:
raise DNSError(f"Unsupported TSIG algorithm in {path}: {algo_str}")
# Extract secret
secret_match = re.search(r'secret\s+"([^"]+)"', content)
if not secret_match:
raise DNSError(f"Invalid key file {path}: no secret found")
secret = secret_match.group(1)
keyring = dns.tsigkeyring.from_text({key_name: secret})
return keyring, dns.name.from_text(algorithm)
except DNSError:
raise
except Exception as e:
raise DNSError(f"Failed to parse key file {path}: {e}")
class DNSService:
"""DNS service for RFC 2136 dynamic updates."""
def __init__(self, config):
"""
Initialize DNS service.
Args:
config: Application configuration dictionary.
Raises:
DNSError: If initialization fails.
"""
try:
dns_cfg = config["dns_service"]
self.server = dns_cfg["dns_server"]
self.port = dns_cfg.get("dns_port", 53)
self.timeout = dns_cfg.get("dns_timeout", 5)
# Parse default TSIG key
default_key_file = dns_cfg.get("ddns_default_key_file")
self.default_keyring, self.default_algorithm = parse_bind_key_file(
default_key_file
)
# Parse per-zone TSIG keys
self.zone_keys = {}
zone_keys_cfg = dns_cfg.get("zone_keys", {})
for zone, key_path in zone_keys_cfg.items():
keyring, algorithm = parse_bind_key_file(key_path)
# Normalize zone name
if not zone.endswith("."):
zone = zone + "."
self.zone_keys[zone] = (keyring, algorithm)
if self.default_keyring or self.zone_keys:
logging.debug(
f"DNS service initialized: server={self.server}:{self.port} "
f"tsig=enabled zones={len(self.zone_keys)}"
)
else:
logging.debug(
f"DNS service initialized: server={self.server}:{self.port} "
f"tsig=disabled"
)
except DNSError:
raise
except Exception as e:
raise DNSError(f"Failed to initialize DNS service: {e}")
def _get_key_for_zone(self, zone):
"""
Get TSIG key for a zone.
Args:
zone: Zone name (will be normalized to FQDN).
Returns:
Tuple of (keyring, algorithm) or (None, None) for unauthenticated.
"""
# Normalize zone name
if not zone.endswith("."):
zone = zone + "."
# Check zone-specific key first
if zone in self.zone_keys:
return self.zone_keys[zone]
# Fall back to default key
return self.default_keyring, self.default_algorithm
def _make_update(self, zone):
"""
Create DNS update message for zone.
Args:
zone: Zone name.
Returns:
dns.update.Update message object.
"""
# Normalize zone name
if not zone.endswith("."):
zone = zone + "."
keyring, algorithm = self._get_key_for_zone(zone)
if keyring:
return dns.update.Update(
zone,
keyring=keyring,
keyalgorithm=algorithm
)
else:
return dns.update.Update(zone)
def _send_update(self, update):
"""
Send DNS update message to server.
Args:
update: dns.update.Update message object.
Raises:
DNSError: If update fails.
"""
try:
response = dns.query.tcp(
update,
self.server,
port=self.port,
timeout=self.timeout
)
rcode = response.rcode()
if rcode != dns.rcode.NOERROR:
raise DNSError(f"DNS update failed: {dns.rcode.to_text(rcode)}")
except dns.exception.Timeout:
raise DNSError(f"DNS update timeout: {self.server}:{self.port}")
except DNSError:
raise
except Exception as e:
raise DNSError(f"DNS update failed: {e}")
def _get_relative_name(self, hostname, zone):
"""
Get hostname relative to zone.
Args:
hostname: Fully qualified hostname.
zone: Zone name.
Returns:
Relative name for use in DNS update.
"""
# Strip zone suffix to get relative name
zone_suffix = "." + zone
if hostname.endswith(zone_suffix):
return hostname[:-len(zone_suffix)]
return hostname
def query_record(self, hostname, zone, record_type):
"""
Check if DNS record exists.
Args:
hostname: Hostname (without zone suffix).
zone: DNS zone name.
record_type: Record type string (A or AAAA).
Returns:
IP address string if record exists, None otherwise.
"""
fqdn = f"{self._get_relative_name(hostname, zone)}.{zone}"
if not fqdn.endswith("."):
fqdn += "."
try:
resolver = dns.resolver.Resolver()
resolver.nameservers = [self.server]
resolver.port = self.port
resolver.lifetime = self.timeout
answers = resolver.resolve(fqdn, record_type)
return str(answers[0]) if answers else None
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer,
dns.resolver.NoNameservers):
return None
except Exception as e:
logging.warning(
f"DNS query failed: hostname={hostname} zone={zone} "
f"type={record_type}: {e}"
)
return None
def update_record(self, hostname, zone, ip, ttl):
"""
Update a DNS record for the given hostname.
Args:
hostname: Hostname (without zone suffix).
zone: DNS zone name.
ip: IP address to set.
ttl: DNS record TTL.
Raises:
DNSError: If update fails.
Example:
>>> dns_service.update_record("myhost", "example.com", "192.168.1.1", 60)
>>> dns_service.update_record("myhost", "example.com", "2001:db8::1", 60)
"""
try:
record_type, normalized_ip = detect_ip_type(ip)
name = self._get_relative_name(hostname, zone)
update = self._make_update(zone)
update.replace(name, ttl, record_type, normalized_ip)
self._send_update(update)
logging.debug(
f"DNS record updated: hostname={hostname} zone={zone} "
f"type={record_type} ip={normalized_ip} ttl={ttl}"
)
except DNSError:
raise
except Exception as e:
raise DNSError(f"Failed to update DNS record for {hostname}: {e}")
def delete_record(self, hostname, zone, record_type):
"""
Delete DNS record(s) for the given hostname and record type.
Args:
hostname: Fully qualified hostname.
zone: DNS zone name.
record_type: Record type (A or AAAA).
Returns:
True (for compatibility).
Raises:
DNSError: If delete fails.
"""
try:
name = self._get_relative_name(hostname, zone)
update = self._make_update(zone)
update.delete(name, record_type)
self._send_update(update)
logging.debug(
f"DNS record deleted: hostname={hostname} zone={zone} "
f"type={record_type}"
)
return True
except DNSError:
raise
except Exception as e:
raise DNSError(f"Failed to delete DNS record for {hostname}: {e}")