refactor code to better handle premature disconnects

This commit is contained in:
2024-06-16 01:15:33 +02:00
parent 792ba7c1aa
commit 393c87db08

View File

@@ -38,6 +38,7 @@ async def uvscan_worker(queue):
uvscan, filename, cb = job uvscan, filename, cb = job
proc = await asyncio.create_subprocess_exec(uvscan, "--secure", "--mime", "--noboot", "--panalyse", "--manalyse", filename, stdout=asyncio.subprocess.PIPE) proc = await asyncio.create_subprocess_exec(uvscan, "--secure", "--mime", "--noboot", "--panalyse", "--manalyse", filename, stdout=asyncio.subprocess.PIPE)
stdout, _ = await proc.communicate() stdout, _ = await proc.communicate()
os.remove(filename)
if proc.returncode == 13: if proc.returncode == 13:
match = uvscan_regex.search(stdout.decode()) match = uvscan_regex.search(stdout.decode())
name = match.group(1) if match else "UNKNOWN" name = match.group(1) if match else "UNKNOWN"
@@ -136,38 +137,37 @@ class AIO(asyncio.Protocol):
self._send_response(str(e)) self._send_response(str(e))
def process_uvscan_result(self, result): def process_uvscan_result(self, result):
self.logger.debug(f"{self.peer} removing temporary file {self.tmpfile}") if not self.tmpfile:
os.remove(self.tmpfile) return
self.tmpfile = None
if not self.cancelled:
self.logger.info(f"{self.peer} received uvscan result of {self.tmpfile}: {result}") self.logger.info(f"{self.peer} received uvscan result of {self.tmpfile}: {result}")
self._send_response(result) self._send_response(result)
self.tmpfile = None
def connection_lost(self, exc): def connection_lost(self, exc):
if self.tmpfile: if not self.tmpfile:
self.logger.info(f"closed connection to {self.peer}")
return
entries = [] entries = []
try: try:
for entry in iter(AIO.queue.get_nowait, None): for entry in iter(AIO.queue.get_nowait, None):
if not entry: if not entry:
continue continue
if entry[1] != self.tmpfile: if self.tmpfile and entry[1] == self.tmpfile:
self.logger.warning(f"{self.peer} client prematurely closed connection, skipping scan of {self.tmpfile}")
os.remove(self.tmpfile)
self.tmpfile = None
continue
entries.append(entry) entries.append(entry)
else:
self.cancelled = True
except asyncio.QueueEmpty: except asyncio.QueueEmpty:
pass pass
for entry in entries: for entry in entries:
AIO.queue.put_nowait(entry) AIO.queue.put_nowait(entry)
if self.cancelled:
self.logger.warning(f"{self.peer} client prematurely closed connection, skipped scan of {self.tmpfile}")
self.logger.debug(f"{self.peer} removing temporary file {self.tmpfile}")
os.remove(self.tmpfile)
else:
self.logger.warning(f"{self.peer} client prematurely closed connection")
self.cancelled = True
else: if self.tmpfile:
self.logger.info(f"closed connection to {self.peer}") self.logger.warning(f"{self.peer} client prematurely closed connection, but scan is already running")
self.tmpfile = None
def main(): def main():