From 393c87db0823b9c67999075b522edabc97835470 Mon Sep 17 00:00:00 2001 From: Thomas Oettli Date: Sun, 16 Jun 2024 01:15:33 +0200 Subject: [PATCH] refactor code to better handle premature disconnects --- uvscand/__init__.py | 56 ++++++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/uvscand/__init__.py b/uvscand/__init__.py index 18d49ed..702ee74 100644 --- a/uvscand/__init__.py +++ b/uvscand/__init__.py @@ -38,6 +38,7 @@ async def uvscan_worker(queue): uvscan, filename, cb = job proc = await asyncio.create_subprocess_exec(uvscan, "--secure", "--mime", "--noboot", "--panalyse", "--manalyse", filename, stdout=asyncio.subprocess.PIPE) stdout, _ = await proc.communicate() + os.remove(filename) if proc.returncode == 13: match = uvscan_regex.search(stdout.decode()) name = match.group(1) if match else "UNKNOWN" @@ -136,38 +137,37 @@ class AIO(asyncio.Protocol): self._send_response(str(e)) def process_uvscan_result(self, result): - self.logger.debug(f"{self.peer} removing temporary file {self.tmpfile}") - os.remove(self.tmpfile) + if not self.tmpfile: + return + + self.logger.info(f"{self.peer} received uvscan result of {self.tmpfile}: {result}") + self._send_response(result) self.tmpfile = None - if not self.cancelled: - self.logger.info(f"{self.peer} received uvscan result of {self.tmpfile}: {result}") - self._send_response(result) def connection_lost(self, exc): - if self.tmpfile: - entries = [] - try: - for entry in iter(AIO.queue.get_nowait, None): - if not entry: - continue - if entry[1] != self.tmpfile: - entries.append(entry) - else: - self.cancelled = True - except asyncio.QueueEmpty: - pass - for entry in entries: - AIO.queue.put_nowait(entry) - if self.cancelled: - self.logger.warning(f"{self.peer} client prematurely closed connection, skipped scan of {self.tmpfile}") - self.logger.debug(f"{self.peer} removing temporary file {self.tmpfile}") - os.remove(self.tmpfile) - else: - self.logger.warning(f"{self.peer} client prematurely closed connection") - self.cancelled = True - - else: + if not self.tmpfile: self.logger.info(f"closed connection to {self.peer}") + return + + entries = [] + try: + for entry in iter(AIO.queue.get_nowait, None): + if not entry: + continue + if self.tmpfile and entry[1] == self.tmpfile: + self.logger.warning(f"{self.peer} client prematurely closed connection, skipping scan of {self.tmpfile}") + os.remove(self.tmpfile) + self.tmpfile = None + continue + entries.append(entry) + except asyncio.QueueEmpty: + pass + for entry in entries: + AIO.queue.put_nowait(entry) + + if self.tmpfile: + self.logger.warning(f"{self.peer} client prematurely closed connection, but scan is already running") + self.tmpfile = None def main():