summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/hashserv/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/hashserv/server.py')
-rw-r--r--bitbake/lib/hashserv/server.py149
1 files changed, 119 insertions, 30 deletions
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
index 81050715ea..3ff4c51ccb 100644
--- a/bitbake/lib/hashserv/server.py
+++ b/bitbake/lib/hashserv/server.py
@@ -3,7 +3,7 @@
# SPDX-License-Identifier: GPL-2.0-only
#
-from contextlib import closing
+from contextlib import closing, contextmanager
from datetime import datetime
import asyncio
import json
@@ -12,8 +12,9 @@ import math
import os
import signal
import socket
+import sys
import time
-from . import chunkify, DEFAULT_MAX_CHUNK
+from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS
logger = logging.getLogger('hashserv.server')
@@ -111,16 +112,40 @@ class Stats(object):
class ClientError(Exception):
pass
+def insert_task(cursor, data, ignore=False):
+ keys = sorted(data.keys())
+ query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % (
+ " OR IGNORE" if ignore else "",
+ ', '.join(keys),
+ ', '.join(':' + k for k in keys))
+ cursor.execute(query, data)
+
+async def copy_from_upstream(client, db, method, taskhash):
+ d = await client.get_taskhash(method, taskhash, True)
+ if d is not None:
+ # Filter out unknown columns
+ d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
+ keys = sorted(d.keys())
+
+
+ with closing(db.cursor()) as cursor:
+ insert_task(cursor, d)
+ db.commit()
+
+ return d
+
class ServerClient(object):
FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
- def __init__(self, reader, writer, db, request_stats):
+ def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream):
self.reader = reader
self.writer = writer
self.db = db
self.request_stats = request_stats
self.max_chunk = DEFAULT_MAX_CHUNK
+ self.backfill_queue = backfill_queue
+ self.upstream = upstream
self.handlers = {
'get': self.handle_get,
@@ -130,10 +155,18 @@ class ServerClient(object):
'get-stats': self.handle_get_stats,
'reset-stats': self.handle_reset_stats,
'chunk-stream': self.handle_chunk,
+ 'backfill-wait': self.handle_backfill_wait,
}
async def process_requests(self):
+ if self.upstream is not None:
+ self.upstream_client = await create_async_client(self.upstream)
+ else:
+ self.upstream_client = None
+
try:
+
+
self.addr = self.writer.get_extra_info('peername')
logger.debug('Client %r connected' % (self.addr,))
@@ -171,6 +204,9 @@ class ServerClient(object):
except ClientError as e:
logger.error(str(e))
finally:
+ if self.upstream_client is not None:
+ await self.upstream_client.close()
+
self.writer.close()
async def dispatch_message(self, msg):
@@ -239,15 +275,19 @@ class ServerClient(object):
if row is not None:
logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
d = {k: row[k] for k in row.keys()}
-
- self.write_message(d)
+ elif self.upstream_client is not None:
+ d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash)
else:
- self.write_message(None)
+ d = None
+
+ self.write_message(d)
async def handle_get_stream(self, request):
self.write_message('ok')
while True:
+ upstream = None
+
l = await self.reader.readline()
if not l:
return
@@ -272,6 +312,12 @@ class ServerClient(object):
if row is not None:
msg = ('%s\n' % row['unihash']).encode('utf-8')
#logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
+ elif self.upstream_client is not None:
+ upstream = await self.upstream_client.get_unihash(method, taskhash)
+ if upstream:
+ msg = ("%s\n" % upstream).encode("utf-8")
+ else:
+ msg = "\n".encode("utf-8")
else:
msg = '\n'.encode('utf-8')
@@ -282,6 +328,11 @@ class ServerClient(object):
await self.writer.drain()
+ # Post to the backfill queue after writing the result to minimize
+ # the turn around time on a request
+ if upstream is not None:
+ await self.backfill_queue.put((method, taskhash))
+
async def handle_report(self, data):
with closing(self.db.cursor()) as cursor:
cursor.execute('''
@@ -324,11 +375,7 @@ class ServerClient(object):
if k in data:
insert_data[k] = data[k]
- cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
- ', '.join(sorted(insert_data.keys())),
- ', '.join(':' + k for k in sorted(insert_data.keys()))),
- insert_data)
-
+ insert_task(cursor, insert_data)
self.db.commit()
logger.info('Adding taskhash %s with unihash %s',
@@ -358,11 +405,7 @@ class ServerClient(object):
if k in data:
insert_data[k] = data[k]
- cursor.execute('''INSERT OR IGNORE INTO tasks_v2 (%s) VALUES (%s)''' % (
- ', '.join(sorted(insert_data.keys())),
- ', '.join(':' + k for k in sorted(insert_data.keys()))),
- insert_data)
-
+ insert_task(cursor, insert_data, ignore=True)
self.db.commit()
# Fetch the unihash that will be reported for the taskhash. If the
@@ -394,6 +437,13 @@ class ServerClient(object):
self.request_stats.reset()
self.write_message(d)
+ async def handle_backfill_wait(self, request):
+ d = {
+ 'tasks': self.backfill_queue.qsize(),
+ }
+ await self.backfill_queue.join()
+ self.write_message(d)
+
def query_equivalent(self, method, taskhash, query):
# This is part of the inner loop and must be as fast as possible
try:
@@ -405,7 +455,7 @@ class ServerClient(object):
class Server(object):
- def __init__(self, db, loop=None):
+ def __init__(self, db, loop=None, upstream=None):
self.request_stats = Stats()
self.db = db
@@ -416,6 +466,8 @@ class Server(object):
self.loop = loop
self.close_loop = False
+ self.upstream = upstream
+
self._cleanup_socket = None
def start_tcp_server(self, host, port):
@@ -458,7 +510,7 @@ class Server(object):
async def handle_client(self, reader, writer):
# writer.transport.set_write_buffer_limits(0)
try:
- client = ServerClient(reader, writer, self.db, self.request_stats)
+ client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream)
await client.process_requests()
except Exception as e:
import traceback
@@ -467,23 +519,60 @@ class Server(object):
writer.close()
logger.info('Client disconnected')
+ @contextmanager
+ def _backfill_worker(self):
+ async def backfill_worker_task():
+ client = await create_async_client(self.upstream)
+ try:
+ while True:
+ item = await self.backfill_queue.get()
+ if item is None:
+ self.backfill_queue.task_done()
+ break
+ method, taskhash = item
+ await copy_from_upstream(client, self.db, method, taskhash)
+ self.backfill_queue.task_done()
+ finally:
+ await client.close()
+
+ async def join_worker(worker):
+ await self.backfill_queue.put(None)
+ await worker
+
+ if self.upstream is not None:
+ worker = asyncio.ensure_future(backfill_worker_task())
+ try:
+ yield
+ finally:
+ self.loop.run_until_complete(join_worker(worker))
+ else:
+ yield
+
def serve_forever(self):
def signal_handler():
self.loop.stop()
- self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
-
+ asyncio.set_event_loop(self.loop)
try:
- self.loop.run_forever()
- except KeyboardInterrupt:
- pass
+ self.backfill_queue = asyncio.Queue()
+
+ self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
- self.server.close()
- self.loop.run_until_complete(self.server.wait_closed())
- logger.info('Server shutting down')
+ with self._backfill_worker():
+ try:
+ self.loop.run_forever()
+ except KeyboardInterrupt:
+ pass
- if self.close_loop:
- self.loop.close()
+ self.server.close()
+
+ self.loop.run_until_complete(self.server.wait_closed())
+ logger.info('Server shutting down')
+ finally:
+ if self.close_loop:
+ if sys.version_info >= (3, 6):
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+ self.loop.close()
- if self._cleanup_socket is not None:
- self._cleanup_socket()
+ if self._cleanup_socket is not None:
+ self._cleanup_socket()