diff --git a/setup.py b/setup.py index 479492a..393f60e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ def read_file(fname): return f.read() setup(name = "uvscand", - version = "0.0.3", + version = "0.0.4", author = "Thomas Oettli", author_email = "spacefreak@noop.ch", description = "A python daemon to perform virus scans with uvscan (McAfee) over TCP socket.", diff --git a/uvscand/__init__.py b/uvscand/__init__.py index 446fd16..3332d94 100644 --- a/uvscand/__init__.py +++ b/uvscand/__init__.py @@ -53,7 +53,6 @@ class AIO(asyncio.Protocol): config = None queue = asyncio.Queue() separator = b"\x00" - completed = False def __init__(self): if not AIO.config: @@ -61,67 +60,89 @@ class AIO(asyncio.Protocol): if not AIO.queue: raise RuntimeError("queue not set") self.logger = logging.getLogger(__name__) - self.data = bytearray() self.tmpfile = None + def _send_response(self, response): + response = response.encode() + AIO.separator + self.logger.debug("{} sending response: {}".format(self.peer, response)) + self.transport.write(response) + self.transport.close() + + def connection_made(self, transport): self.peer = transport.get_extra_info("peername") self.logger.info("new connection from {}".format(self.peer)) self.transport = transport self.request_time = str(time.time()) + self.buffer = bytearray() + self.data = bytearray() + self.command = None + self.length = None + self.all_chunks = False + self.completed = False def data_received(self, data): - self.logger.debug("data received from {}".format(self.peer)) - self.data.extend(data) - if self.data[-4:] == b"\x00\x00\x00\x00": - self.logger.debug("last data chunk received from {}".format(self.peer)) - self.process_request() - else: - self.logger.debug("received data chunk from {}".format(self.peer)) - - def process_request(self): try: - if self.data[0] != ord(b"z"): - raise RuntimeError("protocol error") - pos = self.data.index(ord(AIO.separator)) - # parse command - command = self.data[0:pos].decode() - pos += 1 - if command == "zINSTREAM": - # save data chunks to temporary file - tmpfile = os.path.join(AIO.config["tmpdir"], "uvscan_{}_{}".format(self.request_time, str(self.peer[1]))) - self.logger.debug("save data from {} in temporary file {}".format(self.peer, tmpfile)) - with open(tmpfile, "wb") as f: - self.tmpfile = tmpfile - while True: - length = struct.unpack(">I", self.data[pos:pos + 4])[0] - if length == 0: break - pos += 4 - f.write(self.data[pos:pos + length]) - pos += length - AIO.queue.put_nowait((AIO.config["uvscan_path"], tmpfile, self.process_uvscan_result)) - self.logger.info("queued uvscan of {}, queue size is {}".format(tmpfile, AIO.queue.qsize())) - else: - raise RuntimeError("unknown command") - except (RuntimeError, IndexError, IOError, struct.error) as e: - self.logger.warning("warning: {}".format(e)) - self.send_response(str(e)) + if self.all_chunks: + self.logger.warning("{} received {} bytes of garbage after last chunk".format(self.peer, len(data))) + return + self.logger.debug("{} received {} bytes".format(self.peer, len(data))) + self.buffer.extend(data) - def send_response(self, response): - response = response.encode() + AIO.separator - self.logger.debug("sending response to {}: {}".format(self.peer, response)) - self.transport.write(response) - self.transport.close() + if not self.command: + if len(self.buffer) < 10: + return + if self.buffer[0] != ord(b"z"): + raise RuntimeError("protocol error") + pos = self.buffer.index(ord(AIO.separator)) + # parse command + command = self.buffer[0:pos].decode() + if command != "zINSTREAM": + raise RuntimeError("unknown command") + self.command = command + self.logger.debug("{} command is {}".format(self.peer, command)) + pos += 1 + self.buffer = self.buffer[pos:] + if self.command: + while True: + if not self.length: + if len(self.buffer) < 4: + break + self.length = struct.unpack(">I", self.buffer[0:4])[0] + self.buffer = self.buffer[4:] + if self.length == 0: + self.all_chunks = True + tmpfile = os.path.join(AIO.config["tmpdir"], "uvscan_{}_{}".format(self.request_time, str(self.peer[1]))) + self.logger.debug("{} got last chunk, save data to {}".format(self.peer, tmpfile)) + with open(tmpfile, "wb") as f: + self.tmpfile = tmpfile + f.write(self.data) + AIO.queue.put_nowait((AIO.config["uvscan_path"], tmpfile, self.process_uvscan_result)) + self.logger.info("{} queued uvscan of {}, queue size is {}".format(self.peer, tmpfile, AIO.queue.qsize())) + break + self.logger.debug("{} got chunk size of {} bytes".format(self.peer, self.length)) + else: + if len(self.buffer) < self.length: + self.logger.debug("{} got {} of {} bytes".format(self.peer, len(self.buffer), self.length)) + break + self.logger.debug("{} chunk complete ({} bytes)".format(self.peer, self.length)) + self.data.extend(self.buffer[0:self.length]) + self.buffer = self.buffer[self.length:] + self.length = None + + except (RuntimeError, IndexError, IOError, struct.error) as e: + self.logger.warning("{} warning: {}".format(self.peer, e)) + self._send_response(str(e)) def process_uvscan_result(self, result): - self.logger.info("received uvscan result of {}: {}".format(self.tmpfile, result)) + self.logger.info("{} received uvscan result of {}: {}".format(self.peer, self.tmpfile, result)) self.completed = True - self.send_response(result) + self._send_response(result) def connection_lost(self, exc): if self.tmpfile: if not self.completed: - self.logger.warning("client prematurely closed connection, removing {} from scan queue".format(self.tmpfile)) + self.logger.warning("{} client prematurely closed connection, removing {} from scan queue".format(self.peer, self.tmpfile)) entries = [] try: for entry in iter(AIO.queue.get_nowait, None): @@ -133,7 +154,7 @@ class AIO(asyncio.Protocol): pass for entry in entries: AIO.queue.put_nowait(entry) - self.logger.debug("removing temporary file {}".format(self.tmpfile)) + self.logger.debug("{} removing temporary file {}".format(self.peer, self.tmpfile)) os.remove(self.tmpfile) self.logger.info("closed connection to {}".format(self.peer)) @@ -192,9 +213,10 @@ def main(): logger.error("option '{}' not present in config section 'uvscand'".format(option)) sys.exit(1) - # set loglevel according to config - stdouthandler.setLevel(int(config["loglevel"])) - sysloghandler.setLevel(int(config["loglevel"])) + if not args.debug: + # set loglevel according to config + stdouthandler.setLevel(int(config["loglevel"])) + sysloghandler.setLevel(int(config["loglevel"])) # check if uvscan binary exists and is executable if not os.path.isfile(config["uvscan_path"]) or not os.access(config["uvscan_path"], os.X_OK):