Fix protocol handling
This commit is contained in:
2
setup.py
2
setup.py
@@ -5,7 +5,7 @@ def read_file(fname):
|
|||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
setup(name = "uvscand",
|
setup(name = "uvscand",
|
||||||
version = "0.0.3",
|
version = "0.0.4",
|
||||||
author = "Thomas Oettli",
|
author = "Thomas Oettli",
|
||||||
author_email = "spacefreak@noop.ch",
|
author_email = "spacefreak@noop.ch",
|
||||||
description = "A python daemon to perform virus scans with uvscan (McAfee) over TCP socket.",
|
description = "A python daemon to perform virus scans with uvscan (McAfee) over TCP socket.",
|
||||||
|
|||||||
@@ -53,7 +53,6 @@ class AIO(asyncio.Protocol):
|
|||||||
config = None
|
config = None
|
||||||
queue = asyncio.Queue()
|
queue = asyncio.Queue()
|
||||||
separator = b"\x00"
|
separator = b"\x00"
|
||||||
completed = False
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not AIO.config:
|
if not AIO.config:
|
||||||
@@ -61,67 +60,89 @@ class AIO(asyncio.Protocol):
|
|||||||
if not AIO.queue:
|
if not AIO.queue:
|
||||||
raise RuntimeError("queue not set")
|
raise RuntimeError("queue not set")
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.data = bytearray()
|
|
||||||
self.tmpfile = None
|
self.tmpfile = None
|
||||||
|
|
||||||
|
def _send_response(self, response):
|
||||||
|
response = response.encode() + AIO.separator
|
||||||
|
self.logger.debug("{} sending response: {}".format(self.peer, response))
|
||||||
|
self.transport.write(response)
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
|
|
||||||
def connection_made(self, transport):
|
def connection_made(self, transport):
|
||||||
self.peer = transport.get_extra_info("peername")
|
self.peer = transport.get_extra_info("peername")
|
||||||
self.logger.info("new connection from {}".format(self.peer))
|
self.logger.info("new connection from {}".format(self.peer))
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.request_time = str(time.time())
|
self.request_time = str(time.time())
|
||||||
|
self.buffer = bytearray()
|
||||||
|
self.data = bytearray()
|
||||||
|
self.command = None
|
||||||
|
self.length = None
|
||||||
|
self.all_chunks = False
|
||||||
|
self.completed = False
|
||||||
|
|
||||||
def data_received(self, data):
|
def data_received(self, data):
|
||||||
self.logger.debug("data received from {}".format(self.peer))
|
|
||||||
self.data.extend(data)
|
|
||||||
if self.data[-4:] == b"\x00\x00\x00\x00":
|
|
||||||
self.logger.debug("last data chunk received from {}".format(self.peer))
|
|
||||||
self.process_request()
|
|
||||||
else:
|
|
||||||
self.logger.debug("received data chunk from {}".format(self.peer))
|
|
||||||
|
|
||||||
def process_request(self):
|
|
||||||
try:
|
try:
|
||||||
if self.data[0] != ord(b"z"):
|
if self.all_chunks:
|
||||||
|
self.logger.warning("{} received {} bytes of garbage after last chunk".format(self.peer, len(data)))
|
||||||
|
return
|
||||||
|
self.logger.debug("{} received {} bytes".format(self.peer, len(data)))
|
||||||
|
self.buffer.extend(data)
|
||||||
|
|
||||||
|
if not self.command:
|
||||||
|
if len(self.buffer) < 10:
|
||||||
|
return
|
||||||
|
if self.buffer[0] != ord(b"z"):
|
||||||
raise RuntimeError("protocol error")
|
raise RuntimeError("protocol error")
|
||||||
pos = self.data.index(ord(AIO.separator))
|
pos = self.buffer.index(ord(AIO.separator))
|
||||||
# parse command
|
# parse command
|
||||||
command = self.data[0:pos].decode()
|
command = self.buffer[0:pos].decode()
|
||||||
|
if command != "zINSTREAM":
|
||||||
|
raise RuntimeError("unknown command")
|
||||||
|
self.command = command
|
||||||
|
self.logger.debug("{} command is {}".format(self.peer, command))
|
||||||
pos += 1
|
pos += 1
|
||||||
if command == "zINSTREAM":
|
self.buffer = self.buffer[pos:]
|
||||||
# save data chunks to temporary file
|
if self.command:
|
||||||
|
while True:
|
||||||
|
if not self.length:
|
||||||
|
if len(self.buffer) < 4:
|
||||||
|
break
|
||||||
|
self.length = struct.unpack(">I", self.buffer[0:4])[0]
|
||||||
|
self.buffer = self.buffer[4:]
|
||||||
|
if self.length == 0:
|
||||||
|
self.all_chunks = True
|
||||||
tmpfile = os.path.join(AIO.config["tmpdir"], "uvscan_{}_{}".format(self.request_time, str(self.peer[1])))
|
tmpfile = os.path.join(AIO.config["tmpdir"], "uvscan_{}_{}".format(self.request_time, str(self.peer[1])))
|
||||||
self.logger.debug("save data from {} in temporary file {}".format(self.peer, tmpfile))
|
self.logger.debug("{} got last chunk, save data to {}".format(self.peer, tmpfile))
|
||||||
with open(tmpfile, "wb") as f:
|
with open(tmpfile, "wb") as f:
|
||||||
self.tmpfile = tmpfile
|
self.tmpfile = tmpfile
|
||||||
while True:
|
f.write(self.data)
|
||||||
length = struct.unpack(">I", self.data[pos:pos + 4])[0]
|
|
||||||
if length == 0: break
|
|
||||||
pos += 4
|
|
||||||
f.write(self.data[pos:pos + length])
|
|
||||||
pos += length
|
|
||||||
AIO.queue.put_nowait((AIO.config["uvscan_path"], tmpfile, self.process_uvscan_result))
|
AIO.queue.put_nowait((AIO.config["uvscan_path"], tmpfile, self.process_uvscan_result))
|
||||||
self.logger.info("queued uvscan of {}, queue size is {}".format(tmpfile, AIO.queue.qsize()))
|
self.logger.info("{} queued uvscan of {}, queue size is {}".format(self.peer, tmpfile, AIO.queue.qsize()))
|
||||||
|
break
|
||||||
|
self.logger.debug("{} got chunk size of {} bytes".format(self.peer, self.length))
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("unknown command")
|
if len(self.buffer) < self.length:
|
||||||
except (RuntimeError, IndexError, IOError, struct.error) as e:
|
self.logger.debug("{} got {} of {} bytes".format(self.peer, len(self.buffer), self.length))
|
||||||
self.logger.warning("warning: {}".format(e))
|
break
|
||||||
self.send_response(str(e))
|
self.logger.debug("{} chunk complete ({} bytes)".format(self.peer, self.length))
|
||||||
|
self.data.extend(self.buffer[0:self.length])
|
||||||
|
self.buffer = self.buffer[self.length:]
|
||||||
|
self.length = None
|
||||||
|
|
||||||
def send_response(self, response):
|
except (RuntimeError, IndexError, IOError, struct.error) as e:
|
||||||
response = response.encode() + AIO.separator
|
self.logger.warning("{} warning: {}".format(self.peer, e))
|
||||||
self.logger.debug("sending response to {}: {}".format(self.peer, response))
|
self._send_response(str(e))
|
||||||
self.transport.write(response)
|
|
||||||
self.transport.close()
|
|
||||||
|
|
||||||
def process_uvscan_result(self, result):
|
def process_uvscan_result(self, result):
|
||||||
self.logger.info("received uvscan result of {}: {}".format(self.tmpfile, result))
|
self.logger.info("{} received uvscan result of {}: {}".format(self.peer, self.tmpfile, result))
|
||||||
self.completed = True
|
self.completed = True
|
||||||
self.send_response(result)
|
self._send_response(result)
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
if self.tmpfile:
|
if self.tmpfile:
|
||||||
if not self.completed:
|
if not self.completed:
|
||||||
self.logger.warning("client prematurely closed connection, removing {} from scan queue".format(self.tmpfile))
|
self.logger.warning("{} client prematurely closed connection, removing {} from scan queue".format(self.peer, self.tmpfile))
|
||||||
entries = []
|
entries = []
|
||||||
try:
|
try:
|
||||||
for entry in iter(AIO.queue.get_nowait, None):
|
for entry in iter(AIO.queue.get_nowait, None):
|
||||||
@@ -133,7 +154,7 @@ class AIO(asyncio.Protocol):
|
|||||||
pass
|
pass
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
AIO.queue.put_nowait(entry)
|
AIO.queue.put_nowait(entry)
|
||||||
self.logger.debug("removing temporary file {}".format(self.tmpfile))
|
self.logger.debug("{} removing temporary file {}".format(self.peer, self.tmpfile))
|
||||||
os.remove(self.tmpfile)
|
os.remove(self.tmpfile)
|
||||||
self.logger.info("closed connection to {}".format(self.peer))
|
self.logger.info("closed connection to {}".format(self.peer))
|
||||||
|
|
||||||
@@ -192,6 +213,7 @@ def main():
|
|||||||
logger.error("option '{}' not present in config section 'uvscand'".format(option))
|
logger.error("option '{}' not present in config section 'uvscand'".format(option))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not args.debug:
|
||||||
# set loglevel according to config
|
# set loglevel according to config
|
||||||
stdouthandler.setLevel(int(config["loglevel"]))
|
stdouthandler.setLevel(int(config["loglevel"]))
|
||||||
sysloghandler.setLevel(int(config["loglevel"]))
|
sysloghandler.setLevel(int(config["loglevel"]))
|
||||||
|
|||||||
Reference in New Issue
Block a user