Files
dns-manager/dnsmgr/__init__.py

782 lines
25 KiB
Python

#!/usr/bin/env python3
import dns.name
import dns.tsig
import dns.rcode
import dns.rdataclass
import dns.rdataset
import dns.rdatatype
import dns.query
import dns.update
import dns.zone
import dns.xfr
import os
import re
import shutil
import subprocess
import yaml
from hashlib import sha1
from prettytable import PrettyTable
from shutil import chown
DEFAULT_CFGFILE = '/etc/dns-manager/config.yml'
NAMED_DEFAULT_VIEW = '_default'
RECORD_TYPES = (
dns.rdatatype.A,
dns.rdatatype.AAAA,
dns.rdatatype.CAA,
dns.rdatatype.CDS,
dns.rdatatype.CNAME,
dns.rdatatype.DNAME,
dns.rdatatype.DS,
dns.rdatatype.MX,
dns.rdatatype.NS,
dns.rdatatype.PTR,
dns.rdatatype.SRV,
dns.rdatatype.TLSA,
dns.rdatatype.TXT)
ALL_RECORD_TYPES = (dns.rdatatype.from_text(rdtype.name) for rdtype in dns.rdatatype.RdataType)
def printe(msg):
print(f'ERROR: {msg}')
def prettytable(field_names, rows, truncate=False):
t = PrettyTable()
t.field_names = field_names
for field in t.field_names:
t.align[field] = 'l'
if truncate and rows:
n_cols = len(rows[0])
max_col_lengths = [0] * n_cols
for row in rows:
for index in range(0, len(row)):
max_col_lengths[index] = max(max_col_lengths[index], len(row[index]) + 3)
max_total_len = sum(max_col_lengths) + 2
max_pre_len = sum(max_col_lengths[0:-1]) + 2
terminal_width = os.get_terminal_size().columns
if max_total_len >= terminal_width:
max_value_len = terminal_width - max_pre_len - 3
if max_value_len < 5:
raise RuntimeError('terminal is too small')
for i in range(len(rows)):
value = rows[i][-1]
if len(value) > max_value_len:
rows[i][-1] = f'{value[:max_value_len - 3]}...'
for row in rows:
t.add_row(row)
return t
def prettyselect(field_names, rows, prompt='Select entry', also_valid=[], truncate=False):
length = len(rows)
if length < 1:
raise RuntimeError('no entries to select from')
field_names.insert(0, '#')
for index in range(length):
rows[index].insert(0, str(index + 1))
print(prettytable(field_names, rows, truncate))
print()
valid_str = f'1 - {length}'
if also_valid:
also_valid_str = ', '.join(also_valid)
valid_str = f'{valid_str}, {also_valid_str}'
index = None
while index is None:
try:
index = input(f'{prompt} ({valid_str}): ')
except KeyboardInterrupt:
print('\nAborted ...')
raise KeyboardInterrupt
if index in also_valid:
print()
return index
try:
index = int(index) - 1
if index < 0 or index > length - 1:
raise ValueError
except ValueError:
index = None
print()
return index
def name_from_text(txt, origin=None):
if not txt:
raise RuntimeError('empty value')
if origin is None:
txt = txt.lower()
if not txt.endswith('.'):
txt += '.'
elif txt.endswith('.'):
raise RuntimeError('record name is absolute (ends with dot)')
try:
name = dns.name.from_text(txt, origin)
except Exception as e:
raise RuntimeError(f'{e}')
return name
def name_views_from_text(txt):
try:
(name, view_txt) = txt.split('@', maxsplit=1)
if not view_txt:
view_txt = None
except ValueError:
name = txt
view_txt = None
name = name_from_text(name)
if view_txt is None:
views = None
else:
views = list(set([view.strip() for view in view_txt.split(',')]))
if '*' in views:
views = '*'
return name, views
def input_name(origin=None, prompt='Zone name'):
name = None
while name is None:
try:
value = input(f'{prompt}: ')
if not value:
continue
name = name_from_text(value, origin)
print()
except KeyboardInterrupt:
print('\nAborted ...')
raise KeyboardInterrupt
except RuntimeError as e:
print(f'ERROR: {e}\n')
name = None
return name
def ttl_from_text(txt):
try:
ttl = int(txt)
if ttl < 5:
raise RuntimeError('TTL is too low (<5 seconds)')
if ttl > 604800:
raise RuntimeError('TTL is too high (>604800 seconds)')
except ValueError as e:
raise RuntimeError(f'{e}')
return ttl
def input_ttl():
ttl = None
while ttl is None:
try:
value = input('TTL (5 - 604800): ')
if not value:
continue
ttl = ttl_from_text(value)
print()
except KeyboardInterrupt:
print('\nAborted ...')
raise KeyboardInterrupt
except RuntimeError as e:
print(f'ERROR: {e}\n')
ttl = None
return ttl
def type_from_text(txt, all_types=False):
try:
rdtype = dns.rdatatype.from_text(txt)
if not all_types and rdtype not in RECORD_TYPES:
raise RuntimeError('record type is not supported')
except Exception as e:
raise RuntimeError(f'{e}')
return rdtype
def select_type(all_types=False):
rdtypes = ALL_RECORD_TYPES if all_types else RECORD_TYPES
rows = sorted([[rdtype.to_text(rdtype)] for rdtype in rdtypes])
index = prettyselect(['Record type'], rows, prompt='Select record type')
return rdtypes[index]
def encode_txt_value(txt):
if '"' in txt:
return txt
txt = '" "'.join([txt[0+i:255+i] for i in range(0, len(txt), 255)])
return f'"{txt}"'
def rdata_from_text(rdtype, txt, origin):
if rdtype == dns.rdatatype.TXT:
txt = encode_txt_value(txt)
try:
rdata = dns.rdata.from_text(dns.rdataclass.IN, rdtype, txt, origin)
except Exception as e:
raise RuntimeError(f'{e}')
return rdata
def input_rdata(rdtype, origin):
rdata = None
while rdata is None:
try:
value = input('Record value: ')
if not value:
continue
rdata = rdata_from_text(rdtype, value, origin)
print()
except KeyboardInterrupt:
print('\nAborted ...')
raise KeyboardInterrupt
except RuntimeError as e:
print(f'ERROR: {e}\n')
rdata = None
return rdata
def input_yes_no(prompt='Confirm?'):
confirm = None
while confirm is None:
try:
value = input(f'{prompt} (yes/no): ').lower()
if value == 'yes':
confirm = True
print()
elif value == 'no':
confirm = False
else:
confirm = None
except KeyboardInterrupt:
print('\nAborted ...')
confirm = False
return confirm
class DNSViewConfig:
def __init__(self, name, config, config_dir):
if not isinstance(config, dict):
raise RuntimeError(f'views: {name}: value is not an associative array')
self.config_dir = config.get('config_dir', os.path.join(config_dir, f'{name}.zones'))
if not isinstance(self.config_dir, str):
raise RuntimeError(f'views: {name}: config_dir: value is not a string')
self.config_file = config.get('config_file')
if self.config_file is None:
raise RuntimeError(f'views: {name}: missing mandatory parameter: config_file')
if not isinstance(self.config_file, str):
raise RuntimeError(f'views: {name}: config_file: value is not a string')
self.zone_dir = config.get('zone_dir')
if self.zone_dir is None:
raise RuntimeError(f'views: {name}: missing mandatory parameter: zone_dir')
if not isinstance(self.zone_dir, str):
raise RuntimeError(f'views: {name}: zone_dir: value is not a string')
self.catalog_zone = config.get('catalog_zone')
if not isinstance(self.catalog_zone, (str, type(None))):
raise RuntimeError(f'views: {name}: catalog_zone: value is not a string')
self.config_template = None
self.zone_template = None
templates = config.get('templates')
if templates is not None:
if not isinstance(templates, dict):
raise RuntimeError(f'views: {name}: templates: value is not an associative array')
self.config_template = templates.get('config')
if not isinstance(self.config_template, (str, type(None))):
raise RuntimeError(f'views: {name}: templates: config: value is not a string')
self.zone_template = templates.get('zone')
if not isinstance(self.zone_template, (str, type(None))):
raise RuntimeError(f'views: {name}: templates: zone: value is not a string')
class DNSManagerConfig:
def __init__(self, cfgfile):
with open(cfgfile, 'r') as file:
config = yaml.safe_load(file)
if not isinstance(config, dict):
raise RuntimeError('config is not an associative array')
self.etc_dir = config.get('etc_dir', '/etc/dns-manager')
if not isinstance(self.etc_dir, str):
raise RuntimeError('etc_dir: value is not a string')
self.control_key = config.get('control_key')
if not isinstance(self.control_key, (str, type(None))):
raise RuntimeError('control_key: value is not a string')
self.dns_ip = config.get('dns_ip', '127.0.0.1')
if not isinstance(self.dns_ip, str):
raise RuntimeError('dns_ip value: is not a string')
self.named_checkconf = config.get('named_checkconf', None)
if self.named_checkconf is None:
self.named_checkconf = shutil.which('named-checkconf')
if self.named_checkconf is None:
raise RuntimeError('named-checkconf: executable not found')
elif not isinstance(self.named_checkconf, str):
raise RuntimeError('named_checkconf: value is not a string')
self.named_conf = config.get('named_conf')
if not isinstance(self.named_conf, (str, type(None))):
raise RuntimeError('named_conf: value is not a string')
self.rndc = config.get('rndc')
if not isinstance(self.rndc, (str, type(None))):
raise RuntimeError('rndc: value is not a string')
self.dns_keyfiles = config.get('dns_keyfiles')
if self.dns_keyfiles is None:
self.dns_keyfiles = {}
elif not isinstance(self.dns_keyfiles, (dict)):
raise RuntimeError('dns_keyfiles: value is not an associative array')
for view in self.dns_keyfiles:
if not isinstance(self.dns_keyfiles[view], str):
raise RuntimeError(f'dns_keyfiles: {view}: value is not a string')
self.zones_config = {}
zones_config = config.get('zones_config')
if zones_config is not None:
if not isinstance(zones_config, dict):
raise RuntimeError('zones_config: value is not a dictionary')
for name, view_config in zones_config.items():
self.zones_config[name] = DNSViewConfig(name, view_config, self.etc_dir)
class Zone(dns.zone.Zone):
def __init__(self, view=NAMED_DEFAULT_VIEW, status='unknown', cfgfile=None, zonefile=None, **kwargs):
self.view = view
self.status = status
self.cfgfile = cfgfile
self.zonefile = zonefile
super().__init__(**kwargs)
def filter_by_name(self, name, origin):
name_relative = name.relativize(origin)
self.nodes = dict((rdname, node) for rdname, node in self.items() if rdname == name_relative)
def filter_by_rdtype(self, rdtype):
for name, node in self.items():
rdataset = node.get_rdataset(dns.rdataclass.IN, rdtype)
node.rdatasets.clear()
if rdataset is not None:
node.rdatasets.append(rdataset)
self.nodes = dict((rdname, node) for rdname, node in self.items() if node.rdatasets)
def filter_by_rdata(self, rdata):
for name, node in self.items():
for rdataset in node:
rdataset.items = list(filter(lambda item: rdata == item, rdataset.items))
node.rdatasets = list(filter(lambda rdataset: rdataset.items, node.rdatasets))
self.nodes = dict((rdname, node) for rdname, node in self.items() if node.rdatasets)
def nfz(self):
return sha1(self.origin.to_wire()).hexdigest()
def named_zones(named_checkconf, named_conf=None):
cmd = [named_checkconf, '-l']
if named_conf is not None:
cmd.append(named_conf)
output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode()
zones = []
for line in output.splitlines():
try:
(name, zclass, view, status) = line.split()
except ValueError:
raise RuntimeError(f"named-checkconf returned invalid line: '{line}'")
if zclass.upper() == 'IN' and status.lower() in ('master', 'slave'):
zone = Zone(origin=name, view=view, status=status)
zones.append(zone)
return zones
def managed_zones(config, bind_zones=[]):
config_dirs = []
zone_dirs = []
zones = []
for view, cfg in config.items():
if cfg.config_dir in config_dirs:
raise RuntimeError(f'config directory used in multiple views: {cfg.config_dir}')
config_dirs.append(cfg.config_dir)
if cfg.zone_dir in zone_dirs:
raise RuntimeError(f'zone directory used in multiple views: {cfg.zone_dir}')
zone_dirs.append(cfg.zone_dir)
for file in os.listdir(cfg.config_dir):
if file.startswith('.') or not file.endswith('.conf'):
continue
cfgfile = os.path.join(cfg.config_dir, file)
if not os.path.isfile(cfgfile):
continue
name = file.removesuffix('.conf')
status = None
if bind_zones:
dns_name = dns.name.from_text(name)
try:
status = next(z.status for z in bind_zones if z.origin == dns_name)
except StopIteration:
status = None
zonefile = os.path.join(cfg.zone_dir, f'{name}.zone')
if not os.path.isfile(zonefile):
raise RuntimeError(f'missing zone file: {zonefile}')
zone = Zone(origin=name, view=view, status=status, cfgfile=cfgfile, zonefile=zonefile)
zones.append(zone)
return zones
def keys_from_file(path):
with open(path, 'r') as f:
content = f.read()
key_block_re = re.compile(r'^\s*key\s+"?(?P<name>[^"]+?)"?\s*{(?P<config>(.|\n)*?)}\s*;', re.MULTILINE)
secret_re = re.compile(r'(.|\n)*secret\s+"?(?P<secret>[^"]+?)"?\s*;', re.MULTILINE)
algorithm_re = re.compile(r'(.|\n)*algorithm\s+"?(?P<algorithm>[^"]+?)"?\s*;', re.MULTILINE)
matches = key_block_re.finditer(content)
if not matches:
raise RuntimeError(f'no key section found in config file: {path}')
keys = []
for match in matches:
groupdict = match.groupdict()
name = groupdict['name']
match = secret_re.match(groupdict['config'])
if not match:
raise RuntimeError(f"missing secret in config of key '{name}'")
secret = match.groupdict()['secret']
match = algorithm_re.match(groupdict['config'])
if match:
algorithm = match.groupdict()['algorithm']
else:
algorithm = None
keys.append(dns.tsig.Key(name, secret, algorithm=algorithm))
return keys
class DNSManager:
def __init__(self, cfgfile=DEFAULT_CFGFILE):
self.config = DNSManagerConfig(cfgfile)
self._bind_zones = None
self._zones = None
self._all_zones = None
@property
def bind_zones(self):
if self._bind_zones is None:
self._bind_zones = named_zones(named_checkconf=self.config.named_checkconf, named_conf=self.config.named_conf)
return self._bind_zones
@property
def zones(self):
if self._zones is None:
self._zones = managed_zones(self.config.zones_config, self.bind_zones)
return self._zones
@property
def all_zones(self):
if self._all_zones is None:
self._all_zones = self.zones
for zone in self.bind_zones:
if next((z for z in self.zones if z.origin == zone.origin and z.view == zone.view), None) is None:
self._all_zones.append(zone)
return self._all_zones
def get_zones(self, name_view, all_zones=False):
(dns_name, views) = name_views_from_text(name_view)
zone_base = self.all_zones if all_zones else self.zones
zones = list(filter(lambda z: z.origin == dns_name, zone_base))
if not zones:
raise RuntimeError('zone not found')
if views is None:
if len(zones) > 1:
raise RuntimeError('zone is part of multiple views')
elif zones[0].view != NAMED_DEFAULT_VIEW:
raise RuntimeError('zone is not part of the default view')
elif views != '*':
all_views = set([zone.view for zone in zone_base])
zone_views = [zone.view for zone in zones]
for view in views:
if view not in all_views:
raise RuntimeError(f"view does not exist -- '{view}'")
if view not in zone_views:
raise RuntimeError(f"zone is not part of view -- '{view}'")
zones = list(filter(lambda z: z.view in views, zones))
return zones
def select_zones(self, all_zones=False):
zones = self.all_zones if all_zones else self.zones
names = sorted(set([zone.origin.to_unicode(omit_final_dot=True) for zone in zones]))
rows = [[name] for name in names]
index = prettyselect(['Zone'], rows, prompt='Select zone')
name = names[index]
try:
selected_zones = self.get_zones(name, all_zones)
except ValueError:
dns_name = dns.name.from_text(name)
zones = list(filter(lambda z: z.origin == dns_name, zones))
views = [view for view in sorted(set([zone.view for zone in zones]))]
rows = [[view] for view in sorted(set([zone.view for zone in zones]))]
index = prettyselect(['View'], rows, prompt='Select view', also_valid=['*'])
if index == '*':
selected_zones = zones
else:
view = views[index]
selected_zones = list(filter(lambda z: z.view == view, zones))
return selected_zones
def generate_config(self, view):
view_cfg = self.config.zones_config[view]
try:
with open(view_cfg.config_file, 'w') as cfh:
for file in os.listdir(view_cfg.config_dir):
if file.startswith('.') or not file.endswith('.conf'):
continue
cfgfile = os.path.join(view_cfg.config_dir, file)
if not os.path.isfile(cfgfile):
continue
with open(cfgfile, 'r') as fh:
cfh.write(fh.read())
cfh.write('\n')
except Exception as e:
raise RuntimeError(f'unable to generate view config: {e}')
def named_reload(self):
rndc = self.config.rndc if self.config.rndc else shutil.which('rndc')
if rndc is None:
raise RuntimeError('rndc executable not found')
cmd = [rndc]
if self.config.control_key:
cmd.extend(['-k', self.config.control_key])
cmd.append('reconfig')
res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if res.returncode != 0:
raise RuntimeError(f'error reloading named config: {res.stderr.decode()}')
return res.stdout.decode()
def get_keyfile(self, zone):
name = zone.origin.to_text(omit_final_dot=True)
keyfile = None
for key_id in (f'{name}@{zone.view}', f'{zone.view}'):
try:
keyfile = self.config.dns_keyfiles[key_id]
break
except KeyError:
pass
if keyfile is None and zone.view != NAMED_DEFAULT_VIEW:
raise RuntimeError(f"no key configured for zone '{name}' in view '{zone.view}'")
return keyfile
def get_zone_content(self, zone):
keyfile = self.get_keyfile(zone)
keys = keys_from_file(keyfile)
key = keys[0] if keys else None
query, _ = dns.xfr.make_query(zone, keyring=key)
try:
dns.query.inbound_xfr(where=self.config.dns_ip, txn_manager=zone, query=query)
except Exception as e:
raise RuntimeError(e)
def add_zone_record(self, zone, rdname, rdataset):
keyfile = self.get_keyfile(zone)
keys = keys_from_file(keyfile)
key = keys[0] if keys else None
update = dns.update.Update(zone.origin, keyring=key)
update.add(rdname, rdataset)
try:
response = dns.query.tcp(update, self.config.dns_ip, timeout=10)
except Exception as e:
raise RuntimeError(e)
if response.rcode() != dns.rcode.NOERROR:
raise RuntimeError(response.to_text())
return response
def delete_zone_record(self, zone, rdname, rdataset):
keyfile = self.get_keyfile(zone)
keys = keys_from_file(keyfile)
key = keys[0] if keys else None
update = dns.update.Update(zone.origin, keyring=key)
update.delete(rdname, rdataset)
try:
response = dns.query.tcp(update, self.config.dns_ip, timeout=10)
except Exception as e:
raise RuntimeError(e)
if response.rcode() != dns.rcode.NOERROR:
raise RuntimeError(response.to_text())
return response
def add_zone(self, name, view, config_template=None, zone_template=None):
config = self.config.zones_config[view]
origin = name.to_text(omit_final_dot=True)
cfgfile = os.path.join(config.config_dir, f'{origin}.conf')
zonefile = os.path.join(config.zone_dir, f'{origin}.zone')
zone = Zone(origin=name, view=view, status=None, cfgfile=cfgfile, zonefile=zonefile)
if os.path.exists(cfgfile):
raise RuntimeError(f'config file already exists: {cfgfile}')
if os.path.exists(zonefile):
raise RuntimeError(f'zone file already exists: {zonefile}')
if config_template is None:
config_template = config.config_template
if config_template is None:
raise RuntimeError('no config template file configured')
if zone_template is None:
zone_template = config.zone_template
if zone_template is None:
raise RuntimeError('no zone template file configured')
try:
with open(config_template, 'r') as f:
zone_config = f.read()
except Exception as e:
raise RuntimeError(f'unable to open/read config template: {e}')
zone_config = zone_config.replace('%ZONE%', origin) \
.replace('%ZONE_FILE%', zonefile) \
.replace('%ZONE_FILENAME%', f'{origin}.zone')
try:
with open(cfgfile, 'w') as f:
f.write(zone_config)
except Exception as e:
raise RuntimeError(f'unable to open/write config file: {e}')
try:
with open(zone_template, 'r') as f:
zone_content = f.read()
except Exception as e:
os.remove(cfgfile)
raise RuntimeError(f'unable to open/read zone template: {e}')
zone_content = zone_content.replace('%ZONE%', origin)
try:
with open(zonefile, 'w') as f:
f.write(zone_content)
except Exception as e:
os.remove(cfgfile)
raise RuntimeError(f'unable to open/write zone file: {e}')
try:
chown(zonefile, 'named', 'named')
except Exception as e:
os.remove(cfgfile)
os.remove(zonefile)
raise RuntimeError(f'unable to change ownership of zone file: {e}')
return zone
def delete_zone(self, zone):
try:
os.remove(zone.cfgfile)
except Exception as e:
raise RuntimeError(f'unable to delete zone config file: {e}')
self.generate_config(zone.view)
def cleanup_zone(self, zone):
try:
os.remove(zone.zonefile)
zone_dir = self.config.zones_config[zone.view].zone_dir
for file in os.listdir(zone_dir):
file = os.path.join(zone_dir, file)
if not file.startswith(zone.zonefile + '.') or not os.path.isfile(file):
continue
os.remove(file)
except Exception as e:
raise RuntimeError(f'unable to delete zone file: {e}')