From 5a746f56361d5074ff948388eaf701a6782d1bca Mon Sep 17 00:00:00 2001 From: Thomas Oettli Date: Wed, 17 Feb 2021 16:55:08 +0100 Subject: [PATCH] restructure code and add replace_links --- pymodmilter/__init__.py | 60 ++++++++++++++- pymodmilter/actions.py | 162 ++++++++++++++-------------------------- 2 files changed, 114 insertions(+), 108 deletions(-) diff --git a/pymodmilter/__init__.py b/pymodmilter/__init__.py index 48d01f1..71a04e4 100644 --- a/pymodmilter/__init__.py +++ b/pymodmilter/__init__.py @@ -29,9 +29,12 @@ import logging from Milter.utils import parse_addr +from collections import defaultdict + +from email.header import Header from email.message import MIMEPart from email.parser import BytesFeedParser -from email.policy import default as default_policy +from email.policy import default as default_policy, SMTP from pymodmilter.conditions import Conditions @@ -79,7 +82,7 @@ class Rule: break def need_body(self): - """Return the if this rule needs the message body.""" + """Return True if this rule needs the message body.""" return self._need_body def ignores(self, host=None, envfrom=None, envto=None): @@ -142,6 +145,14 @@ class MilterMessage(MIMEPart): self._headers = newheaders +def replace_illegal_chars(string): + """Replace illegal characters in header values.""" + return string.replace( + "\x00", "").replace( + "\r", "").replace( + "\n", "") + + class ModifyMilter(Milter.Base): """ModifyMilter based on Milter.Base to implement milter communication""" @@ -162,6 +173,50 @@ class ModifyMilter(Milter.Base): # save rules, it must not change during runtime self.rules = ModifyMilter._rules.copy() + self.msg = None + + def addheader(self, field, value, idx=-1): + value = replace_illegal_chars(Header(s=value).encode()) + self.logger.debug(f"milter: addheader: {field}: {value}") + super().addheader(field, value, idx) + + def chgheaer(self, field, value, idx=1): + value = replace_illegal_chars(Header(s=value).encode()) + if value: + self.logger.debug(f"milter: chgheader: {field}[{idx}]: {value}") + else: + self.logger.debug(f"milter: delheader: {field}[{idx}]") + super().chgheader(field, idx, value) + + def update_headers(self, old_headers): + if self.msg.is_multipart() and not self.msg["MIME-Version"]: + self.msg.add_header("MIME-Version", "1.0") + + # serialize the message object so it updates its internal strucure + self.msg.as_bytes() + + old_headers = [(f, f.lower(), v) for f, v in old_headers] + headers = [(f, f.lower(), v) for f, v in self.msg.items()] + + idx = defaultdict(int) + for field, field_lower, value in old_headers: + idx[field_lower] += 1 + if (field, field_lower, value) not in headers: + self.chgheader(field, "", idx=idx[field_lower]) + idx[field] -= 1 + + for field, value in self.msg.items(): + field_lower = field.lower() + if (field, field_lower, value) not in old_headers: + self.addheader(field, value) + + def replacebody(self): + data = self.msg.as_bytes(policy=SMTP) + body_pos = data.find(b"\r\n\r\n") + 4 + self.logger.debug("milter: replacebody") + super().replacebody(data[body_pos:]) + del data + def connect(self, IPname, family, hostaddr): try: if hostaddr is None: @@ -210,7 +265,6 @@ class ModifyMilter(Milter.Base): return Milter.CONTINUE - @Milter.noreply def envrcpt(self, to, *str): try: self.recipients.add("@".join(parse_addr(to)).lower()) diff --git a/pymodmilter/actions.py b/pymodmilter/actions.py index 7db1ac3..eae8ae5 100644 --- a/pymodmilter/actions.py +++ b/pymodmilter/actions.py @@ -20,26 +20,9 @@ from bs4 import BeautifulSoup from collections import defaultdict from copy import copy from datetime import datetime -from email.header import Header from email.message import MIMEPart -from email.policy import SMTP -from pymodmilter import CustomLogger, Conditions - - -def _replace_illegal_chars(string): - """Replace illegal characters in header values.""" - return string.replace( - "\x00", "").replace( - "\r", "").replace( - "\n", "") - - -def _add_header(milter, field, value, idx=-1): - value = _replace_illegal_chars( - Header(s=value).encode()) - milter.logger.debug(f"milter: addheader: {field}: {value}") - milter.addheader(field, value, idx) +from pymodmilter import CustomLogger, Conditions, replace_illegal_chars def add_header(milter, field, value, pretend=False, @@ -51,20 +34,10 @@ def add_header(milter, field, value, pretend=False, else: logger.info(f"add_header: {header[0:70]}") - milter.msg.add_header(field, _replace_illegal_chars(value)) + milter.msg.add_header(field, replace_illegal_chars(value)) - if pretend: - return - - _add_header(milter, field, value) - - -def _mod_header(milter, field, value, occ=1): - value = _replace_illegal_chars( - Header(s=value).encode()) - milter.logger.debug( - f"milter: chgheader: {field}[{occ}]: {value}") - milter.chgheader(field, occ, value) + if not pretend: + milter.addheader(field, value) def mod_header(milter, field, value, search=None, pretend=False, @@ -76,11 +49,11 @@ def mod_header(milter, field, value, search=None, pretend=False, if isinstance(search, str): search = re.compile(search, re.MULTILINE + re.DOTALL + re.IGNORECASE) - occ = defaultdict(int) + idx = defaultdict(int) for i, (f, v) in enumerate(milter.msg.items()): f_lower = f.lower() - occ[f_lower] += 1 + idx[f_lower] += 1 if not field.match(f): continue @@ -109,18 +82,10 @@ def mod_header(milter, field, value, search=None, pretend=False, logger.info(f"mod_header: {header[0:70]}: {new_header[0:70]}") milter.msg.replace_header( - f, _replace_illegal_chars(new_value), occ=occ[f_lower]) + f, replace_illegal_chars(new_value), idx=idx[f_lower]) - if pretend: - continue - - _mod_header(milter, f, new_value, occ=occ[f_lower]) - - -def _del_header(milter, field, occ=1): - milter.logger.debug( - f"milter: delheader: {field}[{occ}]") - milter.chgheader(field, occ, "") + if not pretend: + milter.chgheader(f, new_value, idx=idx[f_lower]) def del_header(milter, field, value=None, pretend=False, @@ -132,11 +97,11 @@ def del_header(milter, field, value=None, pretend=False, if isinstance(value, str): value = re.compile(value, re.MULTILINE + re.DOTALL + re.IGNORECASE) - occ = defaultdict(int) + idx = defaultdict(int) for f, v in milter.msg.items(): f_lower = f.lower() - occ[f_lower] += 1 + idx[f_lower] += 1 if not field.match(f): continue @@ -149,32 +114,12 @@ def del_header(milter, field, value=None, pretend=False, logger.debug(f"del_header: {header}") else: logger.info(f"del_header: {header[0:70]}") - milter.msg.remove_header(f, occ=occ[f_lower]) + milter.msg.remove_header(f, idx=idx[f_lower]) if not pretend: - _del_header(milter, f, occ=occ[f_lower]) + milter.chgheader(f, "", idx=idx[f_lower]) - occ[f_lower] -= 1 - - -def _serialize_msg(msg, logger): - if msg.is_multipart() and not msg["MIME-Version"]: - msg.add_header("MIME-Version", "1.0") - - try: - logger.debug("serialize message as bytes") - data = msg.as_bytes(policy=SMTP) - except Exception as e: - logger.waring( - f"unable to serialize message as bytes: {e}") - try: - logger.warning("try to serialize message as string") - data = msg.as_string(policy=SMTP) - data = data.encode("ascii", errors="replace") - except Exception as e: - raise e - - return data + idx[f_lower] -= 1 def _get_body_content(msg, pref): @@ -244,37 +189,6 @@ def _patch_message_body(milter, action, text, html, logger): del html_body["MIME-Version"] -def _update_body(milter, logger): - data = _serialize_msg(milter.msg, logger) - body_pos = data.find(b"\r\n\r\n") + 4 - logger.debug("milter: replacebody") - milter.replacebody(data[body_pos:]) - del data - - -def _update_headers(milter, original_headers, logger): - if milter.msg.is_multipart() and not milter.msg["MIME-Version"]: - milter.msg.add_header("MIME-Version", "1.0") - - # serialize the message object so it updates its internal strucure - milter.msg.as_bytes() - - original_headers = [(f, f.lower(), v) for f, v in original_headers] - headers = [(f, f.lower(), v) for f, v in milter.msg.items()] - - occ = defaultdict(int) - for field, field_lower, value in original_headers: - occ[field_lower] += 1 - if (field, field_lower, value) not in headers: - _del_header(milter, field, occ=occ[field_lower]) - occ[field] -= 1 - - for field, value in milter.msg.items(): - field_lower = field.lower() - if (field, field_lower, value) not in original_headers: - _add_header(milter, field, value) - - def _wrap_message(milter, logger): attachment = MIMEPart() attachment.set_content(milter.msg.as_bytes(), @@ -315,7 +229,7 @@ def _inject_body(milter): def add_disclaimer(milter, text, html, action, policy, pretend=False, logger=logging.getLogger(__name__)): """Append or prepend a disclaimer to the mail body.""" - original_headers = milter.msg.items() + old_headers = milter.msg.items() try: try: @@ -348,11 +262,43 @@ def add_disclaimer(milter, text, html, action, policy, pretend=False, "unable to wrap message in a new message envelope, " "give up ...") - if pretend: - return + if not pretend: + milter.update_headers(old_headers) + milter.replacebody() - _update_headers(milter, original_headers, logger) - _update_body(milter, logger) + +def replace_links(milter, repl, pretend=False, + logger=logging.getLogger(__name__)): + """Replace links in the mail body.""" + + text_body, text_content = _get_body_content(milter.msg, "plain") + html_body, html_content = _get_body_content(milter.msg, "html") + + if text_content is not None: + logger.info("replace links in text body") + + content = text_content + + text_body.set_content( + content.encode(), maintype="text", subtype="plain") + text_body.set_param("charset", "UTF-8", header="Content-Type") + del text_body["MIME-Version"] + + if html_content is not None: + logger.info("replace links in html body") + + soup = BeautifulSoup(html_content, "html.parser") + + for link in soup.find_all("a", href=True): + link["href"] = repl + + html_body.set_content( + str(soup).encode(), maintype="text", subtype="html") + html_body.set_param("charset", "UTF-8", header="Content-Type") + del html_body["MIME-Version"] + + if not pretend: + milter.replacebody() def store(milter, directory, pretend=False, @@ -376,6 +322,7 @@ class Action: "del_header": False, "mod_header": False, "add_disclaimer": True, + "replace_links": True, "store": True} def __init__(self, name, local_addrs, conditions, action_type, args, @@ -451,6 +398,11 @@ class Action: self._args["text"] = f.read() except IOError as e: raise RuntimeError(f"unable to read template: {e}") + + elif action_type == "replace_links": + self._func = replace_links + self._args["repl"] = args["repl"] + elif action_type == "store": self._func = store if args["storage_type"] not in ["file"]: