#!/usr/bin/env python3 import dns.name import dns.tsig import dns.rcode import dns.rdataclass import dns.rdataset import dns.rdatatype import dns.query import dns.update import dns.zone import dns.xfr import os import re import shutil import subprocess import yaml from hashlib import sha1 from prettytable import PrettyTable from shutil import chown DEFAULT_CFGFILE = '/etc/dns-manager/config.yml' NAMED_DEFAULT_VIEW = '_default' RECORD_TYPES = ( dns.rdatatype.A, dns.rdatatype.AAAA, dns.rdatatype.CAA, dns.rdatatype.CDS, dns.rdatatype.CNAME, dns.rdatatype.DNAME, dns.rdatatype.DS, dns.rdatatype.MX, dns.rdatatype.NS, dns.rdatatype.PTR, dns.rdatatype.SRV, dns.rdatatype.TLSA, dns.rdatatype.TXT) ALL_RECORD_TYPES = (dns.rdatatype.from_text(rdtype.name) for rdtype in dns.rdatatype.RdataType) def printe(msg): print(f'ERROR: {msg}') def prettytable(field_names, rows, truncate=False): t = PrettyTable() t.field_names = field_names for field in t.field_names: t.align[field] = 'l' if truncate and rows: n_cols = len(rows[0]) max_col_lengths = [0] * n_cols for row in rows: for index in range(0, len(row)): max_col_lengths[index] = max(max_col_lengths[index], len(row[index]) + 3) max_total_len = sum(max_col_lengths) + 2 max_pre_len = sum(max_col_lengths[0:-1]) + 2 terminal_width = os.get_terminal_size().columns if max_total_len >= terminal_width: max_value_len = terminal_width - max_pre_len - 3 if max_value_len < 5: raise RuntimeError('terminal is too small') for i in range(len(rows)): value = rows[i][-1] if len(value) > max_value_len: rows[i][-1] = f'{value[:max_value_len - 3]}...' for row in rows: t.add_row(row) return t def prettyselect(field_names, rows, prompt='Select entry', also_valid=[], truncate=False): length = len(rows) if length < 1: raise RuntimeError('no entries to select from') field_names.insert(0, '#') for index in range(length): rows[index].insert(0, str(index + 1)) print(prettytable(field_names, rows, truncate)) print() valid_str = f'1 - {length}' if also_valid: also_valid_str = ', '.join(also_valid) valid_str = f'{valid_str}, {also_valid_str}' index = None while index is None: try: index = input(f'{prompt} ({valid_str}): ') except KeyboardInterrupt: print('\nAborted ...') raise KeyboardInterrupt if index in also_valid: print() return index try: index = int(index) - 1 if index < 0 or index > length - 1: raise ValueError except ValueError: index = None print() return index def name_from_text(txt, origin=None): if not txt: raise RuntimeError('empty value') if origin is None: txt = txt.lower() if not txt.endswith('.'): txt += '.' elif txt.endswith('.'): raise RuntimeError('record name is absolute (ends with dot)') try: name = dns.name.from_text(txt, origin) except Exception as e: raise RuntimeError(f'{e}') return name def name_views_from_text(txt): try: (name, view_txt) = txt.split('@', maxsplit=1) if not view_txt: view_txt = None except ValueError: name = txt view_txt = None name = name_from_text(name) if view_txt is None: views = None else: views = list(set([view.strip() for view in view_txt.split(',')])) if '*' in views: views = '*' return name, views def input_name(origin=None, prompt='Zone name'): name = None while name is None: try: value = input(f'{prompt}: ') if not value: continue name = name_from_text(value, origin) print() except KeyboardInterrupt: print('\nAborted ...') raise KeyboardInterrupt except RuntimeError as e: print(f'ERROR: {e}\n') name = None return name def ttl_from_text(txt): try: ttl = int(txt) if ttl < 5: raise RuntimeError('TTL is too low (<5 seconds)') if ttl > 604800: raise RuntimeError('TTL is too high (>604800 seconds)') except ValueError as e: raise RuntimeError(f'{e}') return ttl def input_ttl(): ttl = None while ttl is None: try: value = input('TTL (5 - 604800): ') if not value: continue ttl = ttl_from_text(value) print() except KeyboardInterrupt: print('\nAborted ...') raise KeyboardInterrupt except RuntimeError as e: print(f'ERROR: {e}\n') ttl = None return ttl def type_from_text(txt, all_types=False): try: rdtype = dns.rdatatype.from_text(txt) if not all_types and rdtype not in RECORD_TYPES: raise RuntimeError('record type is not supported') except Exception as e: raise RuntimeError(f'{e}') return rdtype def select_type(all_types=False): rdtypes = ALL_RECORD_TYPES if all_types else RECORD_TYPES rows = sorted([[rdtype.to_text(rdtype)] for rdtype in rdtypes]) index = prettyselect(['Record type'], rows, prompt='Select record type') return rdtypes[index] def encode_txt_value(txt): if '"' in txt: return txt txt = '" "'.join([txt[0+i:255+i] for i in range(0, len(txt), 255)]) return f'"{txt}"' def rdata_from_text(rdtype, txt, origin): if rdtype == dns.rdatatype.TXT: txt = encode_txt_value(txt) try: rdata = dns.rdata.from_text(dns.rdataclass.IN, rdtype, txt, origin) except Exception as e: raise RuntimeError(f'{e}') return rdata def input_rdata(rdtype, origin): rdata = None while rdata is None: try: value = input('Record value: ') if not value: continue rdata = rdata_from_text(rdtype, value, origin) print() except KeyboardInterrupt: print('\nAborted ...') raise KeyboardInterrupt except RuntimeError as e: print(f'ERROR: {e}\n') rdata = None return rdata def input_yes_no(prompt='Confirm?'): confirm = None while confirm is None: try: value = input(f'{prompt} (yes/no): ').lower() if value == 'yes': confirm = True print() elif value == 'no': confirm = False else: confirm = None except KeyboardInterrupt: print('\nAborted ...') confirm = False return confirm class DNSViewConfig: def __init__(self, name, config, config_dir): if not isinstance(config, dict): raise RuntimeError(f'views: {name}: value is not an associative array') self.config_dir = config.get('config_dir', os.path.join(config_dir, f'{name}.zones')) if not isinstance(self.config_dir, str): raise RuntimeError(f'views: {name}: config_dir: value is not a string') self.config_file = config.get('config_file') if self.config_file is None: raise RuntimeError(f'views: {name}: missing mandatory parameter: config_file') if not isinstance(self.config_file, str): raise RuntimeError(f'views: {name}: config_file: value is not a string') self.zone_dir = config.get('zone_dir') if self.zone_dir is None: raise RuntimeError(f'views: {name}: missing mandatory parameter: zone_dir') if not isinstance(self.zone_dir, str): raise RuntimeError(f'views: {name}: zone_dir: value is not a string') self.catalog_zone = config.get('catalog_zone') if not isinstance(self.catalog_zone, (str, type(None))): raise RuntimeError(f'views: {name}: catalog_zone: value is not a string') self.config_template = None self.zone_template = None templates = config.get('templates') if templates is not None: if not isinstance(templates, dict): raise RuntimeError(f'views: {name}: templates: value is not an associative array') self.config_template = templates.get('config') if not isinstance(self.config_template, (str, type(None))): raise RuntimeError(f'views: {name}: templates: config: value is not a string') self.zone_template = templates.get('zone') if not isinstance(self.zone_template, (str, type(None))): raise RuntimeError(f'views: {name}: templates: zone: value is not a string') class DNSManagerConfig: def __init__(self, cfgfile): with open(cfgfile, 'r') as file: config = yaml.safe_load(file) if not isinstance(config, dict): raise RuntimeError('config is not an associative array') self.etc_dir = config.get('etc_dir', '/etc/dns-manager') if not isinstance(self.etc_dir, str): raise RuntimeError('etc_dir: value is not a string') self.control_key = config.get('control_key') if not isinstance(self.control_key, (str, type(None))): raise RuntimeError('control_key: value is not a string') self.dns_ip = config.get('dns_ip', '127.0.0.1') if not isinstance(self.dns_ip, str): raise RuntimeError('dns_ip value: is not a string') self.named_checkconf = config.get('named_checkconf', None) if self.named_checkconf is None: self.named_checkconf = shutil.which('named-checkconf') if self.named_checkconf is None: raise RuntimeError('named-checkconf: executable not found') elif not isinstance(self.named_checkconf, str): raise RuntimeError('named_checkconf: value is not a string') self.named_conf = config.get('named_conf') if not isinstance(self.named_conf, (str, type(None))): raise RuntimeError('named_conf: value is not a string') self.rndc = config.get('rndc') if not isinstance(self.rndc, (str, type(None))): raise RuntimeError('rndc: value is not a string') self.dns_keyfiles = config.get('dns_keyfiles') if self.dns_keyfiles is None: self.dns_keyfiles = {} elif not isinstance(self.dns_keyfiles, (dict)): raise RuntimeError('dns_keyfiles: value is not an associative array') for view in self.dns_keyfiles: if not isinstance(self.dns_keyfiles[view], str): raise RuntimeError(f'dns_keyfiles: {view}: value is not a string') self.zones_config = {} zones_config = config.get('zones_config') if zones_config is not None: if not isinstance(zones_config, dict): raise RuntimeError('zones_config: value is not a dictionary') for name, view_config in zones_config.items(): self.zones_config[name] = DNSViewConfig(name, view_config, self.etc_dir) class Zone(dns.zone.Zone): def __init__(self, view=NAMED_DEFAULT_VIEW, status='unknown', cfgfile=None, zonefile=None, **kwargs): self.view = view self.status = status self.cfgfile = cfgfile self.zonefile = zonefile super().__init__(**kwargs) def filter_by_name(self, name, origin): name_relative = name.relativize(origin) self.nodes = dict((rdname, node) for rdname, node in self.items() if rdname == name_relative) def filter_by_rdtype(self, rdtype): for name, node in self.items(): rdataset = node.get_rdataset(dns.rdataclass.IN, rdtype) node.rdatasets.clear() if rdataset is not None: node.rdatasets.append(rdataset) self.nodes = dict((rdname, node) for rdname, node in self.items() if node.rdatasets) def filter_by_rdata(self, rdata): for name, node in self.items(): for rdataset in node: rdataset.items = list(filter(lambda item: rdata == item, rdataset.items)) node.rdatasets = list(filter(lambda rdataset: rdataset.items, node.rdatasets)) self.nodes = dict((rdname, node) for rdname, node in self.items() if node.rdatasets) def nfz(self): return sha1(self.origin.to_wire()).hexdigest() def named_zones(named_checkconf, named_conf=None): cmd = [named_checkconf, '-l'] if named_conf is not None: cmd.append(named_conf) output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode() zones = [] for line in output.splitlines(): try: (name, zclass, view, status) = line.split() except ValueError: raise RuntimeError(f"named-checkconf returned invalid line: '{line}'") if zclass.upper() == 'IN' and status.lower() in ('master', 'slave'): zone = Zone(origin=name, view=view, status=status) zones.append(zone) return zones def managed_zones(config, bind_zones=[]): config_dirs = [] zone_dirs = [] zones = [] for view, cfg in config.items(): if cfg.config_dir in config_dirs: raise RuntimeError(f'config directory used in multiple views: {cfg.config_dir}') config_dirs.append(cfg.config_dir) if cfg.zone_dir in zone_dirs: raise RuntimeError(f'zone directory used in multiple views: {cfg.zone_dir}') zone_dirs.append(cfg.zone_dir) for file in os.listdir(cfg.config_dir): if file.startswith('.') or not file.endswith('.conf'): continue cfgfile = os.path.join(cfg.config_dir, file) if not os.path.isfile(cfgfile): continue name = file.removesuffix('.conf') status = None if bind_zones: dns_name = dns.name.from_text(name) try: status = next(z.status for z in bind_zones if z.origin == dns_name) except StopIteration: status = None zonefile = os.path.join(cfg.zone_dir, f'{name}.zone') if not os.path.isfile(zonefile): raise RuntimeError(f'missing zone file: {zonefile}') zone = Zone(origin=name, view=view, status=status, cfgfile=cfgfile, zonefile=zonefile) zones.append(zone) return zones def keys_from_file(path): with open(path, 'r') as f: content = f.read() key_block_re = re.compile(r'^\s*key\s+"?(?P[^"]+?)"?\s*{(?P(.|\n)*?)}\s*;', re.MULTILINE) secret_re = re.compile(r'(.|\n)*secret\s+"?(?P[^"]+?)"?\s*;', re.MULTILINE) algorithm_re = re.compile(r'(.|\n)*algorithm\s+"?(?P[^"]+?)"?\s*;', re.MULTILINE) matches = key_block_re.finditer(content) if not matches: raise RuntimeError(f'no key section found in config file: {path}') keys = [] for match in matches: groupdict = match.groupdict() name = groupdict['name'] match = secret_re.match(groupdict['config']) if not match: raise RuntimeError(f"missing secret in config of key '{name}'") secret = match.groupdict()['secret'] match = algorithm_re.match(groupdict['config']) if match: algorithm = match.groupdict()['algorithm'] else: algorithm = None keys.append(dns.tsig.Key(name, secret, algorithm=algorithm)) return keys class DNSManager: def __init__(self, cfgfile=DEFAULT_CFGFILE): self.config = DNSManagerConfig(cfgfile) self._bind_zones = None self._zones = None self._all_zones = None @property def bind_zones(self): if self._bind_zones is None: self._bind_zones = named_zones(named_checkconf=self.config.named_checkconf, named_conf=self.config.named_conf) return self._bind_zones @property def zones(self): if self._zones is None: self._zones = managed_zones(self.config.zones_config, self.bind_zones) return self._zones @property def all_zones(self): if self._all_zones is None: self._all_zones = self.zones for zone in self.bind_zones: if next((z for z in self.zones if z.origin == zone.origin and z.view == zone.view), None) is None: self._all_zones.append(zone) return self._all_zones def get_zones(self, name_view, all_zones=False): (dns_name, views) = name_views_from_text(name_view) zone_base = self.all_zones if all_zones else self.zones zones = list(filter(lambda z: z.origin == dns_name, zone_base)) if not zones: raise RuntimeError('zone not found') if views is None: if len(zones) > 1: raise RuntimeError('zone is part of multiple views') elif zones[0].view != NAMED_DEFAULT_VIEW: raise RuntimeError('zone is not part of the default view') elif views != '*': all_views = set([zone.view for zone in zone_base]) zone_views = [zone.view for zone in zones] for view in views: if view not in all_views: raise RuntimeError(f"view does not exist -- '{view}'") if view not in zone_views: raise RuntimeError(f"zone is not part of view -- '{view}'") zones = list(filter(lambda z: z.view in views, zones)) return zones def select_zones(self, all_zones=False): zones = self.all_zones if all_zones else self.zones names = sorted(set([zone.origin.to_unicode(omit_final_dot=True) for zone in zones])) rows = [[name] for name in names] index = prettyselect(['Zone'], rows, prompt='Select zone') name = names[index] try: selected_zones = self.get_zones(name, all_zones) except ValueError: dns_name = dns.name.from_text(name) zones = list(filter(lambda z: z.origin == dns_name, zones)) views = [view for view in sorted(set([zone.view for zone in zones]))] rows = [[view] for view in sorted(set([zone.view for zone in zones]))] index = prettyselect(['View'], rows, prompt='Select view', also_valid=['*']) if index == '*': selected_zones = zones else: view = views[index] selected_zones = list(filter(lambda z: z.view == view, zones)) return selected_zones def generate_config(self, view): view_cfg = self.config.zones_config[view] try: with open(view_cfg.config_file, 'w') as cfh: for file in os.listdir(view_cfg.config_dir): cfgfile = os.path.join(view_cfg.config_dir, file) if not file.endswith('.conf') or not os.path.isfile(cfgfile): continue with open(cfgfile, 'r') as fh: cfh.write(fh.read()) cfh.write('\n') except Exception as e: raise RuntimeError(f'unable to generate view config: {e}') def named_reload(self): rndc = self.config.rndc if self.config.rndc else shutil.which('rndc') if rndc is None: raise RuntimeError('rndc executable not found') cmd = [rndc] if self.config.control_key: cmd.extend(['-k', self.config.control_key]) cmd.append('reconfig') res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if res.returncode != 0: raise RuntimeError(f'error reloading named config: {res.stderr.decode()}') return res.stdout.decode() def get_keyfile(self, zone): name = zone.origin.to_text(omit_final_dot=True) keyfile = None for key_id in (f'{name}@{zone.view}', f'{zone.view}'): try: keyfile = self.config.dns_keyfiles[key_id] break except KeyError: pass if keyfile is None and zone.view != NAMED_DEFAULT_VIEW: raise RuntimeError(f"no key configured for zone '{name}' in view '{zone.view}'") return keyfile def get_zone_content(self, zone): keyfile = self.get_keyfile(zone) keys = keys_from_file(keyfile) key = keys[0] if keys else None query, _ = dns.xfr.make_query(zone, keyring=key) try: dns.query.inbound_xfr(where=self.config.dns_ip, txn_manager=zone, query=query) except Exception as e: raise RuntimeError(e) def add_zone_record(self, zone, rdname, rdataset): keyfile = self.get_keyfile(zone) keys = keys_from_file(keyfile) key = keys[0] if keys else None update = dns.update.Update(zone.origin, keyring=key) update.add(rdname, rdataset) try: response = dns.query.tcp(update, self.config.dns_ip, timeout=10) except Exception as e: raise RuntimeError(e) if response.rcode() != dns.rcode.NOERROR: raise RuntimeError(response.to_text()) return response def delete_zone_record(self, zone, rdname, rdataset): keyfile = self.get_keyfile(zone) keys = keys_from_file(keyfile) key = keys[0] if keys else None update = dns.update.Update(zone.origin, keyring=key) update.delete(rdname, rdataset) try: response = dns.query.tcp(update, self.config.dns_ip, timeout=10) except Exception as e: raise RuntimeError(e) if response.rcode() != dns.rcode.NOERROR: raise RuntimeError(response.to_text()) return response def add_zone(self, name, view, config_template=None, zone_template=None): config = self.config.zones_config[view] origin = name.to_text(omit_final_dot=True) cfgfile = os.path.join(config.config_dir, f'{origin}.conf') zonefile = os.path.join(config.zone_dir, f'{origin}.zone') zone = Zone(origin=name, view=view, status=None, cfgfile=cfgfile, zonefile=zonefile) if os.path.exists(cfgfile): raise RuntimeError(f'config file already exists: {cfgfile}') if os.path.exists(zonefile): raise RuntimeError(f'zone file already exists: {zonefile}') if config_template is None: config_template = config.config_template if config_template is None: raise RuntimeError('no config template file configured') if zone_template is None: zone_template = config.zone_template if zone_template is None: raise RuntimeError('no zone template file configured') try: with open(config_template, 'r') as f: zone_config = f.read() except Exception as e: raise RuntimeError(f'unable to open/read config template: {e}') zone_config = zone_config.replace('%ZONE%', origin) \ .replace('%ZONE_FILE%', zonefile) \ .replace('%ZONE_FILENAME%', f'{origin}.zone') try: with open(cfgfile, 'w') as f: f.write(zone_config) except Exception as e: raise RuntimeError(f'unable to open/write config file: {e}') try: with open(zone_template, 'r') as f: zone_content = f.read() except Exception as e: os.remove(cfgfile) raise RuntimeError(f'unable to open/read zone template: {e}') zone_content = zone_content.replace('%ZONE%', origin) try: with open(zonefile, 'w') as f: f.write(zone_content) except Exception as e: os.remove(cfgfile) raise RuntimeError(f'unable to open/write zone file: {e}') try: chown(zonefile, 'named', 'named') except Exception as e: os.remove(cfgfile) os.remove(zonefile) raise RuntimeError(f'unable to change ownership of zone file: {e}') return zone def delete_zone(self, zone): try: os.remove(zone.cfgfile) except Exception as e: raise RuntimeError(f'unable to delete zone config file: {e}') self.generate_config(zone.view) def cleanup_zone(self, zone): try: os.remove(zone.zonefile) zone_dir = self.config.zones_config[zone.view].zone_dir for file in os.listdir(zone_dir): file = os.path.join(zone_dir, file) if not file.startswith(zone.zonefile + '.') or not os.path.isfile(file): continue os.remove(file) except Exception as e: raise RuntimeError(f'unable to delete zone file: {e}')