From f4bb0d38ebb18a9e8efc544f1e3eae75b1475175 Mon Sep 17 00:00:00 2001 From: Thomas Oettli Date: Tue, 21 Sep 2021 05:20:47 +0200 Subject: [PATCH] add whitelist functionality to quarantine --- pymodmilter/__init__.py | 1 + pymodmilter/action.py | 8 ++ pymodmilter/conditions.py | 46 +++++++ pymodmilter/storage.py | 26 +++- pymodmilter/whitelist.py | 276 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 354 insertions(+), 3 deletions(-) create mode 100644 pymodmilter/whitelist.py diff --git a/pymodmilter/__init__.py b/pymodmilter/__init__.py index 591e2a7..d1f032f 100644 --- a/pymodmilter/__init__.py +++ b/pymodmilter/__init__.py @@ -22,6 +22,7 @@ __all__ = [ "rule", "run", "storage", + "whitelist", "ModifyMilterConfig", "ModifyMilter"] diff --git a/pymodmilter/action.py b/pymodmilter/action.py index 69866c7..d5ef9ee 100644 --- a/pymodmilter/action.py +++ b/pymodmilter/action.py @@ -188,6 +188,14 @@ class ActionConfig(BaseConfig): if "reject_reason" in cfg: self.add_string_arg(cfg, "reject_reason") + if "whitelist" in cfg: + wl = {"whitelist": cfg["whitelist"]} + wl["name"] = f"{self.name}: whitelist" + if "loglevel" not in wl: + wl["loglevel"] = self.loglevel + self.args["whitelist"] = ConditionsConfig( + wl, self.local_addrs, self.debug) + class Action: """Action to implement a pre-configured action to perform on e-mails.""" diff --git a/pymodmilter/conditions.py b/pymodmilter/conditions.py index 1a4c462..f097ff6 100644 --- a/pymodmilter/conditions.py +++ b/pymodmilter/conditions.py @@ -20,6 +20,7 @@ import re from netaddr import IPAddress, IPNetwork, AddrFormatError from pymodmilter import BaseConfig, CustomLogger +from pymodmilter.whitelist import DatabaseWhitelist class ConditionsConfig(BaseConfig): @@ -52,6 +53,33 @@ class ConditionsConfig(BaseConfig): if "metavar" in cfg: self.add_string_arg(cfg, "metavar") + if "whitelist" in cfg: + assert isinstance(cfg["whitelist"], dict), \ + f"{self.name}: whitelist: invalid value, " \ + f"should be dict" + whitelist = cfg["whitelist"] + assert "type" in whitelist, \ + f"{self.name}: whitelist: mandatory parameter 'type' not found" + assert isinstance(whitelist["type"], str), \ + f"{self.name}: whitelist: type: invalid value, " \ + f"should be string" + self.args["whitelist"] = { + "type": whitelist["type"], + "name": f"{self.name}: whitelist"} + if whitelist["type"] == "db": + for arg in ["connection", "table"]: + assert arg in whitelist, \ + f"{self.name}: whitelist: mandatory parameter " \ + f"'{arg}' not found" + assert isinstance(whitelist[arg], str), \ + f"{self.name}: whitelist: {arg}: invalid value, " \ + f"should be string" + self.args["whitelist"][arg] = whitelist[arg] + + else: + raise RuntimeError( + f"{self.name}: whitelist: type: invalid type") + self.logger.debug(f"{self.name}: " f"loglevel={self.loglevel}, " f"args={self.args}") @@ -94,6 +122,13 @@ class Conditions: except re.error as e: raise RuntimeError(e) + if "whitelist" in cfg.args: + wl_cfg = cfg.args["whitelist"] + if wl_cfg["type"] == "db": + self.whitelist = DatabaseWhitelist(wl_cfg) + else: + raise RuntimeError("invalid storage type") + def match_host(self, host): logger = CustomLogger( self.logger, {"name": self.name}) @@ -134,6 +169,17 @@ class Conditions: return True + def get_wl_rcpts(self, mailfrom, rcpts): + if not self.whitelist: + return {} + + wl_rcpts = [] + for rcpt in rcpts: + if self.whitelist.check(mailfrom, rcpt): + wl_rcpts.append(rcpt) + + return wl_rcpts + def match(self, milter): logger = CustomLogger( self.logger, {"qid": milter.qid, "name": self.name}) diff --git a/pymodmilter/storage.py b/pymodmilter/storage.py index fc6d0f0..9666ca7 100644 --- a/pymodmilter/storage.py +++ b/pymodmilter/storage.py @@ -27,6 +27,7 @@ from glob import glob from time import gmtime from pymodmilter.base import CustomLogger +from pymodmilter.conditions import Conditions class BaseMailStorage: @@ -278,8 +279,8 @@ class Quarantine: "Quarantine class." _headersonly = False - def __init__(self, storage, notification=None, milter_action=None, - reject_reason="Message rejected"): + def __init__(self, storage, notification=None, whitelist=None, + milter_action=None, reject_reason="Message rejected"): self.storage = storage.action(**storage.args, metadata=True) self.storage_name = storage.name self.storage_logger = storage.logger @@ -289,20 +290,39 @@ class Quarantine: self.notification = notification.action(**notification.args) self.notification_name = notification.name self.notification_logger = notification.logger + self.whitelist = Conditions(whitelist) self.milter_action = milter_action self.reject_reason = reject_reason def execute(self, milter, pretend=False, logger=logging.getLogger(__name__)): + wl_rcpts = [] + if self.whitelist: + wl_rcpts = self.whitelist.get_wl_rcpts( + milter.msginfo["mailfrom"], milter.msginfo["rcpts"]) + logger.info(f"whitelisted recipients: {wl_rcpts}") + + rcpts = [ + rcpt for rcpt in milter.msginfo["rcpts"] if rcpt not in wl_rcpts] + + if not rcpts: + # all recipients whitelisted + return + + logger.info(f"add to quarantine for recipients: {rcpts}") + milter.msginfo["rcpts"] = rcpts + custom_logger = CustomLogger( self.storage_logger, {"name": self.storage_name}) self.storage.execute(milter, pretend, custom_logger) + if self.notification is not None: custom_logger = CustomLogger( self.notification_logger, {"name": self.notification_name}) self.notification.execute(milter, pretend, custom_logger) - milter.delrcpt(milter.msginfo["rcpts"].copy()) + milter.msginfo["rcpts"].extend(wl_rcpts) + milter.delrcpt(rcpts) if self.milter_action is not None: return (self.milter_action, self.reject_reason) diff --git a/pymodmilter/whitelist.py b/pymodmilter/whitelist.py new file mode 100644 index 0000000..9c0fcd9 --- /dev/null +++ b/pymodmilter/whitelist.py @@ -0,0 +1,276 @@ +# PyMod-Milter is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PyMod-Milter is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PyMod-Milter. If not, see . +# + +__all__ = [ + "WhitelistBase", + "DatabaseWhitelist"] + +import logging +import peewee +import re + +from datetime import datetime +from playhouse.db_url import connect + + +class WhitelistBase: + "Whitelist base class" + def __init__(self, cfg): + self.name = cfg["name"] + self.logger = logging.getLogger(__name__) + self.valid_entry_regex = re.compile( + r"^[a-zA-Z0-9_.=+-]*?(@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)?$") + + def check(self, mailfrom, recipient): + "Check if mailfrom/recipient combination is whitelisted." + return + + def find(self, mailfrom=None, recipients=None, older_than=None): + "Find whitelist entries." + return + + def add(self, mailfrom, recipient, comment, permanent): + "Add entry to whitelist." + # check if mailfrom and recipient are valid + if not self.valid_entry_regex.match(mailfrom): + raise RuntimeError("invalid from address") + if not self.valid_entry_regex.match(recipient): + raise RuntimeError("invalid recipient") + return + + def delete(self, whitelist_id): + "Delete entry from whitelist." + return + + +class WhitelistModel(peewee.Model): + mailfrom = peewee.CharField() + recipient = peewee.CharField() + created = peewee.DateTimeField(default=datetime.now) + last_used = peewee.DateTimeField(default=datetime.now) + comment = peewee.TextField(default="") + permanent = peewee.BooleanField(default=False) + + +class Meta: + indexes = ( + # trailing comma is mandatory if only one index should be created + (('mailfrom', 'recipient'), True), + ) + + +class DatabaseWhitelist(WhitelistBase): + "Whitelist class to store whitelist in a database" + whitelist_type = "db" + _db_connections = {} + _db_tables = {} + + def __init__(self, cfg): + super().__init__(cfg) + + tablename = cfg["table"] + connection_string = cfg["connection"] + + if connection_string in DatabaseWhitelist._db_connections.keys(): + db = DatabaseWhitelist._db_connections[connection_string] + else: + try: + # connect to database + conn = re.sub( + r"(.*?://.*?):.*?(@.*)", + r"\1:\2", + connection_string) + self.logger.debug( + f"connecting to database '{conn}'") + db = connect(connection_string) + except Exception as e: + raise RuntimeError( + f"unable to connect to database: {e}") + + DatabaseWhitelist._db_connections[connection_string] = db + + # generate model meta class + self.meta = Meta + self.meta.database = db + self.meta.table_name = tablename + self.model = type(f"WhitelistModel_{self.name}", (WhitelistModel,), { + "Meta": self.meta + }) + + if connection_string not in DatabaseWhitelist._db_tables.keys(): + DatabaseWhitelist._db_tables[connection_string] = [] + + if tablename not in DatabaseWhitelist._db_tables[connection_string]: + DatabaseWhitelist._db_tables[connection_string].append(tablename) + try: + db.create_tables([self.model]) + except Exception as e: + raise RuntimeError( + f"unable to initialize table '{tablename}': {e}") + + def _entry_to_dict(self, entry): + result = {} + result[entry.id] = { + "id": entry.id, + "mailfrom": entry.mailfrom, + "recipient": entry.recipient, + "created": entry.created, + "last_used": entry.last_used, + "comment": entry.comment, + "permanent": entry.permanent + } + return result + + def get_weight(self, entry): + value = 0 + for address in [entry.mailfrom, entry.recipient]: + if address == "": + value += 2 + elif address[0] == "@": + value += 1 + return value + + def check(self, mailfrom, recipient): + # check if mailfrom/recipient combination is whitelisted + super().check(mailfrom, recipient) + + # generate list of possible mailfroms + self.logger.debug( + f"query database for whitelist entries from <{mailfrom}> " + f"to <{recipient}>") + mailfroms = [""] + if "@" in mailfrom and not mailfrom.startswith("@"): + domain = mailfrom.split("@")[1] + mailfroms.append(f"@{domain}") + mailfroms.append(mailfrom) + + # generate list of possible recipients + recipients = [""] + if "@" in recipient and not recipient.startswith("@"): + domain = recipient.split("@")[1] + recipients.append(f"@{domain}") + recipients.append(recipient) + + # query the database + try: + entries = list( + self.model.select().where( + self.model.mailfrom.in_(mailfroms), + self.model.recipient.in_(recipients))) + except Exception as e: + raise RuntimeError(f"unable to query database: {e}") + + if not entries: + # no whitelist entry found + return {} + + if len(entries) > 1: + entries.sort(key=lambda x: self.get_weight(x), reverse=True) + + # use entry with the highest weight + entry = entries[0] + entry.last_used = datetime.now() + entry.save() + result = {} + for entry in entries: + result.update(self._entry_to_dict(entry)) + + return result + + def find(self, mailfrom=None, recipients=None, older_than=None): + "Find whitelist entries." + super().find(mailfrom, recipients, older_than) + + if isinstance(mailfrom, str): + mailfrom = [mailfrom] + if isinstance(recipients, str): + recipients = [recipients] + + entries = {} + try: + for entry in list(self.model.select()): + if older_than is not None: + delta = (datetime.now() - entry.last_used).total_seconds() + if delta < (older_than * 86400): + continue + + if mailfrom is not None: + if entry.mailfrom not in mailfrom: + continue + + if recipients is not None: + if entry.recipient not in recipients: + continue + + entries.update(self._entry_to_dict(entry)) + except Exception as e: + raise RuntimeError(f"unable to query database: {e}") + + return entries + + def add(self, mailfrom, recipient, comment, permanent): + "Add entry to whitelist." + super().add( + mailfrom, + recipient, + comment, + permanent) + + try: + self.model.create( + mailfrom=mailfrom, + recipient=recipient, + comment=comment, + permanent=permanent) + except Exception as e: + raise RuntimeError(f"unable to add entry to database: {e}") + + def delete(self, whitelist_id): + "Delete entry from whitelist." + super().delete(whitelist_id) + + try: + query = self.model.delete().where(self.model.id == whitelist_id) + deleted = query.execute() + except Exception as e: + raise RuntimeError( + f"unable to delete entry from database: {e}") + + if deleted == 0: + raise RuntimeError("invalid whitelist id") + + +class WhitelistCache: + def __init__(self): + self.cache = {} + + def load(self, whitelist, mailfrom, recipients): + for recipient in recipients: + self.check(whitelist, mailfrom, recipient) + + def check(self, whitelist, mailfrom, recipient): + if whitelist not in self.cache.keys(): + self.cache[whitelist] = {} + if recipient not in self.cache[whitelist].keys(): + self.cache[whitelist][recipient] = None + if self.cache[whitelist][recipient] is None: + self.cache[whitelist][recipient] = whitelist.check( + mailfrom, recipient) + return self.cache[whitelist][recipient] + + def get_recipients(self, whitelist, mailfrom, recipients): + self.load(whitelist, mailfrom, recipients) + return list(filter( + lambda x: self.cache[whitelist][x], + self.cache[whitelist].keys()))