Fix queue handling and log level config

This commit is contained in:
2020-02-17 12:59:49 +01:00
parent 8f8e075541
commit 693d7ac3e1
2 changed files with 34 additions and 10 deletions

View File

@@ -5,7 +5,7 @@ def read_file(fname):
return f.read() return f.read()
setup(name = "uvscand", setup(name = "uvscand",
version = "0.0.2", version = "0.0.3",
author = "Thomas Oettli", author = "Thomas Oettli",
author_email = "spacefreak@noop.ch", author_email = "spacefreak@noop.ch",
description = "A python daemon to perform virus scans with uvscan (McAfee) over TCP socket.", description = "A python daemon to perform virus scans with uvscan (McAfee) over TCP socket.",

View File

@@ -53,6 +53,7 @@ class AIO(asyncio.Protocol):
config = None config = None
queue = asyncio.Queue() queue = asyncio.Queue()
separator = b"\x00" separator = b"\x00"
completed = False
def __init__(self): def __init__(self):
if not AIO.config: if not AIO.config:
@@ -65,7 +66,7 @@ class AIO(asyncio.Protocol):
def connection_made(self, transport): def connection_made(self, transport):
self.peer = transport.get_extra_info("peername") self.peer = transport.get_extra_info("peername")
self.logger.debug("new connection from {}".format(self.peer)) self.logger.info("new connection from {}".format(self.peer))
self.transport = transport self.transport = transport
self.request_time = str(time.time()) self.request_time = str(time.time())
@@ -88,34 +89,53 @@ class AIO(asyncio.Protocol):
pos += 1 pos += 1
if command == "zINSTREAM": if command == "zINSTREAM":
# save data chunks to temporary file # save data chunks to temporary file
self.tmpfile = os.path.join(AIO.config["tmpdir"], "uvscan_{}_{}".format(self.request_time, str(self.peer[1]))) tmpfile = os.path.join(AIO.config["tmpdir"], "uvscan_{}_{}".format(self.request_time, str(self.peer[1])))
self.logger.debug("save data from {} in temporary file {}".format(self.peer, self.tmpfile)) self.logger.debug("save data from {} in temporary file {}".format(self.peer, tmpfile))
with open(self.tmpfile, "wb") as f: with open(tmpfile, "wb") as f:
self.tmpfile = tmpfile
while True: while True:
length = struct.unpack(">I", self.data[pos:pos + 4])[0] length = struct.unpack(">I", self.data[pos:pos + 4])[0]
if length == 0: break if length == 0: break
pos += 4 pos += 4
f.write(self.data[pos:pos + length]) f.write(self.data[pos:pos + length])
pos += length pos += length
self.logger.debug("starting uvscan for file {}".format(self.tmpfile)) AIO.queue.put_nowait((AIO.config["uvscan_path"], tmpfile, self.process_uvscan_result))
asyncio.async(AIO.queue.put((AIO.config["uvscan_path"], self.tmpfile, self.send_response))) self.logger.info("queued uvscan of {}, queue size is {}".format(tmpfile, AIO.queue.qsize()))
else: else:
raise RuntimeError("unknown command") raise RuntimeError("unknown command")
except (RuntimeError, IndexError, IOError, struct.error) as e: except (RuntimeError, IndexError, IOError, struct.error) as e:
self.logger.warning("warning: {}".format(e))
self.send_response(str(e)) self.send_response(str(e))
def send_response(self, response): def send_response(self, response):
response = response.encode() response = response.encode() + AIO.separator
response += AIO.separator
self.logger.debug("sending response to {}: {}".format(self.peer, response)) self.logger.debug("sending response to {}: {}".format(self.peer, response))
self.transport.write(response) self.transport.write(response)
self.transport.close() self.transport.close()
def process_uvscan_result(self, result):
self.logger.info("received uvscan result of {}: {}".format(self.tmpfile, result))
self.completed = True
self.send_response(result)
def connection_lost(self, exc): def connection_lost(self, exc):
if self.tmpfile: if self.tmpfile:
if not self.completed:
self.logger.warning("client prematurely closed connection, removing {} from scan queue".format(self.tmpfile))
entries = []
try:
for entry in iter(AIO.queue.get_nowait, None):
if not entry:
continue
if entry[1] != self.tmpfile:
entries.append(entry)
except asyncio.QueueEmpty:
pass
for entry in entries:
AIO.queue.put_nowait(entry)
self.logger.debug("removing temporary file {}".format(self.tmpfile)) self.logger.debug("removing temporary file {}".format(self.tmpfile))
os.remove(self.tmpfile) os.remove(self.tmpfile)
self.logger.debug("closed connection to {}".format(self.peer)) self.logger.info("closed connection to {}".format(self.peer))
def main(): def main():
@@ -172,6 +192,10 @@ def main():
logger.error("option '{}' not present in config section 'uvscand'".format(option)) logger.error("option '{}' not present in config section 'uvscand'".format(option))
sys.exit(1) sys.exit(1)
# set loglevel according to config
stdouthandler.setLevel(int(config["loglevel"]))
sysloghandler.setLevel(int(config["loglevel"]))
# check if uvscan binary exists and is executable # check if uvscan binary exists and is executable
if not os.path.isfile(config["uvscan_path"]) or not os.access(config["uvscan_path"], os.X_OK): if not os.path.isfile(config["uvscan_path"]) or not os.access(config["uvscan_path"], os.X_OK):
logger.error("uvscan binary '{}' does not exist or is not executable".format(config["uvscan_path"])) logger.error("uvscan binary '{}' does not exist or is not executable".format(config["uvscan_path"]))