diff --git a/pymodmilter/__init__.py b/pymodmilter/__init__.py index 41db62e..f5e038e 100644 --- a/pymodmilter/__init__.py +++ b/pymodmilter/__init__.py @@ -13,7 +13,6 @@ # __all__ = [ - "make_header", "actions", "conditions", "run", @@ -26,7 +25,7 @@ import Milter import logging from Milter.utils import parse_addr -from email.message import Message +from email.message import MIMEPart from email.parser import BytesFeedParser from email.policy import default as default_policy @@ -98,18 +97,18 @@ class Rule: return True - def execute(self, milter, msg, pretend=None): + def execute(self, milter, pretend=None): """Execute all actions of this rule.""" if pretend is None: pretend = self.pretend for action in self.actions: - milter_action = action.execute(milter, msg, pretend=pretend) + milter_action = action.execute(milter, pretend=pretend) if milter_action is not None: return milter_action -class MilterMessage(Message): +class MilterMessage(MIMEPart): def replace_header(self, _name, _value, occ=None): _name = _name.lower() counter = 0 @@ -139,12 +138,6 @@ class MilterMessage(Message): self._headers = newheaders -def remove_surrogates(string): - return string.encode( - "ascii", errors="surrogateescape").decode( - "ascii", errors="replace") - - class ModifyMilter(Milter.Base): """ModifyMilter based on Milter.Base to implement milter communication""" @@ -261,12 +254,11 @@ class ModifyMilter(Milter.Base): def header(self, field, value): try: - # feed header line to BytesParser - self._fp.feed(field + b": " + value + b"\r\n") + # remove surrogates + field = field.encode("ascii", errors="surrogateescape") + value = value.encode("ascii", errors="surrogateescape") - # remove surrogates from field and value - field = remove_surrogates(field) - value = remove_surrogates(value) + self._fp.feed(field + b": " + value + b"\r\n") except Exception as e: self.logger.exception( f"an exception occured in header method: {e}") @@ -297,10 +289,9 @@ class ModifyMilter(Milter.Base): def eom(self): try: - msg = self._fp.close() - + self.msg = self._fp.close() for rule in self.rules: - milter_action = rule.execute(self, msg) + milter_action = rule.execute(self) if milter_action is not None: if milter_action["action"] == "reject": diff --git a/pymodmilter/actions.py b/pymodmilter/actions.py index 2b1c382..cad80e1 100644 --- a/pymodmilter/actions.py +++ b/pymodmilter/actions.py @@ -35,16 +35,17 @@ def _replace_illegal_chars(string): "\n", "") -def add_header(milter, msg, field, value, pretend=False, +def add_header(milter, field, value, pretend=False, update_msg=True, logger=logging.getLogger(__name__)): """Add a mail header field.""" - header = f"{field}: {value}" - if logger.getEffectiveLevel() == logging.DEBUG: - logger.debug(f"add_header: {header}") - else: - logger.info(f"add_header: {header[0:70]}") + if update_msg: + header = f"{field}: {value}" + if logger.getEffectiveLevel() == logging.DEBUG: + logger.debug(f"add_header: {header}") + else: + logger.info(f"add_header: {header[0:70]}") - msg.add_header(field, value) + milter.msg.add_header(field, value) if pretend: return @@ -55,8 +56,8 @@ def add_header(milter, msg, field, value, pretend=False, milter.addheader(field, encoded_value, -1) -def mod_header(milter, msg, field, value, search=None, pretend=False, - logger=logging.getLogger(__name__)): +def mod_header(milter, field, value, search=None, pretend=False, + update_msg=True, logger=logging.getLogger(__name__)): """Change the value of a mail header field.""" if isinstance(field, str): field = re.compile(field, re.IGNORECASE) @@ -66,7 +67,7 @@ def mod_header(milter, msg, field, value, search=None, pretend=False, occ = defaultdict(int) - for i, (f, v) in enumerate(msg.items()): + for i, (f, v) in enumerate(milter.msg.items()): f_lower = f.lower() occ[f_lower] += 1 @@ -87,15 +88,16 @@ def mod_header(milter, msg, field, value, search=None, pretend=False, "skip modification") continue - header = f"{f}: {v}" - new_header = f"{f}: {new_v}" + if update_msg: + header = f"{f}: {v}" + new_header = f"{f}: {new_v}" - if logger.getEffectiveLevel() == logging.DEBUG: - logger.debug(f"mod_header: {header}: {new_header}") - else: - logger.info(f"mod_header: {header[0:70]}: {new_header[0:70]}") + if logger.getEffectiveLevel() == logging.DEBUG: + logger.debug(f"mod_header: {header}: {new_header}") + else: + logger.info(f"mod_header: {header[0:70]}: {new_header[0:70]}") - msg.replace_header(f, new_v, occ=occ[f_lower]) + milter.msg.replace_header(f, new_v, occ=occ[f_lower]) if pretend: continue @@ -107,7 +109,7 @@ def mod_header(milter, msg, field, value, search=None, pretend=False, milter.chgheader(f, occ[f_lower], encoded_value) -def del_header(milter, msg, field, value=None, pretend=False, +def del_header(milter, field, value=None, pretend=False, update_msg=True, logger=logging.getLogger(__name__)): """Delete a mail header field.""" if isinstance(field, str): @@ -118,9 +120,7 @@ def del_header(milter, msg, field, value=None, pretend=False, occ = defaultdict(int) - # iterate a copy of milter.fields because elements may get removed - # during iteration - for f, v in msg.items(): + for f, v in milter.msg.items(): f_lower = f.lower() occ[f_lower] += 1 @@ -130,13 +130,13 @@ def del_header(milter, msg, field, value=None, pretend=False, if value is not None and not value.search(v): continue - header = f"{f}: {v}" - if logger.getEffectiveLevel() == logging.DEBUG: - logger.debug(f"del_header: {header}") - else: - logger.info(f"del_header: {header[0:70]}") - - msg.remove_header(f, occ=occ[f_lower]) + if update_msg: + header = f"{f}: {v}" + if logger.getEffectiveLevel() == logging.DEBUG: + logger.debug(f"del_header: {header}") + else: + logger.info(f"del_header: {header[0:70]}") + milter.msg.remove_header(f, occ=occ[f_lower]) occ[f_lower] -= 1 @@ -146,13 +146,38 @@ def del_header(milter, msg, field, value=None, pretend=False, milter.chgheader(f, occ[f_lower], "") -def _get_body_content(msg, body_type): - content = None - body_part = msg.get_body(preferencelist=(body_type)) - if body_part is not None: - content = body_part.get_content() +def _serialize_msg(msg, logger): + if not msg["MIME-Version"]: + msg.add_header("MIME-Version", "1.0") - return (body_part, content) + try: + logger.debug("serialize message as bytes") + data = msg.as_bytes(policy=SMTP) + except Exception as e: + logger.waring( + f"unable to serialize message as bytes: {e}") + try: + logger.warning("try to serialize message as string") + data = msg.as_string(policy=SMTP) + data = data.encode("ascii", errors="replace") + except Exception as e: + raise e + + return data + + +def _get_body_content(msg, pref): + part = None + content = None + if not msg.is_multipart() and msg.get_content_type() == f"text/{pref}": + part = msg + else: + part = msg.get_body(preferencelist=(pref)) + + if part is not None: + content = part.get_content() + + return (part, content) def _has_content_before_body_tag(soup): @@ -166,9 +191,9 @@ def _has_content_before_body_tag(soup): return False -def _patch_message_body(msg, action, text, html, logger): - text_body, text_content = _get_body_content(msg, "plain") - html_body, html_content = _get_body_content(msg, "html") +def _patch_message_body(milter, action, text, html, logger): + text_body, text_content = _get_body_content(milter.msg, "plain") + html_body, html_content = _get_body_content(milter.msg, "html") if text_content is None and html_content is None: raise RuntimeError("message does not contain any body part") @@ -184,6 +209,7 @@ def _patch_message_body(msg, action, text, html, logger): text_body.set_content( content.encode(), maintype="text", subtype="plain") text_body.set_param("charset", "UTF-8", header="Content-Type") + del text_body["MIME-Version"] if html_content is not None: logger.info(f"{action} html disclaimer") @@ -204,86 +230,84 @@ def _patch_message_body(msg, action, text, html, logger): html_body.set_content( str(soup).encode(), maintype="text", subtype="html") html_body.set_param("charset", "UTF-8", header="Content-Type") + del html_body["MIME-Version"] -def _serialize_msg(msg, logger): - try: - logger.debug("serialize message as bytes") - data = msg.as_bytes(policy=SMTP) - except Exception as e: - logger.waring( - f"unable to serialize message as bytes: {e}") - try: - logger.warning("try to serialize message as string") - data = msg.as_string(policy=SMTP) - data = data.encode("ascii", errors="replace") - except Exception as e: - raise e - - return data +def _update_body(milter, logger): + data = _serialize_msg(milter.msg, logger) + body_pos = data.find(b"\r\n\r\n") + 4 + logger.debug("milter: replacebody") + milter.replacebody(data[body_pos:]) + del data -def _wrap_message(milter): - msg = MIMEPart() - msg.add_header("MIME-Version", "1.0") +def _update_headers(milter, original_headers, logger): + # serialize the message object so it updates its headers internally + milter.msg.as_bytes() + for field, value in original_headers: + if field not in milter.msg: + del_header(milter, field=f"^{field}$", update_msg=False, + logger=logger) - msg.set_content( + for field, value in milter.msg.items(): + field_lower = field.lower() + if not [f for f in original_headers if f[0].lower() == field_lower]: + add_header(milter, field=field, value=value, update_msg=False, + logger=logger) + else: + mod_header(milter, field=f"^{field}$", value=value, + update_msg=False, logger=logger) + + +def _wrap_message(milter, logger): + attachment = MIMEPart() + attachment.set_content(milter.msg.as_bytes(), + maintype="plain", subtype="text", + disposition="attachment", + filename=f"{milter.qid}.eml", + params={"name": f"{milter.qid}.eml"}) + + milter.msg.clear_content() + milter.msg.set_content( "Please see the original email attached.") - msg.add_alternative( - "Please see the original email attached.", + milter.msg.add_alternative( + "Please see the original email attached.", subtype="html") - - data = b"" - for field, value in milter.fields: - encoded_value = _replace_illegal_chars( - Header(s=value).encode()) - data += field.encode("ascii", errors="replace") - data += b": " - data += encoded_value.encode("ascii", errors="replace") - data += b"\r\n" - - milter.body_data.seek(0) - data += b"\r\n" + milter.body_data.read() - - msg.add_attachment( - data, maintype="plain", subtype="text", - filename=f"{milter.qid}.eml") - - return msg + milter.msg.make_mixed() + milter.msg.attach(attachment) -def _inject_body(milter, msg): - if not msg.is_multipart(): - msg.make_mixed() +def _inject_body(milter): + if not milter.msg.is_multipart(): + milter.msg.make_mixed() - new_msg = MIMEPart() - new_msg.add_header("MIME-Version", "1.0") - new_msg.set_content("") - new_msg.add_alternative("", subtype="html") - new_msg.make_mixed() - for attachment in msg.iter_attachments(): - new_msg.attach(attachment) + attachments = [] + for attachment in milter.msg.iter_attachments(): + if "content-disposition" not in attachment: + attachment["Content-Disposition"] = "attachment" + attachments.append(attachment) - return new_msg + milter.msg.clear_content() + milter.msg.set_content("") + milter.msg.add_alternative("", subtype="html") + milter.msg.make_mixed() + + for attachment in attachments: + milter.msg.attach(attachment) -def add_disclaimer(milter, msg, text, html, action, policy, pretend=False, +def add_disclaimer(milter, text, html, action, policy, pretend=False, logger=logging.getLogger(__name__)): """Append or prepend a disclaimer to the mail body.""" - update_headers = False + original_headers = milter.msg.items() try: try: - _patch_message_body(msg, action, text, html, logger) - data = _serialize_msg(msg, logger) - if not msg.is_multipart(): - update_headers = True + _patch_message_body(milter, action, text, html, logger) except RuntimeError as e: logger.info(f"{e}, inject empty plain and html body") - msg = _inject_body(milter, msg) - _patch_message_body(msg, action, text, html, logger) - data = _serialize_msg(msg, logger) - update_headers = True + _inject_body(milter) + _patch_message_body(milter, action, text, html, logger) except Exception as e: logger.warning(e) if policy == "ignore": @@ -300,74 +324,31 @@ def add_disclaimer(milter, msg, text, html, action, policy, pretend=False, logger.info("wrap original message in a new message envelope") try: - msg = _wrap_message(milter) - _patch_message_body(msg, action, text, html, logger) - data = _serialize_msg(msg, logger) - update_headers = True + _wrap_message(milter, logger) + _patch_message_body(milter, action, text, html, logger) except Exception as e: logger.error(e) raise Exception( "unable to wrap message in a new message envelope, " "give up ...") - body_pos = data.find(b"\r\n\r\n") + 4 - milter.body_data.seek(0) - milter.body_data.write(data[body_pos:]) - milter.body_data.truncate() - if pretend: return - logger.debug("milter: replacebody") - milter.replacebody(data[body_pos:]) - del data - - if not update_headers: - return - - fields = { - "mime-version": { - "field": "MIME-Version", - "value": msg.get("MIME-Version"), - "modified": False}, - "content-type": { - "field": "Content-Type", - "value": msg.get("Content-Type"), - "modified": False}, - "content-transfer-encoding": { - "field": "Content-Transfer-Encoding", - "value": msg.get("Content-Transfer-Encoding"), - "modified": False}} - - for field, value in msg.items(): - field_lower = field.lower() - if field_lower in fields and fields[field_lower]["value"] is not None: - mod_header(milter, msg, field=f"^{field}$", - value=fields[field_lower]["value"], - pretend=pretend, logger=logger) - fields[field_lower]["modified"] = True - - elif field_lower.startswith("content-"): - del_header(milter, msg, field=f"^{field}$", - pretend=pretend, logger=logger) - - for field in fields.values(): - if not field["modified"] and field["value"] is not None: - add_header(milter, msg, field=field["field"], value=field["value"], - pretend=pretend, logger=logger) + _update_headers(milter, original_headers, logger) + _update_body(milter, logger) -def store(milter, msg, directory, pretend=False, +def store(milter, directory, pretend=False, logger=logging.getLogger(__name__)): timestamp = datetime.now().strftime("%Y%m%d%H%M%S") store_id = f"{timestamp}_{milter.qid}" datafile = os.path.join(directory, store_id) - milter.body_data.seek(0) logger.info(f"store message in file {datafile}") try: with open(datafile, "wb") as fp: - fp.write(msg.as_bytes()) + fp.write(milter.msg.as_bytes()) except IOError as e: raise RuntimeError(f"unable to store message: {e}") @@ -473,12 +454,12 @@ class Action: """Return the needs of this action.""" return self._need_body - def execute(self, milter, msg, pretend=None): + def execute(self, milter, pretend=None): """Execute configured action.""" if pretend is None: pretend = self.pretend logger = CustomLogger(self.logger, {"qid": milter.qid}) - return self._func(milter=milter, msg=msg, pretend=pretend, + return self._func(milter=milter, pretend=pretend, logger=logger, **self._args)