summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/hashserv/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/hashserv/server.py')
-rw-r--r--bitbake/lib/hashserv/server.py939
1 files changed, 680 insertions, 259 deletions
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
index 8e84989737..68f64f983b 100644
--- a/bitbake/lib/hashserv/server.py
+++ b/bitbake/lib/hashserv/server.py
@@ -3,17 +3,51 @@
# SPDX-License-Identifier: GPL-2.0-only
#
-from contextlib import closing, contextmanager
-from datetime import datetime
+from datetime import datetime, timedelta
import asyncio
import logging
import math
import time
-from . import create_async_client, TABLE_COLUMNS
+import os
+import base64
+import hashlib
+from . import create_async_client
import bb.asyncrpc
+logger = logging.getLogger("hashserv.server")
-logger = logging.getLogger('hashserv.server')
+
+# This permission only exists to match nothing
+NONE_PERM = "@none"
+
+READ_PERM = "@read"
+REPORT_PERM = "@report"
+DB_ADMIN_PERM = "@db-admin"
+USER_ADMIN_PERM = "@user-admin"
+ALL_PERM = "@all"
+
+ALL_PERMISSIONS = {
+ READ_PERM,
+ REPORT_PERM,
+ DB_ADMIN_PERM,
+ USER_ADMIN_PERM,
+ ALL_PERM,
+}
+
+DEFAULT_ANON_PERMS = (
+ READ_PERM,
+ REPORT_PERM,
+ DB_ADMIN_PERM,
+)
+
+TOKEN_ALGORITHM = "sha256"
+
+# 48 bytes of random data will result in 64 characters when base64
+# encoded. This number also ensures that the base64 encoding won't have any
+# trailing '=' characters.
+TOKEN_SIZE = 48
+
+SALT_SIZE = 8
class Measurement(object):
@@ -103,358 +137,745 @@ class Stats(object):
return math.sqrt(self.s / (self.num - 1))
def todict(self):
- return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
+ return {
+ k: getattr(self, k)
+ for k in ("num", "total_time", "max_time", "average", "stdev")
+ }
+
+
+token_refresh_semaphore = asyncio.Lock()
+
+
+async def new_token():
+ # Prevent malicious users from using this API to deduce the entropy
+ # pool on the server and thus be able to guess a token. *All* token
+ # refresh requests lock the same global semaphore and then sleep for a
+ # short time. The effectively rate limits the total number of requests
+ # than can be made across all clients to 10/second, which should be enough
+ # since you have to be an authenticated users to make the request in the
+ # first place
+ async with token_refresh_semaphore:
+ await asyncio.sleep(0.1)
+ raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK)
+
+ return base64.b64encode(raw, b"._").decode("utf-8")
+
+
+def new_salt():
+ return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex()
+
+
+def hash_token(algo, salt, token):
+ h = hashlib.new(algo)
+ h.update(salt.encode("utf-8"))
+ h.update(token.encode("utf-8"))
+ return ":".join([algo, salt, h.hexdigest()])
-def insert_task(cursor, data, ignore=False):
- keys = sorted(data.keys())
- query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % (
- " OR IGNORE" if ignore else "",
- ', '.join(keys),
- ', '.join(':' + k for k in keys))
- cursor.execute(query, data)
+def permissions(*permissions, allow_anon=True, allow_self_service=False):
+ """
+ Function decorator that can be used to decorate an RPC function call and
+ check that the current users permissions match the require permissions.
-async def copy_from_upstream(client, db, method, taskhash):
- d = await client.get_taskhash(method, taskhash, True)
- if d is not None:
- # Filter out unknown columns
- d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
+ If allow_anon is True, the user will also be allowed to make the RPC call
+ if the anonymous user permissions match the permissions.
- with closing(db.cursor()) as cursor:
- insert_task(cursor, d)
- db.commit()
+ If allow_self_service is True, and the "username" property in the request
+ is the currently logged in user, or not specified, the user will also be
+ allowed to make the request. This allows users to access normal privileged
+ API, as long as they are only modifying their own user properties (e.g.
+ users can be allowed to reset their own token without @user-admin
+ permissions, but not the token for any other user.
+ """
- return d
+ def wrapper(func):
+ async def wrap(self, request):
+ if allow_self_service and self.user is not None:
+ username = request.get("username", self.user.username)
+ if username == self.user.username:
+ request["username"] = self.user.username
+ return await func(self, request)
-async def copy_outhash_from_upstream(client, db, method, outhash, taskhash):
- d = await client.get_outhash(method, outhash, taskhash)
- if d is not None:
- # Filter out unknown columns
- d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
+ if not self.user_has_permissions(*permissions, allow_anon=allow_anon):
+ if not self.user:
+ username = "Anonymous user"
+ user_perms = self.server.anon_perms
+ else:
+ username = self.user.username
+ user_perms = self.user.permissions
+
+ self.logger.info(
+ "User %s with permissions %r denied from calling %s. Missing permissions(s) %r",
+ username,
+ ", ".join(user_perms),
+ func.__name__,
+ ", ".join(permissions),
+ )
+ raise bb.asyncrpc.InvokeError(
+ f"{username} is not allowed to access permissions(s) {', '.join(permissions)}"
+ )
+
+ return await func(self, request)
- with closing(db.cursor()) as cursor:
- insert_task(cursor, d)
- db.commit()
+ return wrap
+
+ return wrapper
- return d
class ServerClient(bb.asyncrpc.AsyncServerConnection):
- FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
- ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
- OUTHASH_QUERY = '''
- -- Find tasks with a matching outhash (that is, tasks that
- -- are equivalent)
- SELECT * FROM tasks_v2 WHERE method=:method AND outhash=:outhash
-
- -- If there is an exact match on the taskhash, return it.
- -- Otherwise return the oldest matching outhash of any
- -- taskhash
- ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
- created ASC
-
- -- Only return one row
- LIMIT 1
- '''
-
- def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
- super().__init__(reader, writer, 'OEHASHEQUIV', logger)
- self.db = db
- self.request_stats = request_stats
+ def __init__(self, socket, server):
+ super().__init__(socket, "OEHASHEQUIV", server.logger)
+ self.server = server
self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
- self.backfill_queue = backfill_queue
- self.upstream = upstream
+ self.user = None
+
+ self.handlers.update(
+ {
+ "get": self.handle_get,
+ "get-outhash": self.handle_get_outhash,
+ "get-stream": self.handle_get_stream,
+ "exists-stream": self.handle_exists_stream,
+ "get-stats": self.handle_get_stats,
+ "get-db-usage": self.handle_get_db_usage,
+ "get-db-query-columns": self.handle_get_db_query_columns,
+ # Not always read-only, but internally checks if the server is
+ # read-only
+ "report": self.handle_report,
+ "auth": self.handle_auth,
+ "get-user": self.handle_get_user,
+ "get-all-users": self.handle_get_all_users,
+ "become-user": self.handle_become_user,
+ }
+ )
+
+ if not self.server.read_only:
+ self.handlers.update(
+ {
+ "report-equiv": self.handle_equivreport,
+ "reset-stats": self.handle_reset_stats,
+ "backfill-wait": self.handle_backfill_wait,
+ "remove": self.handle_remove,
+ "gc-mark": self.handle_gc_mark,
+ "gc-sweep": self.handle_gc_sweep,
+ "gc-status": self.handle_gc_status,
+ "clean-unused": self.handle_clean_unused,
+ "refresh-token": self.handle_refresh_token,
+ "set-user-perms": self.handle_set_perms,
+ "new-user": self.handle_new_user,
+ "delete-user": self.handle_delete_user,
+ }
+ )
+
+ def raise_no_user_error(self, username):
+ raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists")
+
+ def user_has_permissions(self, *permissions, allow_anon=True):
+ permissions = set(permissions)
+ if allow_anon:
+ if ALL_PERM in self.server.anon_perms:
+ return True
- self.handlers.update({
- 'get': self.handle_get,
- 'get-outhash': self.handle_get_outhash,
- 'get-stream': self.handle_get_stream,
- 'get-stats': self.handle_get_stats,
- })
-
- if not read_only:
- self.handlers.update({
- 'report': self.handle_report,
- 'report-equiv': self.handle_equivreport,
- 'reset-stats': self.handle_reset_stats,
- 'backfill-wait': self.handle_backfill_wait,
- })
+ if not permissions - self.server.anon_perms:
+ return True
+
+ if self.user is None:
+ return False
+
+ if ALL_PERM in self.user.permissions:
+ return True
+
+ if not permissions - self.user.permissions:
+ return True
+
+ return False
def validate_proto_version(self):
- return (self.proto_version > (1, 0) and self.proto_version <= (1, 1))
+ return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
async def process_requests(self):
- if self.upstream is not None:
- self.upstream_client = await create_async_client(self.upstream)
- else:
- self.upstream_client = None
-
- await super().process_requests()
+ async with self.server.db_engine.connect(self.logger) as db:
+ self.db = db
+ if self.server.upstream is not None:
+ self.upstream_client = await create_async_client(self.server.upstream)
+ else:
+ self.upstream_client = None
- if self.upstream_client is not None:
- await self.upstream_client.close()
+ try:
+ await super().process_requests()
+ finally:
+ if self.upstream_client is not None:
+ await self.upstream_client.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- logger.debug('Handling %s' % k)
- if 'stream' in k:
- await self.handlers[k](msg[k])
+ self.logger.debug("Handling %s" % k)
+ if "stream" in k:
+ return await self.handlers[k](msg[k])
else:
- with self.request_stats.start_sample() as self.request_sample, \
- self.request_sample.measure():
- await self.handlers[k](msg[k])
- return
+ with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure():
+ return await self.handlers[k](msg[k])
raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
+ @permissions(READ_PERM)
async def handle_get(self, request):
- method = request['method']
- taskhash = request['taskhash']
-
- if request.get('all', False):
- row = self.query_equivalent(method, taskhash, self.ALL_QUERY)
+ method = request["method"]
+ taskhash = request["taskhash"]
+ fetch_all = request.get("all", False)
+
+ return await self.get_unihash(method, taskhash, fetch_all)
+
+ async def get_unihash(self, method, taskhash, fetch_all=False):
+ d = None
+
+ if fetch_all:
+ row = await self.db.get_unihash_by_taskhash_full(method, taskhash)
+ if row is not None:
+ d = {k: row[k] for k in row.keys()}
+ elif self.upstream_client is not None:
+ d = await self.upstream_client.get_taskhash(method, taskhash, True)
+ await self.update_unified(d)
else:
- row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
+ row = await self.db.get_equivalent(method, taskhash)
- if row is not None:
- logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
- d = {k: row[k] for k in row.keys()}
- elif self.upstream_client is not None:
- d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash)
- else:
- d = None
+ if row is not None:
+ d = {k: row[k] for k in row.keys()}
+ elif self.upstream_client is not None:
+ d = await self.upstream_client.get_taskhash(method, taskhash)
+ await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
- self.write_message(d)
+ return d
+ @permissions(READ_PERM)
async def handle_get_outhash(self, request):
- with closing(self.db.cursor()) as cursor:
- cursor.execute(self.OUTHASH_QUERY,
- {k: request[k] for k in ('method', 'outhash', 'taskhash')})
+ method = request["method"]
+ outhash = request["outhash"]
+ taskhash = request["taskhash"]
+ with_unihash = request.get("with_unihash", True)
+
+ return await self.get_outhash(method, outhash, taskhash, with_unihash)
- row = cursor.fetchone()
+ async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
+ d = None
+ if with_unihash:
+ row = await self.db.get_unihash_by_outhash(method, outhash)
+ else:
+ row = await self.db.get_outhash(method, outhash)
if row is not None:
- logger.debug('Found equivalent outhash %s -> %s', (row['outhash'], row['unihash']))
d = {k: row[k] for k in row.keys()}
- else:
- d = None
+ elif self.upstream_client is not None:
+ d = await self.upstream_client.get_outhash(method, outhash, taskhash)
+ await self.update_unified(d)
- self.write_message(d)
+ return d
- async def handle_get_stream(self, request):
- self.write_message('ok')
+ async def update_unified(self, data):
+ if data is None:
+ return
+
+ await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+ await self.db.insert_outhash(data)
+
+ async def _stream_handler(self, handler):
+ await self.socket.send_message("ok")
while True:
upstream = None
- l = await self.reader.readline()
+ l = await self.socket.recv()
if not l:
- return
+ break
try:
# This inner loop is very sensitive and must be as fast as
# possible (which is why the request sample is handled manually
# instead of using 'with', and also why logging statements are
# commented out.
- self.request_sample = self.request_stats.start_sample()
+ self.request_sample = self.server.request_stats.start_sample()
request_measure = self.request_sample.measure()
request_measure.start()
- l = l.decode('utf-8').rstrip()
- if l == 'END':
- self.writer.write('ok\n'.encode('utf-8'))
- return
-
- (method, taskhash) = l.split()
- #logger.debug('Looking up %s %s' % (method, taskhash))
- row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
- if row is not None:
- msg = ('%s\n' % row['unihash']).encode('utf-8')
- #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
- elif self.upstream_client is not None:
- upstream = await self.upstream_client.get_unihash(method, taskhash)
- if upstream:
- msg = ("%s\n" % upstream).encode("utf-8")
- else:
- msg = "\n".encode("utf-8")
- else:
- msg = '\n'.encode('utf-8')
+ if l == "END":
+ break
- self.writer.write(msg)
+ msg = await handler(l)
+ await self.socket.send(msg)
finally:
request_measure.end()
self.request_sample.end()
- await self.writer.drain()
+ await self.socket.send("ok")
+ return self.NO_RESPONSE
- # Post to the backfill queue after writing the result to minimize
- # the turn around time on a request
- if upstream is not None:
- await self.backfill_queue.put((method, taskhash))
+ @permissions(READ_PERM)
+ async def handle_get_stream(self, request):
+ async def handler(l):
+ (method, taskhash) = l.split()
+ # self.logger.debug('Looking up %s %s' % (method, taskhash))
+ row = await self.db.get_equivalent(method, taskhash)
- async def handle_report(self, data):
- with closing(self.db.cursor()) as cursor:
- cursor.execute(self.OUTHASH_QUERY,
- {k: data[k] for k in ('method', 'outhash', 'taskhash')})
-
- row = cursor.fetchone()
-
- if row is None and self.upstream_client:
- # Try upstream
- row = await copy_outhash_from_upstream(self.upstream_client,
- self.db,
- data['method'],
- data['outhash'],
- data['taskhash'])
-
- # If no matching outhash was found, or one *was* found but it
- # wasn't an exact match on the taskhash, a new entry for this
- # taskhash should be added
- if row is None or row['taskhash'] != data['taskhash']:
- # If a row matching the outhash was found, the unihash for
- # the new taskhash should be the same as that one.
- # Otherwise the caller provided unihash is used.
- unihash = data['unihash']
- if row is not None:
- unihash = row['unihash']
-
- insert_data = {
- 'method': data['method'],
- 'outhash': data['outhash'],
- 'taskhash': data['taskhash'],
- 'unihash': unihash,
- 'created': datetime.now()
- }
+ if row is not None:
+ # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
+ return row["unihash"]
- for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
- if k in data:
- insert_data[k] = data[k]
+ if self.upstream_client is not None:
+ upstream = await self.upstream_client.get_unihash(method, taskhash)
+ if upstream:
+ await self.server.backfill_queue.put((method, taskhash))
+ return upstream
- insert_task(cursor, insert_data)
- self.db.commit()
+ return ""
- logger.info('Adding taskhash %s with unihash %s',
- data['taskhash'], unihash)
+ return await self._stream_handler(handler)
- d = {
- 'taskhash': data['taskhash'],
- 'method': data['method'],
- 'unihash': unihash
- }
- else:
- d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
+ @permissions(READ_PERM)
+ async def handle_exists_stream(self, request):
+ async def handler(l):
+ if await self.db.unihash_exists(l):
+ return "true"
- self.write_message(d)
+ if self.upstream_client is not None:
+ if await self.upstream_client.unihash_exists(l):
+ return "true"
- async def handle_equivreport(self, data):
- with closing(self.db.cursor()) as cursor:
- insert_data = {
- 'method': data['method'],
- 'outhash': "",
- 'taskhash': data['taskhash'],
- 'unihash': data['unihash'],
- 'created': datetime.now()
- }
+ return "false"
+
+ return await self._stream_handler(handler)
+
+ async def report_readonly(self, data):
+ method = data["method"]
+ outhash = data["outhash"]
+ taskhash = data["taskhash"]
+
+ info = await self.get_outhash(method, outhash, taskhash)
+ if info:
+ unihash = info["unihash"]
+ else:
+ unihash = data["unihash"]
- for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
- if k in data:
- insert_data[k] = data[k]
+ return {
+ "taskhash": taskhash,
+ "method": method,
+ "unihash": unihash,
+ }
+
+ # Since this can be called either read only or to report, the check to
+ # report is made inside the function
+ @permissions(READ_PERM)
+ async def handle_report(self, data):
+ if self.server.read_only or not self.user_has_permissions(REPORT_PERM):
+ return await self.report_readonly(data)
+
+ outhash_data = {
+ "method": data["method"],
+ "outhash": data["outhash"],
+ "taskhash": data["taskhash"],
+ "created": datetime.now(),
+ }
- insert_task(cursor, insert_data, ignore=True)
- self.db.commit()
+ for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"):
+ if k in data:
+ outhash_data[k] = data[k]
+
+ if self.user:
+ outhash_data["owner"] = self.user.username
+
+ # Insert the new entry, unless it already exists
+ if await self.db.insert_outhash(outhash_data):
+ # If this row is new, check if it is equivalent to another
+ # output hash
+ row = await self.db.get_equivalent_for_outhash(
+ data["method"], data["outhash"], data["taskhash"]
+ )
+
+ if row is not None:
+ # A matching output hash was found. Set our taskhash to the
+ # same unihash since they are equivalent
+ unihash = row["unihash"]
+ else:
+ # No matching output hash was found. This is probably the
+ # first outhash to be added.
+ unihash = data["unihash"]
+
+ # Query upstream to see if it has a unihash we can use
+ if self.upstream_client is not None:
+ upstream_data = await self.upstream_client.get_outhash(
+ data["method"], data["outhash"], data["taskhash"]
+ )
+ if upstream_data is not None:
+ unihash = upstream_data["unihash"]
+
+ await self.db.insert_unihash(data["method"], data["taskhash"], unihash)
+
+ unihash_data = await self.get_unihash(data["method"], data["taskhash"])
+ if unihash_data is not None:
+ unihash = unihash_data["unihash"]
+ else:
+ unihash = data["unihash"]
- # Fetch the unihash that will be reported for the taskhash. If the
- # unihash matches, it means this row was inserted (or the mapping
- # was already valid)
- row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY)
+ return {
+ "taskhash": data["taskhash"],
+ "method": data["method"],
+ "unihash": unihash,
+ }
- if row['unihash'] == data['unihash']:
- logger.info('Adding taskhash equivalence for %s with unihash %s',
- data['taskhash'], row['unihash'])
+ @permissions(READ_PERM, REPORT_PERM)
+ async def handle_equivreport(self, data):
+ await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
- d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
+ # Fetch the unihash that will be reported for the taskhash. If the
+ # unihash matches, it means this row was inserted (or the mapping
+ # was already valid)
+ row = await self.db.get_equivalent(data["method"], data["taskhash"])
- self.write_message(d)
+ if row["unihash"] == data["unihash"]:
+ self.logger.info(
+ "Adding taskhash equivalence for %s with unihash %s",
+ data["taskhash"],
+ row["unihash"],
+ )
+ return {k: row[k] for k in ("taskhash", "method", "unihash")}
+ @permissions(READ_PERM)
async def handle_get_stats(self, request):
- d = {
- 'requests': self.request_stats.todict(),
+ return {
+ "requests": self.server.request_stats.todict(),
}
- self.write_message(d)
-
+ @permissions(DB_ADMIN_PERM)
async def handle_reset_stats(self, request):
d = {
- 'requests': self.request_stats.todict(),
+ "requests": self.server.request_stats.todict(),
}
- self.request_stats.reset()
- self.write_message(d)
+ self.server.request_stats.reset()
+ return d
+ @permissions(READ_PERM)
async def handle_backfill_wait(self, request):
d = {
- 'tasks': self.backfill_queue.qsize(),
+ "tasks": self.server.backfill_queue.qsize(),
}
- await self.backfill_queue.join()
- self.write_message(d)
+ await self.server.backfill_queue.join()
+ return d
+
+ @permissions(DB_ADMIN_PERM)
+ async def handle_remove(self, request):
+ condition = request["where"]
+ if not isinstance(condition, dict):
+ raise TypeError("Bad condition type %s" % type(condition))
+
+ return {"count": await self.db.remove(condition)}
+
+ @permissions(DB_ADMIN_PERM)
+ async def handle_gc_mark(self, request):
+ condition = request["where"]
+ mark = request["mark"]
+
+ if not isinstance(condition, dict):
+ raise TypeError("Bad condition type %s" % type(condition))
+
+ if not isinstance(mark, str):
+ raise TypeError("Bad mark type %s" % type(mark))
+
+ return {"count": await self.db.gc_mark(mark, condition)}
+
+ @permissions(DB_ADMIN_PERM)
+ async def handle_gc_sweep(self, request):
+ mark = request["mark"]
+
+ if not isinstance(mark, str):
+ raise TypeError("Bad mark type %s" % type(mark))
+
+ current_mark = await self.db.get_current_gc_mark()
+
+ if not current_mark or mark != current_mark:
+ raise bb.asyncrpc.InvokeError(
+ f"'{mark}' is not the current mark. Refusing to sweep"
+ )
+
+ count = await self.db.gc_sweep()
+
+ return {"count": count}
+
+ @permissions(DB_ADMIN_PERM)
+ async def handle_gc_status(self, request):
+ (keep_rows, remove_rows, current_mark) = await self.db.gc_status()
+ return {
+ "keep": keep_rows,
+ "remove": remove_rows,
+ "mark": current_mark,
+ }
+
+ @permissions(DB_ADMIN_PERM)
+ async def handle_clean_unused(self, request):
+ max_age = request["max_age_seconds"]
+ oldest = datetime.now() - timedelta(seconds=-max_age)
+ return {"count": await self.db.clean_unused(oldest)}
+
+ @permissions(DB_ADMIN_PERM)
+ async def handle_get_db_usage(self, request):
+ return {"usage": await self.db.get_usage()}
+
+ @permissions(DB_ADMIN_PERM)
+ async def handle_get_db_query_columns(self, request):
+ return {"columns": await self.db.get_query_columns()}
+
+ # The authentication API is always allowed
+ async def handle_auth(self, request):
+ username = str(request["username"])
+ token = str(request["token"])
+
+ async def fail_auth():
+ nonlocal username
+ # Rate limit bad login attempts
+ await asyncio.sleep(1)
+ raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}")
+
+ user, db_token = await self.db.lookup_user_token(username)
+
+ if not user or not db_token:
+ await fail_auth()
- def query_equivalent(self, method, taskhash, query):
- # This is part of the inner loop and must be as fast as possible
try:
- cursor = self.db.cursor()
- cursor.execute(query, {'method': method, 'taskhash': taskhash})
- return cursor.fetchone()
- except:
- cursor.close()
+ algo, salt, _ = db_token.split(":")
+ except ValueError:
+ await fail_auth()
+
+ if hash_token(algo, salt, token) != db_token:
+ await fail_auth()
+
+ self.user = user
+
+ self.logger.info("Authenticated as %s", username)
+
+ return {
+ "result": True,
+ "username": self.user.username,
+ "permissions": sorted(list(self.user.permissions)),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
+ async def handle_refresh_token(self, request):
+ username = str(request["username"])
+
+ token = await new_token()
+
+ updated = await self.db.set_user_token(
+ username,
+ hash_token(TOKEN_ALGORITHM, new_salt(), token),
+ )
+ if not updated:
+ self.raise_no_user_error(username)
+
+ return {"username": username, "token": token}
+
+ def get_perm_arg(self, arg):
+ if not isinstance(arg, list):
+ raise bb.asyncrpc.InvokeError("Unexpected type for permissions")
+
+ arg = set(arg)
+ try:
+ arg.remove(NONE_PERM)
+ except KeyError:
+ pass
+
+ unknown_perms = arg - ALL_PERMISSIONS
+ if unknown_perms:
+ raise bb.asyncrpc.InvokeError(
+ "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms)))
+ )
+
+ return sorted(list(arg))
+
+ def return_perms(self, permissions):
+ if ALL_PERM in permissions:
+ return sorted(list(ALL_PERMISSIONS))
+ return sorted(list(permissions))
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_set_perms(self, request):
+ username = str(request["username"])
+ permissions = self.get_perm_arg(request["permissions"])
+
+ if not await self.db.set_user_perms(username, permissions):
+ self.raise_no_user_error(username)
+
+ return {
+ "username": username,
+ "permissions": self.return_perms(permissions),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
+ async def handle_get_user(self, request):
+ username = str(request["username"])
+
+ user = await self.db.lookup_user(username)
+ if user is None:
+ return None
+
+ return {
+ "username": user.username,
+ "permissions": self.return_perms(user.permissions),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_get_all_users(self, request):
+ users = await self.db.get_all_users()
+ return {
+ "users": [
+ {
+ "username": u.username,
+ "permissions": self.return_perms(u.permissions),
+ }
+ for u in users
+ ]
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_new_user(self, request):
+ username = str(request["username"])
+ permissions = self.get_perm_arg(request["permissions"])
+
+ token = await new_token()
+
+ inserted = await self.db.new_user(
+ username,
+ permissions,
+ hash_token(TOKEN_ALGORITHM, new_salt(), token),
+ )
+ if not inserted:
+ raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'")
+
+ return {
+ "username": username,
+ "permissions": self.return_perms(permissions),
+ "token": token,
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
+ async def handle_delete_user(self, request):
+ username = str(request["username"])
+
+ if not await self.db.delete_user(username):
+ self.raise_no_user_error(username)
+
+ return {"username": username}
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_become_user(self, request):
+ username = str(request["username"])
+
+ user = await self.db.lookup_user(username)
+ if user is None:
+ raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist")
+
+ self.user = user
+
+ self.logger.info("Became user %s", username)
+
+ return {
+ "username": self.user.username,
+ "permissions": self.return_perms(self.user.permissions),
+ }
class Server(bb.asyncrpc.AsyncServer):
- def __init__(self, db, loop=None, upstream=None, read_only=False):
+ def __init__(
+ self,
+ db_engine,
+ upstream=None,
+ read_only=False,
+ anon_perms=DEFAULT_ANON_PERMS,
+ admin_username=None,
+ admin_password=None,
+ ):
if upstream and read_only:
- raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server")
+ raise bb.asyncrpc.ServerError(
+ "Read-only hashserv cannot pull from an upstream server"
+ )
+
+ disallowed_perms = set(anon_perms) - set(
+ [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM]
+ )
- super().__init__(logger, loop)
+ if disallowed_perms:
+ raise bb.asyncrpc.ServerError(
+ f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users"
+ )
+
+ super().__init__(logger)
self.request_stats = Stats()
- self.db = db
+ self.db_engine = db_engine
self.upstream = upstream
self.read_only = read_only
+ self.backfill_queue = None
+ self.anon_perms = set(anon_perms)
+ self.admin_username = admin_username
+ self.admin_password = admin_password
+
+ self.logger.info(
+ "Anonymous user permissions are: %s", ", ".join(self.anon_perms)
+ )
+
+ def accept_client(self, socket):
+ return ServerClient(socket, self)
+
+ async def create_admin_user(self):
+ admin_permissions = (ALL_PERM,)
+ async with self.db_engine.connect(self.logger) as db:
+ added = await db.new_user(
+ self.admin_username,
+ admin_permissions,
+ hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
+ )
+ if added:
+ self.logger.info("Created admin user '%s'", self.admin_username)
+ else:
+ await db.set_user_perms(
+ self.admin_username,
+ admin_permissions,
+ )
+ await db.set_user_token(
+ self.admin_username,
+ hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
+ )
+ self.logger.info("Admin user '%s' updated", self.admin_username)
+
+ async def backfill_worker_task(self):
+ async with await create_async_client(
+ self.upstream
+ ) as client, self.db_engine.connect(self.logger) as db:
+ while True:
+ item = await self.backfill_queue.get()
+ if item is None:
+ self.backfill_queue.task_done()
+ break
- def accept_client(self, reader, writer):
- return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
+ method, taskhash = item
+ d = await client.get_taskhash(method, taskhash)
+ if d is not None:
+ await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
+ self.backfill_queue.task_done()
- @contextmanager
- def _backfill_worker(self):
- async def backfill_worker_task():
- client = await create_async_client(self.upstream)
- try:
- while True:
- item = await self.backfill_queue.get()
- if item is None:
- self.backfill_queue.task_done()
- break
- method, taskhash = item
- await copy_from_upstream(client, self.db, method, taskhash)
- self.backfill_queue.task_done()
- finally:
- await client.close()
+ def start(self):
+ tasks = super().start()
+ if self.upstream:
+ self.backfill_queue = asyncio.Queue()
+ tasks += [self.backfill_worker_task()]
- async def join_worker(worker):
- await self.backfill_queue.put(None)
- await worker
+ self.loop.run_until_complete(self.db_engine.create())
- if self.upstream is not None:
- worker = asyncio.ensure_future(backfill_worker_task())
- try:
- yield
- finally:
- self.loop.run_until_complete(join_worker(worker))
- else:
- yield
+ if self.admin_username:
+ self.loop.run_until_complete(self.create_admin_user())
- def run_loop_forever(self):
- self.backfill_queue = asyncio.Queue()
+ return tasks
- with self._backfill_worker():
- super().run_loop_forever()
+ async def stop(self):
+ if self.backfill_queue is not None:
+ await self.backfill_queue.put(None)
+ await super().stop()