diff --git a/pymodmilter/__init__.py b/pymodmilter/__init__.py index c284486..41db62e 100644 --- a/pymodmilter/__init__.py +++ b/pymodmilter/__init__.py @@ -26,35 +26,13 @@ import Milter import logging from Milter.utils import parse_addr -from email.charset import Charset -from email.header import Header, decode_header -from io import BytesIO +from email.message import Message +from email.parser import BytesFeedParser +from email.policy import default as default_policy from pymodmilter.conditions import Conditions -def make_header(decoded_seq, maxlinelen=None, header_name=None, - continuation_ws=' ', errors='strict'): - """Create a Header from a sequence of pairs as returned by decode_header() - - decode_header() takes a header value string and returns a sequence of - pairs of the format (decoded_string, charset) where charset is the string - name of the character set. - - This function takes one of those sequence of pairs and returns a Header - instance. Optional maxlinelen, header_name, and continuation_ws are as in - the Header constructor. - """ - h = Header(maxlinelen=maxlinelen, header_name=header_name, - continuation_ws=continuation_ws) - for s, charset in decoded_seq: - # None means us-ascii but we can simply pass it on to h.append() - if charset is not None and not isinstance(charset, Charset): - charset = Charset(charset) - h.append(s, charset, errors=errors) - return h - - class CustomLogger(logging.LoggerAdapter): def process(self, msg, kwargs): if "name" in self.extra: @@ -91,17 +69,15 @@ class Rule: self.actions = actions self.pretend = pretend - self._needs = [] + self._need_body = False for action in actions: - for need in action.needs(): - if need not in self._needs: - self._needs.append(need) + if action.need_body(): + self._need_body = True + break - self.logger.debug("needs: {}".format(", ".join(self._needs))) - - def needs(self): - """Return the needs of this rule.""" - return self._needs + def need_body(self): + """Return the if this rule needs the message body.""" + return self._need_body def ignores(self, host=None, envfrom=None, envto=None): args = {} @@ -122,17 +98,53 @@ class Rule: return True - def execute(self, milter, pretend=None): + def execute(self, milter, msg, pretend=None): """Execute all actions of this rule.""" if pretend is None: pretend = self.pretend for action in self.actions: - milter_action = action.execute(milter) + milter_action = action.execute(milter, msg, pretend=pretend) if milter_action is not None: return milter_action +class MilterMessage(Message): + def replace_header(self, _name, _value, occ=None): + _name = _name.lower() + counter = 0 + for i, (k, v) in zip(range(len(self._headers)), self._headers): + if k.lower() == _name: + counter += 1 + if not occ or counter == occ: + self._headers[i] = self.policy.header_store_parse( + k, _value) + break + + else: + raise KeyError(_name) + + def remove_header(self, name, occ=None): + name = name.lower() + newheaders = [] + counter = 0 + for k, v in self._headers: + if k.lower() == name: + counter += 1 + if counter != occ: + newheaders.append((k, v)) + else: + newheaders.append((k, v)) + + self._headers = newheaders + + +def remove_surrogates(string): + return string.encode( + "ascii", errors="surrogateescape").decode( + "ascii", errors="replace") + + class ModifyMilter(Milter.Base): """ModifyMilter based on Milter.Base to implement milter communication""" @@ -175,7 +187,7 @@ class ModifyMilter(Milter.Base): return Milter.ACCEPT except Exception as e: self.logger.exception( - f"an exception occured in connect function: {e}") + f"an exception occured in connect method: {e}") return Milter.TEMPFAIL return Milter.CONTINUE @@ -196,7 +208,7 @@ class ModifyMilter(Milter.Base): self.recipients = set() except Exception as e: self.logger.exception( - f"an exception occured in envfrom function: {e}") + f"an exception occured in envfrom method: {e}") return Milter.TEMPFAIL return Milter.CONTINUE @@ -207,7 +219,7 @@ class ModifyMilter(Milter.Base): self.recipients.add("@".join(parse_addr(to)).lower()) except Exception as e: self.logger.exception( - f"an exception occured in envrcpt function: {e}") + f"an exception occured in envrcpt method: {e}") return Milter.TEMPFAIL return Milter.CONTINUE @@ -231,64 +243,64 @@ class ModifyMilter(Milter.Base): self.fields = None self.fields_bytes = None self.body_data = None - needs = [] + + self._fp = BytesFeedParser( + _factory=MilterMessage, policy=default_policy) + self._keep_body = False for rule in self.rules: - needs += rule.needs() - - if "fields" in needs: - self.fields = [] - - if "fields_bytes" in needs: - self.fields_bytes = [] - - if "body" in needs: - self.body_data = BytesIO() + if rule.need_body(): + self._keep_body = True + break except Exception as e: self.logger.exception( - f"an exception occured in data function: {e}") + f"an exception occured in data method: {e}") return Milter.TEMPFAIL return Milter.CONTINUE - def header(self, name, value): + def header(self, field, value): try: - if self.fields_bytes is not None: - self.fields_bytes.append( - (name.encode("ascii", errors="surrogateescape"), - value.encode("ascii", errors="surrogateescape"))) + # feed header line to BytesParser + self._fp.feed(field + b": " + value + b"\r\n") - if self.fields is not None: - # remove surrogates from value - value = value.encode( - errors="surrogateescape").decode(errors="replace") - self.logger.debug(f"received header: {name}: {value}") - header = make_header(decode_header(value), errors="replace") - value = str(header).replace("\x00", "") - self.logger.debug(f"decoded header: {name}: {value}") - self.fields.append((name, value)) + # remove surrogates from field and value + field = remove_surrogates(field) + value = remove_surrogates(value) except Exception as e: self.logger.exception( - f"an exception occured in header function: {e}") + f"an exception occured in header method: {e}") + return Milter.TEMPFAIL + + return Milter.CONTINUE + + def eoh(self): + try: + self._fp.feed(b"\r\n") + except Exception as e: + self.logger.exception( + f"an exception occured in eoh method: {e}") return Milter.TEMPFAIL return Milter.CONTINUE def body(self, chunk): try: - if self.body_data is not None: - self.body_data.write(chunk) + if self._keep_body: + self._fp.feed(chunk) except Exception as e: self.logger.exception( - f"an exception occured in body function: {e}") + f"an exception occured in body method: {e}") return Milter.TEMPFAIL return Milter.CONTINUE def eom(self): try: + msg = self._fp.close() + for rule in self.rules: - milter_action = rule.execute(self) + milter_action = rule.execute(self, msg) if milter_action is not None: if milter_action["action"] == "reject": @@ -303,7 +315,7 @@ class ModifyMilter(Milter.Base): except Exception as e: self.logger.exception( - f"an exception occured in eom function: {e}") + f"an exception occured in eom method: {e}") return Milter.TEMPFAIL return Milter.ACCEPT diff --git a/pymodmilter/actions.py b/pymodmilter/actions.py index 163da45..2b1c382 100644 --- a/pymodmilter/actions.py +++ b/pymodmilter/actions.py @@ -21,10 +21,8 @@ from collections import defaultdict from copy import copy from datetime import datetime from email.header import Header -from email.parser import BytesFeedParser from email.message import MIMEPart -from email.policy import default as default_policy, SMTP -from shutil import copyfileobj +from email.policy import SMTP from pymodmilter import CustomLogger, Conditions @@ -37,7 +35,7 @@ def _replace_illegal_chars(string): "\n", "") -def add_header(field, value, milter, idx=-1, pretend=False, +def add_header(milter, msg, field, value, pretend=False, logger=logging.getLogger(__name__)): """Add a mail header field.""" header = f"{field}: {value}" @@ -46,21 +44,18 @@ def add_header(field, value, milter, idx=-1, pretend=False, else: logger.info(f"add_header: {header[0:70]}") - if idx == -1: - milter.fields.append((field, value)) - else: - milter.fields.insert(idx, (field, value)) + msg.add_header(field, value) if pretend: return encoded_value = _replace_illegal_chars( Header(s=value).encode()) - milter.logger.debug(f"milter: addheader: {field}[{idx}]: {encoded_value}") - milter.addheader(field, encoded_value, idx) + milter.logger.debug(f"milter: addheader: {field}: {encoded_value}") + milter.addheader(field, encoded_value, -1) -def mod_header(field, value, milter, search=None, pretend=False, +def mod_header(milter, msg, field, value, search=None, pretend=False, logger=logging.getLogger(__name__)): """Change the value of a mail header field.""" if isinstance(field, str): @@ -71,8 +66,9 @@ def mod_header(field, value, milter, search=None, pretend=False, occ = defaultdict(int) - for idx, (f, v) in enumerate(milter.fields): - occ[f] += 1 + for i, (f, v) in enumerate(msg.items()): + f_lower = f.lower() + occ[f_lower] += 1 if not field.match(f): continue @@ -93,12 +89,13 @@ def mod_header(field, value, milter, search=None, pretend=False, header = f"{f}: {v}" new_header = f"{f}: {new_v}" + if logger.getEffectiveLevel() == logging.DEBUG: logger.debug(f"mod_header: {header}: {new_header}") else: logger.info(f"mod_header: {header[0:70]}: {new_header[0:70]}") - milter.fields[idx] = (f, new_v) + msg.replace_header(f, new_v, occ=occ[f_lower]) if pretend: continue @@ -106,11 +103,11 @@ def mod_header(field, value, milter, search=None, pretend=False, encoded_value = _replace_illegal_chars( Header(s=new_v).encode()) milter.logger.debug( - f"milter: chgheader: {f}[{occ[f]}]: {encoded_value}") - milter.chgheader(f, occ[f], encoded_value) + f"milter: chgheader: {f}[{occ[f_lower]}]: {encoded_value}") + milter.chgheader(f, occ[f_lower], encoded_value) -def del_header(field, milter, value=None, pretend=False, +def del_header(milter, msg, field, value=None, pretend=False, logger=logging.getLogger(__name__)): """Delete a mail header field.""" if isinstance(field, str): @@ -119,14 +116,13 @@ def del_header(field, milter, value=None, pretend=False, if isinstance(value, str): value = re.compile(value, re.MULTILINE + re.DOTALL + re.IGNORECASE) - idx = -1 occ = defaultdict(int) # iterate a copy of milter.fields because elements may get removed # during iteration - for f, v in milter.fields.copy(): - idx += 1 - occ[f] += 1 + for f, v in msg.items(): + f_lower = f.lower() + occ[f_lower] += 1 if not field.match(f): continue @@ -140,16 +136,14 @@ def del_header(field, milter, value=None, pretend=False, else: logger.info(f"del_header: {header[0:70]}") - del milter.fields[idx] + msg.remove_header(f, occ=occ[f_lower]) + + occ[f_lower] -= 1 if not pretend: - encoded_value = "" milter.logger.debug( - f"milter: chgheader: {f}[{occ[f]}]: {encoded_value}") - milter.chgheader(f, occ[f], encoded_value) - - idx -= 1 - occ[f] -= 1 + f"milter: chgheader: {f}[{occ[f_lower]}]:") + milter.chgheader(f, occ[f_lower], "") def _get_body_content(msg, body_type): @@ -273,31 +267,9 @@ def _inject_body(milter, msg): return new_msg -def add_disclaimer(text, html, action, policy, milter, pretend=False, +def add_disclaimer(milter, msg, text, html, action, policy, pretend=False, logger=logging.getLogger(__name__)): """Append or prepend a disclaimer to the mail body.""" - milter.body_data.seek(0) - fp = BytesFeedParser(policy=default_policy) - - for field, value in milter.fields_bytes: - decoded_field = field.decode("ascii") - decoded_value = value.decode("ascii") - field_lower = decoded_field.lower() - if not field_lower.startswith("content-") and \ - field_lower != "mime-version": - continue - logger.debug( - f"feed content header to message object: " - f"{decoded_field}: {decoded_value}") - fp.feed(field + b": " + value + b"\r\n") - - fp.feed(b"\r\n") - logger.debug(f"feed body to message object") - fp.feed(milter.body_data.read()) - - logger.debug("parse message") - msg = fp.close() - update_headers = False try: @@ -367,48 +339,47 @@ def add_disclaimer(text, html, action, policy, milter, pretend=False, "value": msg.get("Content-Transfer-Encoding"), "modified": False}} - for field, value in milter.fields.copy(): + for field, value in msg.items(): field_lower = field.lower() if field_lower in fields and fields[field_lower]["value"] is not None: - mod_header(field=f"^{field}$", value=fields[field_lower]["value"], - milter=milter, pretend=pretend, logger=logger) + mod_header(milter, msg, field=f"^{field}$", + value=fields[field_lower]["value"], + pretend=pretend, logger=logger) fields[field_lower]["modified"] = True elif field_lower.startswith("content-"): - del_header(field=f"^{field}$", milter=milter, + del_header(milter, msg, field=f"^{field}$", pretend=pretend, logger=logger) for field in fields.values(): if not field["modified"] and field["value"] is not None: - add_header(field=field["field"], value=field["value"], - milter=milter, pretend=pretend, logger=logger) + add_header(milter, msg, field=field["field"], value=field["value"], + pretend=pretend, logger=logger) -def store(directory, milter, pretend=False, +def store(milter, msg, directory, pretend=False, logger=logging.getLogger(__name__)): timestamp = datetime.now().strftime("%Y%m%d%H%M%S") store_id = f"{timestamp}_{milter.qid}" datafile = os.path.join(directory, store_id) + milter.body_data.seek(0) logger.info(f"store message in file {datafile}") try: with open(datafile, "wb") as fp: - for field, value in milter.fields_bytes: - fp.write(field + b": " + value + b"\r\n") - - copyfileobj(milter.body_data, fp) + fp.write(msg.as_bytes()) except IOError as e: raise RuntimeError(f"unable to store message: {e}") class Action: """Action to implement a pre-configured action to perform on e-mails.""" - _types = { - "add_header": ["fields"], - "del_header": ["fields"], - "mod_header": ["fields"], - "add_disclaimer": ["fields", "body"], - "store": ["fields_bytes", "body"]} + _need_body_map = { + "add_header": False, + "del_header": False, + "mod_header": False, + "add_disclaimer": True, + "store": True} def __init__(self, name, local_addrs, conditions, action_type, args, loglevel=logging.INFO, pretend=False): @@ -423,9 +394,9 @@ class Action: self.pretend = pretend self._args = {} - if action_type not in self._types: + if action_type not in self._need_body_map: raise RuntimeError(f"invalid action type '{action_type}'") - self._needs = self._types[action_type] + self._need_body = self._need_body_map[action_type] try: if action_type == "add_header": @@ -498,16 +469,16 @@ class Action: raise RuntimeError( f"mandatory argument not found: {e}") - def needs(self): + def need_body(self): """Return the needs of this action.""" - return self._needs + return self._need_body - def execute(self, milter, pretend=None): + def execute(self, milter, msg, pretend=None): """Execute configured action.""" if pretend is None: pretend = self.pretend logger = CustomLogger(self.logger, {"qid": milter.qid}) - return self._func( - milter=milter, pretend=pretend, logger=logger, **self._args) + return self._func(milter=milter, msg=msg, pretend=pretend, + logger=logger, **self._args)