diff options
Diffstat (limited to 'poky/bitbake/lib/hashserv/server.py')
-rw-r--r-- | poky/bitbake/lib/hashserv/server.py | 105 |
1 files changed, 72 insertions, 33 deletions
diff --git a/poky/bitbake/lib/hashserv/server.py b/poky/bitbake/lib/hashserv/server.py index cc7e48233b..81050715ea 100644 --- a/poky/bitbake/lib/hashserv/server.py +++ b/poky/bitbake/lib/hashserv/server.py @@ -13,6 +13,7 @@ import os import signal import socket import time +from . import chunkify, DEFAULT_MAX_CHUNK logger = logging.getLogger('hashserv.server') @@ -107,12 +108,29 @@ class Stats(object): return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} +class ClientError(Exception): + pass + class ServerClient(object): + FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' + ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' + def __init__(self, reader, writer, db, request_stats): self.reader = reader self.writer = writer self.db = db self.request_stats = request_stats + self.max_chunk = DEFAULT_MAX_CHUNK + + self.handlers = { + 'get': self.handle_get, + 'report': self.handle_report, + 'report-equiv': self.handle_equivreport, + 'get-stream': self.handle_get_stream, + 'get-stats': self.handle_get_stats, + 'reset-stats': self.handle_reset_stats, + 'chunk-stream': self.handle_chunk, + } async def process_requests(self): try: @@ -125,7 +143,11 @@ class ServerClient(object): return (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() - if proto_name != 'OEHASHEQUIV' or proto_version != '1.0': + if proto_name != 'OEHASHEQUIV': + return + + proto_version = tuple(int(v) for v in proto_version.split('.')) + if proto_version < (1, 0) or proto_version > (1, 1): return # Read headers. Currently, no headers are implemented, so look for @@ -140,40 +162,34 @@ class ServerClient(object): break # Handle messages - handlers = { - 'get': self.handle_get, - 'report': self.handle_report, - 'report-equiv': self.handle_equivreport, - 'get-stream': self.handle_get_stream, - 'get-stats': self.handle_get_stats, - 'reset-stats': self.handle_reset_stats, - } - while True: d = await self.read_message() if d is None: break - - for k in handlers.keys(): - if k in d: - logger.debug('Handling %s' % k) - if 'stream' in k: - await handlers[k](d[k]) - else: - with self.request_stats.start_sample() as self.request_sample, \ - self.request_sample.measure(): - await handlers[k](d[k]) - break - else: - logger.warning("Unrecognized command %r" % d) - break - + await self.dispatch_message(d) await self.writer.drain() + except ClientError as e: + logger.error(str(e)) finally: self.writer.close() + async def dispatch_message(self, msg): + for k in self.handlers.keys(): + if k in msg: + logger.debug('Handling %s' % k) + if 'stream' in k: + await self.handlers[k](msg[k]) + else: + with self.request_stats.start_sample() as self.request_sample, \ + self.request_sample.measure(): + await self.handlers[k](msg[k]) + return + + raise ClientError("Unrecognized command %r" % msg) + def write_message(self, msg): - self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8')) + for c in chunkify(json.dumps(msg), self.max_chunk): + self.writer.write(c.encode('utf-8')) async def read_message(self): l = await self.reader.readline() @@ -191,14 +207,38 @@ class ServerClient(object): logger.error('Bad message from client: %r' % message) raise e + async def handle_chunk(self, request): + lines = [] + try: + while True: + l = await self.reader.readline() + l = l.rstrip(b"\n").decode("utf-8") + if not l: + break + lines.append(l) + + msg = json.loads(''.join(lines)) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + logger.error('Bad message from client: %r' % message) + raise e + + if 'chunk-stream' in msg: + raise ClientError("Nested chunks are not allowed") + + await self.dispatch_message(msg) + async def handle_get(self, request): method = request['method'] taskhash = request['taskhash'] - row = self.query_equivalent(method, taskhash) + if request.get('all', False): + row = self.query_equivalent(method, taskhash, self.ALL_QUERY) + else: + row = self.query_equivalent(method, taskhash, self.FAST_QUERY) + if row is not None: logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) - d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} + d = {k: row[k] for k in row.keys()} self.write_message(d) else: @@ -228,7 +268,7 @@ class ServerClient(object): (method, taskhash) = l.split() #logger.debug('Looking up %s %s' % (method, taskhash)) - row = self.query_equivalent(method, taskhash) + row = self.query_equivalent(method, taskhash, self.FAST_QUERY) if row is not None: msg = ('%s\n' % row['unihash']).encode('utf-8') #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) @@ -328,7 +368,7 @@ class ServerClient(object): # Fetch the unihash that will be reported for the taskhash. If the # unihash matches, it means this row was inserted (or the mapping # was already valid) - row = self.query_equivalent(data['method'], data['taskhash']) + row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) if row['unihash'] == data['unihash']: logger.info('Adding taskhash equivalence for %s with unihash %s', @@ -354,12 +394,11 @@ class ServerClient(object): self.request_stats.reset() self.write_message(d) - def query_equivalent(self, method, taskhash): + def query_equivalent(self, method, taskhash, query): # This is part of the inner loop and must be as fast as possible try: cursor = self.db.cursor() - cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1', - {'method': method, 'taskhash': taskhash}) + cursor.execute(query, {'method': method, 'taskhash': taskhash}) return cursor.fetchone() except: cursor.close() |