diff options
Diffstat (limited to 'bitbake/lib/hashserv/server.py')
-rw-r--r-- | bitbake/lib/hashserv/server.py | 986 |
1 files changed, 689 insertions, 297 deletions
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py index 81050715ea..68f64f983b 100644 --- a/bitbake/lib/hashserv/server.py +++ b/bitbake/lib/hashserv/server.py @@ -3,19 +3,51 @@ # SPDX-License-Identifier: GPL-2.0-only # -from contextlib import closing -from datetime import datetime +from datetime import datetime, timedelta import asyncio -import json import logging import math -import os -import signal -import socket import time -from . import chunkify, DEFAULT_MAX_CHUNK +import os +import base64 +import hashlib +from . import create_async_client +import bb.asyncrpc + +logger = logging.getLogger("hashserv.server") + + +# This permission only exists to match nothing +NONE_PERM = "@none" + +READ_PERM = "@read" +REPORT_PERM = "@report" +DB_ADMIN_PERM = "@db-admin" +USER_ADMIN_PERM = "@user-admin" +ALL_PERM = "@all" + +ALL_PERMISSIONS = { + READ_PERM, + REPORT_PERM, + DB_ADMIN_PERM, + USER_ADMIN_PERM, + ALL_PERM, +} + +DEFAULT_ANON_PERMS = ( + READ_PERM, + REPORT_PERM, + DB_ADMIN_PERM, +) + +TOKEN_ALGORITHM = "sha256" + +# 48 bytes of random data will result in 64 characters when base64 +# encoded. This number also ensures that the base64 encoding won't have any +# trailing '=' characters. +TOKEN_SIZE = 48 -logger = logging.getLogger('hashserv.server') +SALT_SIZE = 8 class Measurement(object): @@ -105,385 +137,745 @@ class Stats(object): return math.sqrt(self.s / (self.num - 1)) def todict(self): - return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} - - -class ClientError(Exception): - pass - -class ServerClient(object): - FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' - ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' - - def __init__(self, reader, writer, db, request_stats): - self.reader = reader - self.writer = writer - self.db = db - self.request_stats = request_stats - self.max_chunk = DEFAULT_MAX_CHUNK - - self.handlers = { - 'get': self.handle_get, - 'report': self.handle_report, - 'report-equiv': self.handle_equivreport, - 'get-stream': self.handle_get_stream, - 'get-stats': self.handle_get_stats, - 'reset-stats': self.handle_reset_stats, - 'chunk-stream': self.handle_chunk, + return { + k: getattr(self, k) + for k in ("num", "total_time", "max_time", "average", "stdev") } - async def process_requests(self): - try: - self.addr = self.writer.get_extra_info('peername') - logger.debug('Client %r connected' % (self.addr,)) - # Read protocol and version - protocol = await self.reader.readline() - if protocol is None: - return +token_refresh_semaphore = asyncio.Lock() - (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() - if proto_name != 'OEHASHEQUIV': - return - proto_version = tuple(int(v) for v in proto_version.split('.')) - if proto_version < (1, 0) or proto_version > (1, 1): - return +async def new_token(): + # Prevent malicious users from using this API to deduce the entropy + # pool on the server and thus be able to guess a token. *All* token + # refresh requests lock the same global semaphore and then sleep for a + # short time. The effectively rate limits the total number of requests + # than can be made across all clients to 10/second, which should be enough + # since you have to be an authenticated users to make the request in the + # first place + async with token_refresh_semaphore: + await asyncio.sleep(0.1) + raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK) - # Read headers. Currently, no headers are implemented, so look for - # an empty line to signal the end of the headers - while True: - line = await self.reader.readline() - if line is None: - return + return base64.b64encode(raw, b"._").decode("utf-8") - line = line.decode('utf-8').rstrip() - if not line: - break - # Handle messages - while True: - d = await self.read_message() - if d is None: - break - await self.dispatch_message(d) - await self.writer.drain() - except ClientError as e: - logger.error(str(e)) - finally: - self.writer.close() +def new_salt(): + return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex() - async def dispatch_message(self, msg): - for k in self.handlers.keys(): - if k in msg: - logger.debug('Handling %s' % k) - if 'stream' in k: - await self.handlers[k](msg[k]) + +def hash_token(algo, salt, token): + h = hashlib.new(algo) + h.update(salt.encode("utf-8")) + h.update(token.encode("utf-8")) + return ":".join([algo, salt, h.hexdigest()]) + + +def permissions(*permissions, allow_anon=True, allow_self_service=False): + """ + Function decorator that can be used to decorate an RPC function call and + check that the current users permissions match the require permissions. + + If allow_anon is True, the user will also be allowed to make the RPC call + if the anonymous user permissions match the permissions. + + If allow_self_service is True, and the "username" property in the request + is the currently logged in user, or not specified, the user will also be + allowed to make the request. This allows users to access normal privileged + API, as long as they are only modifying their own user properties (e.g. + users can be allowed to reset their own token without @user-admin + permissions, but not the token for any other user. + """ + + def wrapper(func): + async def wrap(self, request): + if allow_self_service and self.user is not None: + username = request.get("username", self.user.username) + if username == self.user.username: + request["username"] = self.user.username + return await func(self, request) + + if not self.user_has_permissions(*permissions, allow_anon=allow_anon): + if not self.user: + username = "Anonymous user" + user_perms = self.server.anon_perms else: - with self.request_stats.start_sample() as self.request_sample, \ - self.request_sample.measure(): - await self.handlers[k](msg[k]) - return + username = self.user.username + user_perms = self.user.permissions + + self.logger.info( + "User %s with permissions %r denied from calling %s. Missing permissions(s) %r", + username, + ", ".join(user_perms), + func.__name__, + ", ".join(permissions), + ) + raise bb.asyncrpc.InvokeError( + f"{username} is not allowed to access permissions(s) {', '.join(permissions)}" + ) + + return await func(self, request) + + return wrap + + return wrapper + + +class ServerClient(bb.asyncrpc.AsyncServerConnection): + def __init__(self, socket, server): + super().__init__(socket, "OEHASHEQUIV", server.logger) + self.server = server + self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK + self.user = None + + self.handlers.update( + { + "get": self.handle_get, + "get-outhash": self.handle_get_outhash, + "get-stream": self.handle_get_stream, + "exists-stream": self.handle_exists_stream, + "get-stats": self.handle_get_stats, + "get-db-usage": self.handle_get_db_usage, + "get-db-query-columns": self.handle_get_db_query_columns, + # Not always read-only, but internally checks if the server is + # read-only + "report": self.handle_report, + "auth": self.handle_auth, + "get-user": self.handle_get_user, + "get-all-users": self.handle_get_all_users, + "become-user": self.handle_become_user, + } + ) + + if not self.server.read_only: + self.handlers.update( + { + "report-equiv": self.handle_equivreport, + "reset-stats": self.handle_reset_stats, + "backfill-wait": self.handle_backfill_wait, + "remove": self.handle_remove, + "gc-mark": self.handle_gc_mark, + "gc-sweep": self.handle_gc_sweep, + "gc-status": self.handle_gc_status, + "clean-unused": self.handle_clean_unused, + "refresh-token": self.handle_refresh_token, + "set-user-perms": self.handle_set_perms, + "new-user": self.handle_new_user, + "delete-user": self.handle_delete_user, + } + ) - raise ClientError("Unrecognized command %r" % msg) + def raise_no_user_error(self, username): + raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists") - def write_message(self, msg): - for c in chunkify(json.dumps(msg), self.max_chunk): - self.writer.write(c.encode('utf-8')) + def user_has_permissions(self, *permissions, allow_anon=True): + permissions = set(permissions) + if allow_anon: + if ALL_PERM in self.server.anon_perms: + return True - async def read_message(self): - l = await self.reader.readline() - if not l: - return None + if not permissions - self.server.anon_perms: + return True - try: - message = l.decode('utf-8') + if self.user is None: + return False - if not message.endswith('\n'): - return None + if ALL_PERM in self.user.permissions: + return True - return json.loads(message) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - logger.error('Bad message from client: %r' % message) - raise e + if not permissions - self.user.permissions: + return True - async def handle_chunk(self, request): - lines = [] - try: - while True: - l = await self.reader.readline() - l = l.rstrip(b"\n").decode("utf-8") - if not l: - break - lines.append(l) + return False + + def validate_proto_version(self): + return self.proto_version > (1, 0) and self.proto_version <= (1, 1) + + async def process_requests(self): + async with self.server.db_engine.connect(self.logger) as db: + self.db = db + if self.server.upstream is not None: + self.upstream_client = await create_async_client(self.server.upstream) + else: + self.upstream_client = None - msg = json.loads(''.join(lines)) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - logger.error('Bad message from client: %r' % message) - raise e + try: + await super().process_requests() + finally: + if self.upstream_client is not None: + await self.upstream_client.close() - if 'chunk-stream' in msg: - raise ClientError("Nested chunks are not allowed") + async def dispatch_message(self, msg): + for k in self.handlers.keys(): + if k in msg: + self.logger.debug("Handling %s" % k) + if "stream" in k: + return await self.handlers[k](msg[k]) + else: + with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure(): + return await self.handlers[k](msg[k]) - await self.dispatch_message(msg) + raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg) + @permissions(READ_PERM) async def handle_get(self, request): - method = request['method'] - taskhash = request['taskhash'] + method = request["method"] + taskhash = request["taskhash"] + fetch_all = request.get("all", False) + + return await self.get_unihash(method, taskhash, fetch_all) + + async def get_unihash(self, method, taskhash, fetch_all=False): + d = None + + if fetch_all: + row = await self.db.get_unihash_by_taskhash_full(method, taskhash) + if row is not None: + d = {k: row[k] for k in row.keys()} + elif self.upstream_client is not None: + d = await self.upstream_client.get_taskhash(method, taskhash, True) + await self.update_unified(d) + else: + row = await self.db.get_equivalent(method, taskhash) + + if row is not None: + d = {k: row[k] for k in row.keys()} + elif self.upstream_client is not None: + d = await self.upstream_client.get_taskhash(method, taskhash) + await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"]) + + return d + + @permissions(READ_PERM) + async def handle_get_outhash(self, request): + method = request["method"] + outhash = request["outhash"] + taskhash = request["taskhash"] + with_unihash = request.get("with_unihash", True) - if request.get('all', False): - row = self.query_equivalent(method, taskhash, self.ALL_QUERY) + return await self.get_outhash(method, outhash, taskhash, with_unihash) + + async def get_outhash(self, method, outhash, taskhash, with_unihash=True): + d = None + if with_unihash: + row = await self.db.get_unihash_by_outhash(method, outhash) else: - row = self.query_equivalent(method, taskhash, self.FAST_QUERY) + row = await self.db.get_outhash(method, outhash) if row is not None: - logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) d = {k: row[k] for k in row.keys()} + elif self.upstream_client is not None: + d = await self.upstream_client.get_outhash(method, outhash, taskhash) + await self.update_unified(d) - self.write_message(d) - else: - self.write_message(None) + return d - async def handle_get_stream(self, request): - self.write_message('ok') + async def update_unified(self, data): + if data is None: + return + + await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"]) + await self.db.insert_outhash(data) + + async def _stream_handler(self, handler): + await self.socket.send_message("ok") while True: - l = await self.reader.readline() + upstream = None + + l = await self.socket.recv() if not l: - return + break try: # This inner loop is very sensitive and must be as fast as # possible (which is why the request sample is handled manually # instead of using 'with', and also why logging statements are # commented out. - self.request_sample = self.request_stats.start_sample() + self.request_sample = self.server.request_stats.start_sample() request_measure = self.request_sample.measure() request_measure.start() - l = l.decode('utf-8').rstrip() - if l == 'END': - self.writer.write('ok\n'.encode('utf-8')) - return - - (method, taskhash) = l.split() - #logger.debug('Looking up %s %s' % (method, taskhash)) - row = self.query_equivalent(method, taskhash, self.FAST_QUERY) - if row is not None: - msg = ('%s\n' % row['unihash']).encode('utf-8') - #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) - else: - msg = '\n'.encode('utf-8') + if l == "END": + break - self.writer.write(msg) + msg = await handler(l) + await self.socket.send(msg) finally: request_measure.end() self.request_sample.end() - await self.writer.drain() + await self.socket.send("ok") + return self.NO_RESPONSE - async def handle_report(self, data): - with closing(self.db.cursor()) as cursor: - cursor.execute(''' - -- Find tasks with a matching outhash (that is, tasks that - -- are equivalent) - SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash - - -- If there is an exact match on the taskhash, return it. - -- Otherwise return the oldest matching outhash of any - -- taskhash - ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END, - created ASC - - -- Only return one row - LIMIT 1 - ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')}) - - row = cursor.fetchone() - - # If no matching outhash was found, or one *was* found but it - # wasn't an exact match on the taskhash, a new entry for this - # taskhash should be added - if row is None or row['taskhash'] != data['taskhash']: - # If a row matching the outhash was found, the unihash for - # the new taskhash should be the same as that one. - # Otherwise the caller provided unihash is used. - unihash = data['unihash'] - if row is not None: - unihash = row['unihash'] - - insert_data = { - 'method': data['method'], - 'outhash': data['outhash'], - 'taskhash': data['taskhash'], - 'unihash': unihash, - 'created': datetime.now() - } + @permissions(READ_PERM) + async def handle_get_stream(self, request): + async def handler(l): + (method, taskhash) = l.split() + # self.logger.debug('Looking up %s %s' % (method, taskhash)) + row = await self.db.get_equivalent(method, taskhash) - for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): - if k in data: - insert_data[k] = data[k] + if row is not None: + # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) + return row["unihash"] - cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % ( - ', '.join(sorted(insert_data.keys())), - ', '.join(':' + k for k in sorted(insert_data.keys()))), - insert_data) + if self.upstream_client is not None: + upstream = await self.upstream_client.get_unihash(method, taskhash) + if upstream: + await self.server.backfill_queue.put((method, taskhash)) + return upstream - self.db.commit() + return "" - logger.info('Adding taskhash %s with unihash %s', - data['taskhash'], unihash) + return await self._stream_handler(handler) - d = { - 'taskhash': data['taskhash'], - 'method': data['method'], - 'unihash': unihash - } - else: - d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} + @permissions(READ_PERM) + async def handle_exists_stream(self, request): + async def handler(l): + if await self.db.unihash_exists(l): + return "true" - self.write_message(d) + if self.upstream_client is not None: + if await self.upstream_client.unihash_exists(l): + return "true" - async def handle_equivreport(self, data): - with closing(self.db.cursor()) as cursor: - insert_data = { - 'method': data['method'], - 'outhash': "", - 'taskhash': data['taskhash'], - 'unihash': data['unihash'], - 'created': datetime.now() - } + return "false" - for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): - if k in data: - insert_data[k] = data[k] + return await self._stream_handler(handler) - cursor.execute('''INSERT OR IGNORE INTO tasks_v2 (%s) VALUES (%s)''' % ( - ', '.join(sorted(insert_data.keys())), - ', '.join(':' + k for k in sorted(insert_data.keys()))), - insert_data) + async def report_readonly(self, data): + method = data["method"] + outhash = data["outhash"] + taskhash = data["taskhash"] - self.db.commit() + info = await self.get_outhash(method, outhash, taskhash) + if info: + unihash = info["unihash"] + else: + unihash = data["unihash"] + + return { + "taskhash": taskhash, + "method": method, + "unihash": unihash, + } - # Fetch the unihash that will be reported for the taskhash. If the - # unihash matches, it means this row was inserted (or the mapping - # was already valid) - row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) + # Since this can be called either read only or to report, the check to + # report is made inside the function + @permissions(READ_PERM) + async def handle_report(self, data): + if self.server.read_only or not self.user_has_permissions(REPORT_PERM): + return await self.report_readonly(data) + + outhash_data = { + "method": data["method"], + "outhash": data["outhash"], + "taskhash": data["taskhash"], + "created": datetime.now(), + } - if row['unihash'] == data['unihash']: - logger.info('Adding taskhash equivalence for %s with unihash %s', - data['taskhash'], row['unihash']) + for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"): + if k in data: + outhash_data[k] = data[k] - d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} + if self.user: + outhash_data["owner"] = self.user.username - self.write_message(d) + # Insert the new entry, unless it already exists + if await self.db.insert_outhash(outhash_data): + # If this row is new, check if it is equivalent to another + # output hash + row = await self.db.get_equivalent_for_outhash( + data["method"], data["outhash"], data["taskhash"] + ) + if row is not None: + # A matching output hash was found. Set our taskhash to the + # same unihash since they are equivalent + unihash = row["unihash"] + else: + # No matching output hash was found. This is probably the + # first outhash to be added. + unihash = data["unihash"] + + # Query upstream to see if it has a unihash we can use + if self.upstream_client is not None: + upstream_data = await self.upstream_client.get_outhash( + data["method"], data["outhash"], data["taskhash"] + ) + if upstream_data is not None: + unihash = upstream_data["unihash"] + + await self.db.insert_unihash(data["method"], data["taskhash"], unihash) + + unihash_data = await self.get_unihash(data["method"], data["taskhash"]) + if unihash_data is not None: + unihash = unihash_data["unihash"] + else: + unihash = data["unihash"] - async def handle_get_stats(self, request): - d = { - 'requests': self.request_stats.todict(), + return { + "taskhash": data["taskhash"], + "method": data["method"], + "unihash": unihash, } - self.write_message(d) + @permissions(READ_PERM, REPORT_PERM) + async def handle_equivreport(self, data): + await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"]) + + # Fetch the unihash that will be reported for the taskhash. If the + # unihash matches, it means this row was inserted (or the mapping + # was already valid) + row = await self.db.get_equivalent(data["method"], data["taskhash"]) + + if row["unihash"] == data["unihash"]: + self.logger.info( + "Adding taskhash equivalence for %s with unihash %s", + data["taskhash"], + row["unihash"], + ) + + return {k: row[k] for k in ("taskhash", "method", "unihash")} + + @permissions(READ_PERM) + async def handle_get_stats(self, request): + return { + "requests": self.server.request_stats.todict(), + } + @permissions(DB_ADMIN_PERM) async def handle_reset_stats(self, request): d = { - 'requests': self.request_stats.todict(), + "requests": self.server.request_stats.todict(), } - self.request_stats.reset() - self.write_message(d) + self.server.request_stats.reset() + return d - def query_equivalent(self, method, taskhash, query): - # This is part of the inner loop and must be as fast as possible - try: - cursor = self.db.cursor() - cursor.execute(query, {'method': method, 'taskhash': taskhash}) - return cursor.fetchone() - except: - cursor.close() + @permissions(READ_PERM) + async def handle_backfill_wait(self, request): + d = { + "tasks": self.server.backfill_queue.qsize(), + } + await self.server.backfill_queue.join() + return d + @permissions(DB_ADMIN_PERM) + async def handle_remove(self, request): + condition = request["where"] + if not isinstance(condition, dict): + raise TypeError("Bad condition type %s" % type(condition)) -class Server(object): - def __init__(self, db, loop=None): - self.request_stats = Stats() - self.db = db + return {"count": await self.db.remove(condition)} - if loop is None: - self.loop = asyncio.new_event_loop() - self.close_loop = True - else: - self.loop = loop - self.close_loop = False + @permissions(DB_ADMIN_PERM) + async def handle_gc_mark(self, request): + condition = request["where"] + mark = request["mark"] - self._cleanup_socket = None + if not isinstance(condition, dict): + raise TypeError("Bad condition type %s" % type(condition)) - def start_tcp_server(self, host, port): - self.server = self.loop.run_until_complete( - asyncio.start_server(self.handle_client, host, port, loop=self.loop) - ) + if not isinstance(mark, str): + raise TypeError("Bad mark type %s" % type(mark)) - for s in self.server.sockets: - logger.info('Listening on %r' % (s.getsockname(),)) - # Newer python does this automatically. Do it manually here for - # maximum compatibility - s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) + return {"count": await self.db.gc_mark(mark, condition)} - name = self.server.sockets[0].getsockname() - if self.server.sockets[0].family == socket.AF_INET6: - self.address = "[%s]:%d" % (name[0], name[1]) - else: - self.address = "%s:%d" % (name[0], name[1]) + @permissions(DB_ADMIN_PERM) + async def handle_gc_sweep(self, request): + mark = request["mark"] - def start_unix_server(self, path): - def cleanup(): - os.unlink(path) + if not isinstance(mark, str): + raise TypeError("Bad mark type %s" % type(mark)) - cwd = os.getcwd() - try: - # Work around path length limits in AF_UNIX - os.chdir(os.path.dirname(path)) - self.server = self.loop.run_until_complete( - asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop) + current_mark = await self.db.get_current_gc_mark() + + if not current_mark or mark != current_mark: + raise bb.asyncrpc.InvokeError( + f"'{mark}' is not the current mark. Refusing to sweep" ) - finally: - os.chdir(cwd) - logger.info('Listening on %r' % path) + count = await self.db.gc_sweep() - self._cleanup_socket = cleanup - self.address = "unix://%s" % os.path.abspath(path) + return {"count": count} + + @permissions(DB_ADMIN_PERM) + async def handle_gc_status(self, request): + (keep_rows, remove_rows, current_mark) = await self.db.gc_status() + return { + "keep": keep_rows, + "remove": remove_rows, + "mark": current_mark, + } + + @permissions(DB_ADMIN_PERM) + async def handle_clean_unused(self, request): + max_age = request["max_age_seconds"] + oldest = datetime.now() - timedelta(seconds=-max_age) + return {"count": await self.db.clean_unused(oldest)} + + @permissions(DB_ADMIN_PERM) + async def handle_get_db_usage(self, request): + return {"usage": await self.db.get_usage()} + + @permissions(DB_ADMIN_PERM) + async def handle_get_db_query_columns(self, request): + return {"columns": await self.db.get_query_columns()} + + # The authentication API is always allowed + async def handle_auth(self, request): + username = str(request["username"]) + token = str(request["token"]) + + async def fail_auth(): + nonlocal username + # Rate limit bad login attempts + await asyncio.sleep(1) + raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}") + + user, db_token = await self.db.lookup_user_token(username) + + if not user or not db_token: + await fail_auth() - async def handle_client(self, reader, writer): - # writer.transport.set_write_buffer_limits(0) try: - client = ServerClient(reader, writer, self.db, self.request_stats) - await client.process_requests() - except Exception as e: - import traceback - logger.error('Error from client: %s' % str(e), exc_info=True) - traceback.print_exc() - writer.close() - logger.info('Client disconnected') + algo, salt, _ = db_token.split(":") + except ValueError: + await fail_auth() + + if hash_token(algo, salt, token) != db_token: + await fail_auth() - def serve_forever(self): - def signal_handler(): - self.loop.stop() + self.user = user - self.loop.add_signal_handler(signal.SIGTERM, signal_handler) + self.logger.info("Authenticated as %s", username) + return { + "result": True, + "username": self.user.username, + "permissions": sorted(list(self.user.permissions)), + } + + @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) + async def handle_refresh_token(self, request): + username = str(request["username"]) + + token = await new_token() + + updated = await self.db.set_user_token( + username, + hash_token(TOKEN_ALGORITHM, new_salt(), token), + ) + if not updated: + self.raise_no_user_error(username) + + return {"username": username, "token": token} + + def get_perm_arg(self, arg): + if not isinstance(arg, list): + raise bb.asyncrpc.InvokeError("Unexpected type for permissions") + + arg = set(arg) try: - self.loop.run_forever() - except KeyboardInterrupt: + arg.remove(NONE_PERM) + except KeyError: pass - self.server.close() - self.loop.run_until_complete(self.server.wait_closed()) - logger.info('Server shutting down') + unknown_perms = arg - ALL_PERMISSIONS + if unknown_perms: + raise bb.asyncrpc.InvokeError( + "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms))) + ) + + return sorted(list(arg)) + + def return_perms(self, permissions): + if ALL_PERM in permissions: + return sorted(list(ALL_PERMISSIONS)) + return sorted(list(permissions)) + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_set_perms(self, request): + username = str(request["username"]) + permissions = self.get_perm_arg(request["permissions"]) + + if not await self.db.set_user_perms(username, permissions): + self.raise_no_user_error(username) + + return { + "username": username, + "permissions": self.return_perms(permissions), + } + + @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) + async def handle_get_user(self, request): + username = str(request["username"]) + + user = await self.db.lookup_user(username) + if user is None: + return None + + return { + "username": user.username, + "permissions": self.return_perms(user.permissions), + } + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_get_all_users(self, request): + users = await self.db.get_all_users() + return { + "users": [ + { + "username": u.username, + "permissions": self.return_perms(u.permissions), + } + for u in users + ] + } + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_new_user(self, request): + username = str(request["username"]) + permissions = self.get_perm_arg(request["permissions"]) + + token = await new_token() + + inserted = await self.db.new_user( + username, + permissions, + hash_token(TOKEN_ALGORITHM, new_salt(), token), + ) + if not inserted: + raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'") + + return { + "username": username, + "permissions": self.return_perms(permissions), + "token": token, + } + + @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) + async def handle_delete_user(self, request): + username = str(request["username"]) + + if not await self.db.delete_user(username): + self.raise_no_user_error(username) + + return {"username": username} + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_become_user(self, request): + username = str(request["username"]) + + user = await self.db.lookup_user(username) + if user is None: + raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist") + + self.user = user + + self.logger.info("Became user %s", username) + + return { + "username": self.user.username, + "permissions": self.return_perms(self.user.permissions), + } + + +class Server(bb.asyncrpc.AsyncServer): + def __init__( + self, + db_engine, + upstream=None, + read_only=False, + anon_perms=DEFAULT_ANON_PERMS, + admin_username=None, + admin_password=None, + ): + if upstream and read_only: + raise bb.asyncrpc.ServerError( + "Read-only hashserv cannot pull from an upstream server" + ) + + disallowed_perms = set(anon_perms) - set( + [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM] + ) + + if disallowed_perms: + raise bb.asyncrpc.ServerError( + f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users" + ) + + super().__init__(logger) + + self.request_stats = Stats() + self.db_engine = db_engine + self.upstream = upstream + self.read_only = read_only + self.backfill_queue = None + self.anon_perms = set(anon_perms) + self.admin_username = admin_username + self.admin_password = admin_password + + self.logger.info( + "Anonymous user permissions are: %s", ", ".join(self.anon_perms) + ) + + def accept_client(self, socket): + return ServerClient(socket, self) + + async def create_admin_user(self): + admin_permissions = (ALL_PERM,) + async with self.db_engine.connect(self.logger) as db: + added = await db.new_user( + self.admin_username, + admin_permissions, + hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password), + ) + if added: + self.logger.info("Created admin user '%s'", self.admin_username) + else: + await db.set_user_perms( + self.admin_username, + admin_permissions, + ) + await db.set_user_token( + self.admin_username, + hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password), + ) + self.logger.info("Admin user '%s' updated", self.admin_username) + + async def backfill_worker_task(self): + async with await create_async_client( + self.upstream + ) as client, self.db_engine.connect(self.logger) as db: + while True: + item = await self.backfill_queue.get() + if item is None: + self.backfill_queue.task_done() + break + + method, taskhash = item + d = await client.get_taskhash(method, taskhash) + if d is not None: + await db.insert_unihash(d["method"], d["taskhash"], d["unihash"]) + self.backfill_queue.task_done() + + def start(self): + tasks = super().start() + if self.upstream: + self.backfill_queue = asyncio.Queue() + tasks += [self.backfill_worker_task()] + + self.loop.run_until_complete(self.db_engine.create()) + + if self.admin_username: + self.loop.run_until_complete(self.create_admin_user()) - if self.close_loop: - self.loop.close() + return tasks - if self._cleanup_socket is not None: - self._cleanup_socket() + async def stop(self): + if self.backfill_queue is not None: + await self.backfill_queue.put(None) + await super().stop() |