diff options
Diffstat (limited to 'bitbake/lib/hashserv/server.py')
-rw-r--r-- | bitbake/lib/hashserv/server.py | 881 |
1 files changed, 881 insertions, 0 deletions
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py new file mode 100644 index 0000000000..68f64f983b --- /dev/null +++ b/bitbake/lib/hashserv/server.py @@ -0,0 +1,881 @@ +# Copyright (C) 2019 Garmin Ltd. +# +# SPDX-License-Identifier: GPL-2.0-only +# + +from datetime import datetime, timedelta +import asyncio +import logging +import math +import time +import os +import base64 +import hashlib +from . import create_async_client +import bb.asyncrpc + +logger = logging.getLogger("hashserv.server") + + +# This permission only exists to match nothing +NONE_PERM = "@none" + +READ_PERM = "@read" +REPORT_PERM = "@report" +DB_ADMIN_PERM = "@db-admin" +USER_ADMIN_PERM = "@user-admin" +ALL_PERM = "@all" + +ALL_PERMISSIONS = { + READ_PERM, + REPORT_PERM, + DB_ADMIN_PERM, + USER_ADMIN_PERM, + ALL_PERM, +} + +DEFAULT_ANON_PERMS = ( + READ_PERM, + REPORT_PERM, + DB_ADMIN_PERM, +) + +TOKEN_ALGORITHM = "sha256" + +# 48 bytes of random data will result in 64 characters when base64 +# encoded. This number also ensures that the base64 encoding won't have any +# trailing '=' characters. +TOKEN_SIZE = 48 + +SALT_SIZE = 8 + + +class Measurement(object): + def __init__(self, sample): + self.sample = sample + + def start(self): + self.start_time = time.perf_counter() + + def end(self): + self.sample.add(time.perf_counter() - self.start_time) + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args, **kwargs): + self.end() + + +class Sample(object): + def __init__(self, stats): + self.stats = stats + self.num_samples = 0 + self.elapsed = 0 + + def measure(self): + return Measurement(self) + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.end() + + def add(self, elapsed): + self.num_samples += 1 + self.elapsed += elapsed + + def end(self): + if self.num_samples: + self.stats.add(self.elapsed) + self.num_samples = 0 + self.elapsed = 0 + + +class Stats(object): + def __init__(self): + self.reset() + + def reset(self): + self.num = 0 + self.total_time = 0 + self.max_time = 0 + self.m = 0 + self.s = 0 + self.current_elapsed = None + + def add(self, elapsed): + self.num += 1 + if self.num == 1: + self.m = elapsed + self.s = 0 + else: + last_m = self.m + self.m = last_m + (elapsed - last_m) / self.num + self.s = self.s + (elapsed - last_m) * (elapsed - self.m) + + self.total_time += elapsed + + if self.max_time < elapsed: + self.max_time = elapsed + + def start_sample(self): + return Sample(self) + + @property + def average(self): + if self.num == 0: + return 0 + return self.total_time / self.num + + @property + def stdev(self): + if self.num <= 1: + return 0 + return math.sqrt(self.s / (self.num - 1)) + + def todict(self): + return { + k: getattr(self, k) + for k in ("num", "total_time", "max_time", "average", "stdev") + } + + +token_refresh_semaphore = asyncio.Lock() + + +async def new_token(): + # Prevent malicious users from using this API to deduce the entropy + # pool on the server and thus be able to guess a token. *All* token + # refresh requests lock the same global semaphore and then sleep for a + # short time. The effectively rate limits the total number of requests + # than can be made across all clients to 10/second, which should be enough + # since you have to be an authenticated users to make the request in the + # first place + async with token_refresh_semaphore: + await asyncio.sleep(0.1) + raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK) + + return base64.b64encode(raw, b"._").decode("utf-8") + + +def new_salt(): + return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex() + + +def hash_token(algo, salt, token): + h = hashlib.new(algo) + h.update(salt.encode("utf-8")) + h.update(token.encode("utf-8")) + return ":".join([algo, salt, h.hexdigest()]) + + +def permissions(*permissions, allow_anon=True, allow_self_service=False): + """ + Function decorator that can be used to decorate an RPC function call and + check that the current users permissions match the require permissions. + + If allow_anon is True, the user will also be allowed to make the RPC call + if the anonymous user permissions match the permissions. + + If allow_self_service is True, and the "username" property in the request + is the currently logged in user, or not specified, the user will also be + allowed to make the request. This allows users to access normal privileged + API, as long as they are only modifying their own user properties (e.g. + users can be allowed to reset their own token without @user-admin + permissions, but not the token for any other user. + """ + + def wrapper(func): + async def wrap(self, request): + if allow_self_service and self.user is not None: + username = request.get("username", self.user.username) + if username == self.user.username: + request["username"] = self.user.username + return await func(self, request) + + if not self.user_has_permissions(*permissions, allow_anon=allow_anon): + if not self.user: + username = "Anonymous user" + user_perms = self.server.anon_perms + else: + username = self.user.username + user_perms = self.user.permissions + + self.logger.info( + "User %s with permissions %r denied from calling %s. Missing permissions(s) %r", + username, + ", ".join(user_perms), + func.__name__, + ", ".join(permissions), + ) + raise bb.asyncrpc.InvokeError( + f"{username} is not allowed to access permissions(s) {', '.join(permissions)}" + ) + + return await func(self, request) + + return wrap + + return wrapper + + +class ServerClient(bb.asyncrpc.AsyncServerConnection): + def __init__(self, socket, server): + super().__init__(socket, "OEHASHEQUIV", server.logger) + self.server = server + self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK + self.user = None + + self.handlers.update( + { + "get": self.handle_get, + "get-outhash": self.handle_get_outhash, + "get-stream": self.handle_get_stream, + "exists-stream": self.handle_exists_stream, + "get-stats": self.handle_get_stats, + "get-db-usage": self.handle_get_db_usage, + "get-db-query-columns": self.handle_get_db_query_columns, + # Not always read-only, but internally checks if the server is + # read-only + "report": self.handle_report, + "auth": self.handle_auth, + "get-user": self.handle_get_user, + "get-all-users": self.handle_get_all_users, + "become-user": self.handle_become_user, + } + ) + + if not self.server.read_only: + self.handlers.update( + { + "report-equiv": self.handle_equivreport, + "reset-stats": self.handle_reset_stats, + "backfill-wait": self.handle_backfill_wait, + "remove": self.handle_remove, + "gc-mark": self.handle_gc_mark, + "gc-sweep": self.handle_gc_sweep, + "gc-status": self.handle_gc_status, + "clean-unused": self.handle_clean_unused, + "refresh-token": self.handle_refresh_token, + "set-user-perms": self.handle_set_perms, + "new-user": self.handle_new_user, + "delete-user": self.handle_delete_user, + } + ) + + def raise_no_user_error(self, username): + raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists") + + def user_has_permissions(self, *permissions, allow_anon=True): + permissions = set(permissions) + if allow_anon: + if ALL_PERM in self.server.anon_perms: + return True + + if not permissions - self.server.anon_perms: + return True + + if self.user is None: + return False + + if ALL_PERM in self.user.permissions: + return True + + if not permissions - self.user.permissions: + return True + + return False + + def validate_proto_version(self): + return self.proto_version > (1, 0) and self.proto_version <= (1, 1) + + async def process_requests(self): + async with self.server.db_engine.connect(self.logger) as db: + self.db = db + if self.server.upstream is not None: + self.upstream_client = await create_async_client(self.server.upstream) + else: + self.upstream_client = None + + try: + await super().process_requests() + finally: + if self.upstream_client is not None: + await self.upstream_client.close() + + async def dispatch_message(self, msg): + for k in self.handlers.keys(): + if k in msg: + self.logger.debug("Handling %s" % k) + if "stream" in k: + return await self.handlers[k](msg[k]) + else: + with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure(): + return await self.handlers[k](msg[k]) + + raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg) + + @permissions(READ_PERM) + async def handle_get(self, request): + method = request["method"] + taskhash = request["taskhash"] + fetch_all = request.get("all", False) + + return await self.get_unihash(method, taskhash, fetch_all) + + async def get_unihash(self, method, taskhash, fetch_all=False): + d = None + + if fetch_all: + row = await self.db.get_unihash_by_taskhash_full(method, taskhash) + if row is not None: + d = {k: row[k] for k in row.keys()} + elif self.upstream_client is not None: + d = await self.upstream_client.get_taskhash(method, taskhash, True) + await self.update_unified(d) + else: + row = await self.db.get_equivalent(method, taskhash) + + if row is not None: + d = {k: row[k] for k in row.keys()} + elif self.upstream_client is not None: + d = await self.upstream_client.get_taskhash(method, taskhash) + await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"]) + + return d + + @permissions(READ_PERM) + async def handle_get_outhash(self, request): + method = request["method"] + outhash = request["outhash"] + taskhash = request["taskhash"] + with_unihash = request.get("with_unihash", True) + + return await self.get_outhash(method, outhash, taskhash, with_unihash) + + async def get_outhash(self, method, outhash, taskhash, with_unihash=True): + d = None + if with_unihash: + row = await self.db.get_unihash_by_outhash(method, outhash) + else: + row = await self.db.get_outhash(method, outhash) + + if row is not None: + d = {k: row[k] for k in row.keys()} + elif self.upstream_client is not None: + d = await self.upstream_client.get_outhash(method, outhash, taskhash) + await self.update_unified(d) + + return d + + async def update_unified(self, data): + if data is None: + return + + await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"]) + await self.db.insert_outhash(data) + + async def _stream_handler(self, handler): + await self.socket.send_message("ok") + + while True: + upstream = None + + l = await self.socket.recv() + if not l: + break + + try: + # This inner loop is very sensitive and must be as fast as + # possible (which is why the request sample is handled manually + # instead of using 'with', and also why logging statements are + # commented out. + self.request_sample = self.server.request_stats.start_sample() + request_measure = self.request_sample.measure() + request_measure.start() + + if l == "END": + break + + msg = await handler(l) + await self.socket.send(msg) + finally: + request_measure.end() + self.request_sample.end() + + await self.socket.send("ok") + return self.NO_RESPONSE + + @permissions(READ_PERM) + async def handle_get_stream(self, request): + async def handler(l): + (method, taskhash) = l.split() + # self.logger.debug('Looking up %s %s' % (method, taskhash)) + row = await self.db.get_equivalent(method, taskhash) + + if row is not None: + # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) + return row["unihash"] + + if self.upstream_client is not None: + upstream = await self.upstream_client.get_unihash(method, taskhash) + if upstream: + await self.server.backfill_queue.put((method, taskhash)) + return upstream + + return "" + + return await self._stream_handler(handler) + + @permissions(READ_PERM) + async def handle_exists_stream(self, request): + async def handler(l): + if await self.db.unihash_exists(l): + return "true" + + if self.upstream_client is not None: + if await self.upstream_client.unihash_exists(l): + return "true" + + return "false" + + return await self._stream_handler(handler) + + async def report_readonly(self, data): + method = data["method"] + outhash = data["outhash"] + taskhash = data["taskhash"] + + info = await self.get_outhash(method, outhash, taskhash) + if info: + unihash = info["unihash"] + else: + unihash = data["unihash"] + + return { + "taskhash": taskhash, + "method": method, + "unihash": unihash, + } + + # Since this can be called either read only or to report, the check to + # report is made inside the function + @permissions(READ_PERM) + async def handle_report(self, data): + if self.server.read_only or not self.user_has_permissions(REPORT_PERM): + return await self.report_readonly(data) + + outhash_data = { + "method": data["method"], + "outhash": data["outhash"], + "taskhash": data["taskhash"], + "created": datetime.now(), + } + + for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"): + if k in data: + outhash_data[k] = data[k] + + if self.user: + outhash_data["owner"] = self.user.username + + # Insert the new entry, unless it already exists + if await self.db.insert_outhash(outhash_data): + # If this row is new, check if it is equivalent to another + # output hash + row = await self.db.get_equivalent_for_outhash( + data["method"], data["outhash"], data["taskhash"] + ) + + if row is not None: + # A matching output hash was found. Set our taskhash to the + # same unihash since they are equivalent + unihash = row["unihash"] + else: + # No matching output hash was found. This is probably the + # first outhash to be added. + unihash = data["unihash"] + + # Query upstream to see if it has a unihash we can use + if self.upstream_client is not None: + upstream_data = await self.upstream_client.get_outhash( + data["method"], data["outhash"], data["taskhash"] + ) + if upstream_data is not None: + unihash = upstream_data["unihash"] + + await self.db.insert_unihash(data["method"], data["taskhash"], unihash) + + unihash_data = await self.get_unihash(data["method"], data["taskhash"]) + if unihash_data is not None: + unihash = unihash_data["unihash"] + else: + unihash = data["unihash"] + + return { + "taskhash": data["taskhash"], + "method": data["method"], + "unihash": unihash, + } + + @permissions(READ_PERM, REPORT_PERM) + async def handle_equivreport(self, data): + await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"]) + + # Fetch the unihash that will be reported for the taskhash. If the + # unihash matches, it means this row was inserted (or the mapping + # was already valid) + row = await self.db.get_equivalent(data["method"], data["taskhash"]) + + if row["unihash"] == data["unihash"]: + self.logger.info( + "Adding taskhash equivalence for %s with unihash %s", + data["taskhash"], + row["unihash"], + ) + + return {k: row[k] for k in ("taskhash", "method", "unihash")} + + @permissions(READ_PERM) + async def handle_get_stats(self, request): + return { + "requests": self.server.request_stats.todict(), + } + + @permissions(DB_ADMIN_PERM) + async def handle_reset_stats(self, request): + d = { + "requests": self.server.request_stats.todict(), + } + + self.server.request_stats.reset() + return d + + @permissions(READ_PERM) + async def handle_backfill_wait(self, request): + d = { + "tasks": self.server.backfill_queue.qsize(), + } + await self.server.backfill_queue.join() + return d + + @permissions(DB_ADMIN_PERM) + async def handle_remove(self, request): + condition = request["where"] + if not isinstance(condition, dict): + raise TypeError("Bad condition type %s" % type(condition)) + + return {"count": await self.db.remove(condition)} + + @permissions(DB_ADMIN_PERM) + async def handle_gc_mark(self, request): + condition = request["where"] + mark = request["mark"] + + if not isinstance(condition, dict): + raise TypeError("Bad condition type %s" % type(condition)) + + if not isinstance(mark, str): + raise TypeError("Bad mark type %s" % type(mark)) + + return {"count": await self.db.gc_mark(mark, condition)} + + @permissions(DB_ADMIN_PERM) + async def handle_gc_sweep(self, request): + mark = request["mark"] + + if not isinstance(mark, str): + raise TypeError("Bad mark type %s" % type(mark)) + + current_mark = await self.db.get_current_gc_mark() + + if not current_mark or mark != current_mark: + raise bb.asyncrpc.InvokeError( + f"'{mark}' is not the current mark. Refusing to sweep" + ) + + count = await self.db.gc_sweep() + + return {"count": count} + + @permissions(DB_ADMIN_PERM) + async def handle_gc_status(self, request): + (keep_rows, remove_rows, current_mark) = await self.db.gc_status() + return { + "keep": keep_rows, + "remove": remove_rows, + "mark": current_mark, + } + + @permissions(DB_ADMIN_PERM) + async def handle_clean_unused(self, request): + max_age = request["max_age_seconds"] + oldest = datetime.now() - timedelta(seconds=-max_age) + return {"count": await self.db.clean_unused(oldest)} + + @permissions(DB_ADMIN_PERM) + async def handle_get_db_usage(self, request): + return {"usage": await self.db.get_usage()} + + @permissions(DB_ADMIN_PERM) + async def handle_get_db_query_columns(self, request): + return {"columns": await self.db.get_query_columns()} + + # The authentication API is always allowed + async def handle_auth(self, request): + username = str(request["username"]) + token = str(request["token"]) + + async def fail_auth(): + nonlocal username + # Rate limit bad login attempts + await asyncio.sleep(1) + raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}") + + user, db_token = await self.db.lookup_user_token(username) + + if not user or not db_token: + await fail_auth() + + try: + algo, salt, _ = db_token.split(":") + except ValueError: + await fail_auth() + + if hash_token(algo, salt, token) != db_token: + await fail_auth() + + self.user = user + + self.logger.info("Authenticated as %s", username) + + return { + "result": True, + "username": self.user.username, + "permissions": sorted(list(self.user.permissions)), + } + + @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) + async def handle_refresh_token(self, request): + username = str(request["username"]) + + token = await new_token() + + updated = await self.db.set_user_token( + username, + hash_token(TOKEN_ALGORITHM, new_salt(), token), + ) + if not updated: + self.raise_no_user_error(username) + + return {"username": username, "token": token} + + def get_perm_arg(self, arg): + if not isinstance(arg, list): + raise bb.asyncrpc.InvokeError("Unexpected type for permissions") + + arg = set(arg) + try: + arg.remove(NONE_PERM) + except KeyError: + pass + + unknown_perms = arg - ALL_PERMISSIONS + if unknown_perms: + raise bb.asyncrpc.InvokeError( + "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms))) + ) + + return sorted(list(arg)) + + def return_perms(self, permissions): + if ALL_PERM in permissions: + return sorted(list(ALL_PERMISSIONS)) + return sorted(list(permissions)) + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_set_perms(self, request): + username = str(request["username"]) + permissions = self.get_perm_arg(request["permissions"]) + + if not await self.db.set_user_perms(username, permissions): + self.raise_no_user_error(username) + + return { + "username": username, + "permissions": self.return_perms(permissions), + } + + @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) + async def handle_get_user(self, request): + username = str(request["username"]) + + user = await self.db.lookup_user(username) + if user is None: + return None + + return { + "username": user.username, + "permissions": self.return_perms(user.permissions), + } + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_get_all_users(self, request): + users = await self.db.get_all_users() + return { + "users": [ + { + "username": u.username, + "permissions": self.return_perms(u.permissions), + } + for u in users + ] + } + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_new_user(self, request): + username = str(request["username"]) + permissions = self.get_perm_arg(request["permissions"]) + + token = await new_token() + + inserted = await self.db.new_user( + username, + permissions, + hash_token(TOKEN_ALGORITHM, new_salt(), token), + ) + if not inserted: + raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'") + + return { + "username": username, + "permissions": self.return_perms(permissions), + "token": token, + } + + @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False) + async def handle_delete_user(self, request): + username = str(request["username"]) + + if not await self.db.delete_user(username): + self.raise_no_user_error(username) + + return {"username": username} + + @permissions(USER_ADMIN_PERM, allow_anon=False) + async def handle_become_user(self, request): + username = str(request["username"]) + + user = await self.db.lookup_user(username) + if user is None: + raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist") + + self.user = user + + self.logger.info("Became user %s", username) + + return { + "username": self.user.username, + "permissions": self.return_perms(self.user.permissions), + } + + +class Server(bb.asyncrpc.AsyncServer): + def __init__( + self, + db_engine, + upstream=None, + read_only=False, + anon_perms=DEFAULT_ANON_PERMS, + admin_username=None, + admin_password=None, + ): + if upstream and read_only: + raise bb.asyncrpc.ServerError( + "Read-only hashserv cannot pull from an upstream server" + ) + + disallowed_perms = set(anon_perms) - set( + [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM] + ) + + if disallowed_perms: + raise bb.asyncrpc.ServerError( + f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users" + ) + + super().__init__(logger) + + self.request_stats = Stats() + self.db_engine = db_engine + self.upstream = upstream + self.read_only = read_only + self.backfill_queue = None + self.anon_perms = set(anon_perms) + self.admin_username = admin_username + self.admin_password = admin_password + + self.logger.info( + "Anonymous user permissions are: %s", ", ".join(self.anon_perms) + ) + + def accept_client(self, socket): + return ServerClient(socket, self) + + async def create_admin_user(self): + admin_permissions = (ALL_PERM,) + async with self.db_engine.connect(self.logger) as db: + added = await db.new_user( + self.admin_username, + admin_permissions, + hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password), + ) + if added: + self.logger.info("Created admin user '%s'", self.admin_username) + else: + await db.set_user_perms( + self.admin_username, + admin_permissions, + ) + await db.set_user_token( + self.admin_username, + hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password), + ) + self.logger.info("Admin user '%s' updated", self.admin_username) + + async def backfill_worker_task(self): + async with await create_async_client( + self.upstream + ) as client, self.db_engine.connect(self.logger) as db: + while True: + item = await self.backfill_queue.get() + if item is None: + self.backfill_queue.task_done() + break + + method, taskhash = item + d = await client.get_taskhash(method, taskhash) + if d is not None: + await db.insert_unihash(d["method"], d["taskhash"], d["unihash"]) + self.backfill_queue.task_done() + + def start(self): + tasks = super().start() + if self.upstream: + self.backfill_queue = asyncio.Queue() + tasks += [self.backfill_worker_task()] + + self.loop.run_until_complete(self.db_engine.create()) + + if self.admin_username: + self.loop.run_until_complete(self.create_admin_user()) + + return tasks + + async def stop(self): + if self.backfill_queue is not None: + await self.backfill_queue.put(None) + await super().stop() |