summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/hashserv/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/hashserv/server.py')
-rw-r--r--bitbake/lib/hashserv/server.py986
1 files changed, 689 insertions, 297 deletions
diff --git a/bitbake/lib/hashserv/server.py b/bitbake/lib/hashserv/server.py
index 81050715ea..68f64f983b 100644
--- a/bitbake/lib/hashserv/server.py
+++ b/bitbake/lib/hashserv/server.py
@@ -3,19 +3,51 @@
# SPDX-License-Identifier: GPL-2.0-only
#
-from contextlib import closing
-from datetime import datetime
+from datetime import datetime, timedelta
import asyncio
-import json
import logging
import math
-import os
-import signal
-import socket
import time
-from . import chunkify, DEFAULT_MAX_CHUNK
+import os
+import base64
+import hashlib
+from . import create_async_client
+import bb.asyncrpc
+
+logger = logging.getLogger("hashserv.server")
+
+
+# This permission only exists to match nothing
+NONE_PERM = "@none"
+
+READ_PERM = "@read"
+REPORT_PERM = "@report"
+DB_ADMIN_PERM = "@db-admin"
+USER_ADMIN_PERM = "@user-admin"
+ALL_PERM = "@all"
+
+ALL_PERMISSIONS = {
+ READ_PERM,
+ REPORT_PERM,
+ DB_ADMIN_PERM,
+ USER_ADMIN_PERM,
+ ALL_PERM,
+}
+
+DEFAULT_ANON_PERMS = (
+ READ_PERM,
+ REPORT_PERM,
+ DB_ADMIN_PERM,
+)
+
+TOKEN_ALGORITHM = "sha256"
+
+# 48 bytes of random data will result in 64 characters when base64
+# encoded. This number also ensures that the base64 encoding won't have any
+# trailing '=' characters.
+TOKEN_SIZE = 48
-logger = logging.getLogger('hashserv.server')
+SALT_SIZE = 8
class Measurement(object):
@@ -105,385 +137,745 @@ class Stats(object):
return math.sqrt(self.s / (self.num - 1))
def todict(self):
- return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
-
-
-class ClientError(Exception):
- pass
-
-class ServerClient(object):
- FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
- ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
-
- def __init__(self, reader, writer, db, request_stats):
- self.reader = reader
- self.writer = writer
- self.db = db
- self.request_stats = request_stats
- self.max_chunk = DEFAULT_MAX_CHUNK
-
- self.handlers = {
- 'get': self.handle_get,
- 'report': self.handle_report,
- 'report-equiv': self.handle_equivreport,
- 'get-stream': self.handle_get_stream,
- 'get-stats': self.handle_get_stats,
- 'reset-stats': self.handle_reset_stats,
- 'chunk-stream': self.handle_chunk,
+ return {
+ k: getattr(self, k)
+ for k in ("num", "total_time", "max_time", "average", "stdev")
}
- async def process_requests(self):
- try:
- self.addr = self.writer.get_extra_info('peername')
- logger.debug('Client %r connected' % (self.addr,))
- # Read protocol and version
- protocol = await self.reader.readline()
- if protocol is None:
- return
+token_refresh_semaphore = asyncio.Lock()
- (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
- if proto_name != 'OEHASHEQUIV':
- return
- proto_version = tuple(int(v) for v in proto_version.split('.'))
- if proto_version < (1, 0) or proto_version > (1, 1):
- return
+async def new_token():
+ # Prevent malicious users from using this API to deduce the entropy
+ # pool on the server and thus be able to guess a token. *All* token
+ # refresh requests lock the same global semaphore and then sleep for a
+ # short time. The effectively rate limits the total number of requests
+ # than can be made across all clients to 10/second, which should be enough
+ # since you have to be an authenticated users to make the request in the
+ # first place
+ async with token_refresh_semaphore:
+ await asyncio.sleep(0.1)
+ raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK)
- # Read headers. Currently, no headers are implemented, so look for
- # an empty line to signal the end of the headers
- while True:
- line = await self.reader.readline()
- if line is None:
- return
+ return base64.b64encode(raw, b"._").decode("utf-8")
- line = line.decode('utf-8').rstrip()
- if not line:
- break
- # Handle messages
- while True:
- d = await self.read_message()
- if d is None:
- break
- await self.dispatch_message(d)
- await self.writer.drain()
- except ClientError as e:
- logger.error(str(e))
- finally:
- self.writer.close()
+def new_salt():
+ return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex()
- async def dispatch_message(self, msg):
- for k in self.handlers.keys():
- if k in msg:
- logger.debug('Handling %s' % k)
- if 'stream' in k:
- await self.handlers[k](msg[k])
+
+def hash_token(algo, salt, token):
+ h = hashlib.new(algo)
+ h.update(salt.encode("utf-8"))
+ h.update(token.encode("utf-8"))
+ return ":".join([algo, salt, h.hexdigest()])
+
+
+def permissions(*permissions, allow_anon=True, allow_self_service=False):
+ """
+ Function decorator that can be used to decorate an RPC function call and
+ check that the current users permissions match the require permissions.
+
+ If allow_anon is True, the user will also be allowed to make the RPC call
+ if the anonymous user permissions match the permissions.
+
+ If allow_self_service is True, and the "username" property in the request
+ is the currently logged in user, or not specified, the user will also be
+ allowed to make the request. This allows users to access normal privileged
+ API, as long as they are only modifying their own user properties (e.g.
+ users can be allowed to reset their own token without @user-admin
+ permissions, but not the token for any other user.
+ """
+
+ def wrapper(func):
+ async def wrap(self, request):
+ if allow_self_service and self.user is not None:
+ username = request.get("username", self.user.username)
+ if username == self.user.username:
+ request["username"] = self.user.username
+ return await func(self, request)
+
+ if not self.user_has_permissions(*permissions, allow_anon=allow_anon):
+ if not self.user:
+ username = "Anonymous user"
+ user_perms = self.server.anon_perms
else:
- with self.request_stats.start_sample() as self.request_sample, \
- self.request_sample.measure():
- await self.handlers[k](msg[k])
- return
+ username = self.user.username
+ user_perms = self.user.permissions
+
+ self.logger.info(
+ "User %s with permissions %r denied from calling %s. Missing permissions(s) %r",
+ username,
+ ", ".join(user_perms),
+ func.__name__,
+ ", ".join(permissions),
+ )
+ raise bb.asyncrpc.InvokeError(
+ f"{username} is not allowed to access permissions(s) {', '.join(permissions)}"
+ )
+
+ return await func(self, request)
+
+ return wrap
+
+ return wrapper
+
+
+class ServerClient(bb.asyncrpc.AsyncServerConnection):
+ def __init__(self, socket, server):
+ super().__init__(socket, "OEHASHEQUIV", server.logger)
+ self.server = server
+ self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
+ self.user = None
+
+ self.handlers.update(
+ {
+ "get": self.handle_get,
+ "get-outhash": self.handle_get_outhash,
+ "get-stream": self.handle_get_stream,
+ "exists-stream": self.handle_exists_stream,
+ "get-stats": self.handle_get_stats,
+ "get-db-usage": self.handle_get_db_usage,
+ "get-db-query-columns": self.handle_get_db_query_columns,
+ # Not always read-only, but internally checks if the server is
+ # read-only
+ "report": self.handle_report,
+ "auth": self.handle_auth,
+ "get-user": self.handle_get_user,
+ "get-all-users": self.handle_get_all_users,
+ "become-user": self.handle_become_user,
+ }
+ )
+
+ if not self.server.read_only:
+ self.handlers.update(
+ {
+ "report-equiv": self.handle_equivreport,
+ "reset-stats": self.handle_reset_stats,
+ "backfill-wait": self.handle_backfill_wait,
+ "remove": self.handle_remove,
+ "gc-mark": self.handle_gc_mark,
+ "gc-sweep": self.handle_gc_sweep,
+ "gc-status": self.handle_gc_status,
+ "clean-unused": self.handle_clean_unused,
+ "refresh-token": self.handle_refresh_token,
+ "set-user-perms": self.handle_set_perms,
+ "new-user": self.handle_new_user,
+ "delete-user": self.handle_delete_user,
+ }
+ )
- raise ClientError("Unrecognized command %r" % msg)
+ def raise_no_user_error(self, username):
+ raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists")
- def write_message(self, msg):
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode('utf-8'))
+ def user_has_permissions(self, *permissions, allow_anon=True):
+ permissions = set(permissions)
+ if allow_anon:
+ if ALL_PERM in self.server.anon_perms:
+ return True
- async def read_message(self):
- l = await self.reader.readline()
- if not l:
- return None
+ if not permissions - self.server.anon_perms:
+ return True
- try:
- message = l.decode('utf-8')
+ if self.user is None:
+ return False
- if not message.endswith('\n'):
- return None
+ if ALL_PERM in self.user.permissions:
+ return True
- return json.loads(message)
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- logger.error('Bad message from client: %r' % message)
- raise e
+ if not permissions - self.user.permissions:
+ return True
- async def handle_chunk(self, request):
- lines = []
- try:
- while True:
- l = await self.reader.readline()
- l = l.rstrip(b"\n").decode("utf-8")
- if not l:
- break
- lines.append(l)
+ return False
+
+ def validate_proto_version(self):
+ return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
+
+ async def process_requests(self):
+ async with self.server.db_engine.connect(self.logger) as db:
+ self.db = db
+ if self.server.upstream is not None:
+ self.upstream_client = await create_async_client(self.server.upstream)
+ else:
+ self.upstream_client = None
- msg = json.loads(''.join(lines))
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- logger.error('Bad message from client: %r' % message)
- raise e
+ try:
+ await super().process_requests()
+ finally:
+ if self.upstream_client is not None:
+ await self.upstream_client.close()
- if 'chunk-stream' in msg:
- raise ClientError("Nested chunks are not allowed")
+ async def dispatch_message(self, msg):
+ for k in self.handlers.keys():
+ if k in msg:
+ self.logger.debug("Handling %s" % k)
+ if "stream" in k:
+ return await self.handlers[k](msg[k])
+ else:
+ with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure():
+ return await self.handlers[k](msg[k])
- await self.dispatch_message(msg)
+ raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
+ @permissions(READ_PERM)
async def handle_get(self, request):
- method = request['method']
- taskhash = request['taskhash']
+ method = request["method"]
+ taskhash = request["taskhash"]
+ fetch_all = request.get("all", False)
+
+ return await self.get_unihash(method, taskhash, fetch_all)
+
+ async def get_unihash(self, method, taskhash, fetch_all=False):
+ d = None
+
+ if fetch_all:
+ row = await self.db.get_unihash_by_taskhash_full(method, taskhash)
+ if row is not None:
+ d = {k: row[k] for k in row.keys()}
+ elif self.upstream_client is not None:
+ d = await self.upstream_client.get_taskhash(method, taskhash, True)
+ await self.update_unified(d)
+ else:
+ row = await self.db.get_equivalent(method, taskhash)
+
+ if row is not None:
+ d = {k: row[k] for k in row.keys()}
+ elif self.upstream_client is not None:
+ d = await self.upstream_client.get_taskhash(method, taskhash)
+ await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
+
+ return d
+
+ @permissions(READ_PERM)
+ async def handle_get_outhash(self, request):
+ method = request["method"]
+ outhash = request["outhash"]
+ taskhash = request["taskhash"]
+ with_unihash = request.get("with_unihash", True)
- if request.get('all', False):
- row = self.query_equivalent(method, taskhash, self.ALL_QUERY)
+ return await self.get_outhash(method, outhash, taskhash, with_unihash)
+
+ async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
+ d = None
+ if with_unihash:
+ row = await self.db.get_unihash_by_outhash(method, outhash)
else:
- row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
+ row = await self.db.get_outhash(method, outhash)
if row is not None:
- logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
d = {k: row[k] for k in row.keys()}
+ elif self.upstream_client is not None:
+ d = await self.upstream_client.get_outhash(method, outhash, taskhash)
+ await self.update_unified(d)
- self.write_message(d)
- else:
- self.write_message(None)
+ return d
- async def handle_get_stream(self, request):
- self.write_message('ok')
+ async def update_unified(self, data):
+ if data is None:
+ return
+
+ await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+ await self.db.insert_outhash(data)
+
+ async def _stream_handler(self, handler):
+ await self.socket.send_message("ok")
while True:
- l = await self.reader.readline()
+ upstream = None
+
+ l = await self.socket.recv()
if not l:
- return
+ break
try:
# This inner loop is very sensitive and must be as fast as
# possible (which is why the request sample is handled manually
# instead of using 'with', and also why logging statements are
# commented out.
- self.request_sample = self.request_stats.start_sample()
+ self.request_sample = self.server.request_stats.start_sample()
request_measure = self.request_sample.measure()
request_measure.start()
- l = l.decode('utf-8').rstrip()
- if l == 'END':
- self.writer.write('ok\n'.encode('utf-8'))
- return
-
- (method, taskhash) = l.split()
- #logger.debug('Looking up %s %s' % (method, taskhash))
- row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
- if row is not None:
- msg = ('%s\n' % row['unihash']).encode('utf-8')
- #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
- else:
- msg = '\n'.encode('utf-8')
+ if l == "END":
+ break
- self.writer.write(msg)
+ msg = await handler(l)
+ await self.socket.send(msg)
finally:
request_measure.end()
self.request_sample.end()
- await self.writer.drain()
+ await self.socket.send("ok")
+ return self.NO_RESPONSE
- async def handle_report(self, data):
- with closing(self.db.cursor()) as cursor:
- cursor.execute('''
- -- Find tasks with a matching outhash (that is, tasks that
- -- are equivalent)
- SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
-
- -- If there is an exact match on the taskhash, return it.
- -- Otherwise return the oldest matching outhash of any
- -- taskhash
- ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
- created ASC
-
- -- Only return one row
- LIMIT 1
- ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
-
- row = cursor.fetchone()
-
- # If no matching outhash was found, or one *was* found but it
- # wasn't an exact match on the taskhash, a new entry for this
- # taskhash should be added
- if row is None or row['taskhash'] != data['taskhash']:
- # If a row matching the outhash was found, the unihash for
- # the new taskhash should be the same as that one.
- # Otherwise the caller provided unihash is used.
- unihash = data['unihash']
- if row is not None:
- unihash = row['unihash']
-
- insert_data = {
- 'method': data['method'],
- 'outhash': data['outhash'],
- 'taskhash': data['taskhash'],
- 'unihash': unihash,
- 'created': datetime.now()
- }
+ @permissions(READ_PERM)
+ async def handle_get_stream(self, request):
+ async def handler(l):
+ (method, taskhash) = l.split()
+ # self.logger.debug('Looking up %s %s' % (method, taskhash))
+ row = await self.db.get_equivalent(method, taskhash)
- for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
- if k in data:
- insert_data[k] = data[k]
+ if row is not None:
+ # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
+ return row["unihash"]
- cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
- ', '.join(sorted(insert_data.keys())),
- ', '.join(':' + k for k in sorted(insert_data.keys()))),
- insert_data)
+ if self.upstream_client is not None:
+ upstream = await self.upstream_client.get_unihash(method, taskhash)
+ if upstream:
+ await self.server.backfill_queue.put((method, taskhash))
+ return upstream
- self.db.commit()
+ return ""
- logger.info('Adding taskhash %s with unihash %s',
- data['taskhash'], unihash)
+ return await self._stream_handler(handler)
- d = {
- 'taskhash': data['taskhash'],
- 'method': data['method'],
- 'unihash': unihash
- }
- else:
- d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
+ @permissions(READ_PERM)
+ async def handle_exists_stream(self, request):
+ async def handler(l):
+ if await self.db.unihash_exists(l):
+ return "true"
- self.write_message(d)
+ if self.upstream_client is not None:
+ if await self.upstream_client.unihash_exists(l):
+ return "true"
- async def handle_equivreport(self, data):
- with closing(self.db.cursor()) as cursor:
- insert_data = {
- 'method': data['method'],
- 'outhash': "",
- 'taskhash': data['taskhash'],
- 'unihash': data['unihash'],
- 'created': datetime.now()
- }
+ return "false"
- for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
- if k in data:
- insert_data[k] = data[k]
+ return await self._stream_handler(handler)
- cursor.execute('''INSERT OR IGNORE INTO tasks_v2 (%s) VALUES (%s)''' % (
- ', '.join(sorted(insert_data.keys())),
- ', '.join(':' + k for k in sorted(insert_data.keys()))),
- insert_data)
+ async def report_readonly(self, data):
+ method = data["method"]
+ outhash = data["outhash"]
+ taskhash = data["taskhash"]
- self.db.commit()
+ info = await self.get_outhash(method, outhash, taskhash)
+ if info:
+ unihash = info["unihash"]
+ else:
+ unihash = data["unihash"]
+
+ return {
+ "taskhash": taskhash,
+ "method": method,
+ "unihash": unihash,
+ }
- # Fetch the unihash that will be reported for the taskhash. If the
- # unihash matches, it means this row was inserted (or the mapping
- # was already valid)
- row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY)
+ # Since this can be called either read only or to report, the check to
+ # report is made inside the function
+ @permissions(READ_PERM)
+ async def handle_report(self, data):
+ if self.server.read_only or not self.user_has_permissions(REPORT_PERM):
+ return await self.report_readonly(data)
+
+ outhash_data = {
+ "method": data["method"],
+ "outhash": data["outhash"],
+ "taskhash": data["taskhash"],
+ "created": datetime.now(),
+ }
- if row['unihash'] == data['unihash']:
- logger.info('Adding taskhash equivalence for %s with unihash %s',
- data['taskhash'], row['unihash'])
+ for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"):
+ if k in data:
+ outhash_data[k] = data[k]
- d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
+ if self.user:
+ outhash_data["owner"] = self.user.username
- self.write_message(d)
+ # Insert the new entry, unless it already exists
+ if await self.db.insert_outhash(outhash_data):
+ # If this row is new, check if it is equivalent to another
+ # output hash
+ row = await self.db.get_equivalent_for_outhash(
+ data["method"], data["outhash"], data["taskhash"]
+ )
+ if row is not None:
+ # A matching output hash was found. Set our taskhash to the
+ # same unihash since they are equivalent
+ unihash = row["unihash"]
+ else:
+ # No matching output hash was found. This is probably the
+ # first outhash to be added.
+ unihash = data["unihash"]
+
+ # Query upstream to see if it has a unihash we can use
+ if self.upstream_client is not None:
+ upstream_data = await self.upstream_client.get_outhash(
+ data["method"], data["outhash"], data["taskhash"]
+ )
+ if upstream_data is not None:
+ unihash = upstream_data["unihash"]
+
+ await self.db.insert_unihash(data["method"], data["taskhash"], unihash)
+
+ unihash_data = await self.get_unihash(data["method"], data["taskhash"])
+ if unihash_data is not None:
+ unihash = unihash_data["unihash"]
+ else:
+ unihash = data["unihash"]
- async def handle_get_stats(self, request):
- d = {
- 'requests': self.request_stats.todict(),
+ return {
+ "taskhash": data["taskhash"],
+ "method": data["method"],
+ "unihash": unihash,
}
- self.write_message(d)
+ @permissions(READ_PERM, REPORT_PERM)
+ async def handle_equivreport(self, data):
+ await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
+
+ # Fetch the unihash that will be reported for the taskhash. If the
+ # unihash matches, it means this row was inserted (or the mapping
+ # was already valid)
+ row = await self.db.get_equivalent(data["method"], data["taskhash"])
+
+ if row["unihash"] == data["unihash"]:
+ self.logger.info(
+ "Adding taskhash equivalence for %s with unihash %s",
+ data["taskhash"],
+ row["unihash"],
+ )
+
+ return {k: row[k] for k in ("taskhash", "method", "unihash")}
+
+ @permissions(READ_PERM)
+ async def handle_get_stats(self, request):
+ return {
+ "requests": self.server.request_stats.todict(),
+ }
+ @permissions(DB_ADMIN_PERM)
async def handle_reset_stats(self, request):
d = {
- 'requests': self.request_stats.todict(),
+ "requests": self.server.request_stats.todict(),
}
- self.request_stats.reset()
- self.write_message(d)
+ self.server.request_stats.reset()
+ return d
- def query_equivalent(self, method, taskhash, query):
- # This is part of the inner loop and must be as fast as possible
- try:
- cursor = self.db.cursor()
- cursor.execute(query, {'method': method, 'taskhash': taskhash})
- return cursor.fetchone()
- except:
- cursor.close()
+ @permissions(READ_PERM)
+ async def handle_backfill_wait(self, request):
+ d = {
+ "tasks": self.server.backfill_queue.qsize(),
+ }
+ await self.server.backfill_queue.join()
+ return d
+ @permissions(DB_ADMIN_PERM)
+ async def handle_remove(self, request):
+ condition = request["where"]
+ if not isinstance(condition, dict):
+ raise TypeError("Bad condition type %s" % type(condition))
-class Server(object):
- def __init__(self, db, loop=None):
- self.request_stats = Stats()
- self.db = db
+ return {"count": await self.db.remove(condition)}
- if loop is None:
- self.loop = asyncio.new_event_loop()
- self.close_loop = True
- else:
- self.loop = loop
- self.close_loop = False
+ @permissions(DB_ADMIN_PERM)
+ async def handle_gc_mark(self, request):
+ condition = request["where"]
+ mark = request["mark"]
- self._cleanup_socket = None
+ if not isinstance(condition, dict):
+ raise TypeError("Bad condition type %s" % type(condition))
- def start_tcp_server(self, host, port):
- self.server = self.loop.run_until_complete(
- asyncio.start_server(self.handle_client, host, port, loop=self.loop)
- )
+ if not isinstance(mark, str):
+ raise TypeError("Bad mark type %s" % type(mark))
- for s in self.server.sockets:
- logger.info('Listening on %r' % (s.getsockname(),))
- # Newer python does this automatically. Do it manually here for
- # maximum compatibility
- s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
- s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
+ return {"count": await self.db.gc_mark(mark, condition)}
- name = self.server.sockets[0].getsockname()
- if self.server.sockets[0].family == socket.AF_INET6:
- self.address = "[%s]:%d" % (name[0], name[1])
- else:
- self.address = "%s:%d" % (name[0], name[1])
+ @permissions(DB_ADMIN_PERM)
+ async def handle_gc_sweep(self, request):
+ mark = request["mark"]
- def start_unix_server(self, path):
- def cleanup():
- os.unlink(path)
+ if not isinstance(mark, str):
+ raise TypeError("Bad mark type %s" % type(mark))
- cwd = os.getcwd()
- try:
- # Work around path length limits in AF_UNIX
- os.chdir(os.path.dirname(path))
- self.server = self.loop.run_until_complete(
- asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
+ current_mark = await self.db.get_current_gc_mark()
+
+ if not current_mark or mark != current_mark:
+ raise bb.asyncrpc.InvokeError(
+ f"'{mark}' is not the current mark. Refusing to sweep"
)
- finally:
- os.chdir(cwd)
- logger.info('Listening on %r' % path)
+ count = await self.db.gc_sweep()
- self._cleanup_socket = cleanup
- self.address = "unix://%s" % os.path.abspath(path)
+ return {"count": count}
+
+ @permissions(DB_ADMIN_PERM)
+ async def handle_gc_status(self, request):
+ (keep_rows, remove_rows, current_mark) = await self.db.gc_status()
+ return {
+ "keep": keep_rows,
+ "remove": remove_rows,
+ "mark": current_mark,
+ }
+
+ @permissions(DB_ADMIN_PERM)
+ async def handle_clean_unused(self, request):
+ max_age = request["max_age_seconds"]
+ oldest = datetime.now() - timedelta(seconds=-max_age)
+ return {"count": await self.db.clean_unused(oldest)}
+
+ @permissions(DB_ADMIN_PERM)
+ async def handle_get_db_usage(self, request):
+ return {"usage": await self.db.get_usage()}
+
+ @permissions(DB_ADMIN_PERM)
+ async def handle_get_db_query_columns(self, request):
+ return {"columns": await self.db.get_query_columns()}
+
+ # The authentication API is always allowed
+ async def handle_auth(self, request):
+ username = str(request["username"])
+ token = str(request["token"])
+
+ async def fail_auth():
+ nonlocal username
+ # Rate limit bad login attempts
+ await asyncio.sleep(1)
+ raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}")
+
+ user, db_token = await self.db.lookup_user_token(username)
+
+ if not user or not db_token:
+ await fail_auth()
- async def handle_client(self, reader, writer):
- # writer.transport.set_write_buffer_limits(0)
try:
- client = ServerClient(reader, writer, self.db, self.request_stats)
- await client.process_requests()
- except Exception as e:
- import traceback
- logger.error('Error from client: %s' % str(e), exc_info=True)
- traceback.print_exc()
- writer.close()
- logger.info('Client disconnected')
+ algo, salt, _ = db_token.split(":")
+ except ValueError:
+ await fail_auth()
+
+ if hash_token(algo, salt, token) != db_token:
+ await fail_auth()
- def serve_forever(self):
- def signal_handler():
- self.loop.stop()
+ self.user = user
- self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
+ self.logger.info("Authenticated as %s", username)
+ return {
+ "result": True,
+ "username": self.user.username,
+ "permissions": sorted(list(self.user.permissions)),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
+ async def handle_refresh_token(self, request):
+ username = str(request["username"])
+
+ token = await new_token()
+
+ updated = await self.db.set_user_token(
+ username,
+ hash_token(TOKEN_ALGORITHM, new_salt(), token),
+ )
+ if not updated:
+ self.raise_no_user_error(username)
+
+ return {"username": username, "token": token}
+
+ def get_perm_arg(self, arg):
+ if not isinstance(arg, list):
+ raise bb.asyncrpc.InvokeError("Unexpected type for permissions")
+
+ arg = set(arg)
try:
- self.loop.run_forever()
- except KeyboardInterrupt:
+ arg.remove(NONE_PERM)
+ except KeyError:
pass
- self.server.close()
- self.loop.run_until_complete(self.server.wait_closed())
- logger.info('Server shutting down')
+ unknown_perms = arg - ALL_PERMISSIONS
+ if unknown_perms:
+ raise bb.asyncrpc.InvokeError(
+ "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms)))
+ )
+
+ return sorted(list(arg))
+
+ def return_perms(self, permissions):
+ if ALL_PERM in permissions:
+ return sorted(list(ALL_PERMISSIONS))
+ return sorted(list(permissions))
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_set_perms(self, request):
+ username = str(request["username"])
+ permissions = self.get_perm_arg(request["permissions"])
+
+ if not await self.db.set_user_perms(username, permissions):
+ self.raise_no_user_error(username)
+
+ return {
+ "username": username,
+ "permissions": self.return_perms(permissions),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
+ async def handle_get_user(self, request):
+ username = str(request["username"])
+
+ user = await self.db.lookup_user(username)
+ if user is None:
+ return None
+
+ return {
+ "username": user.username,
+ "permissions": self.return_perms(user.permissions),
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_get_all_users(self, request):
+ users = await self.db.get_all_users()
+ return {
+ "users": [
+ {
+ "username": u.username,
+ "permissions": self.return_perms(u.permissions),
+ }
+ for u in users
+ ]
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_new_user(self, request):
+ username = str(request["username"])
+ permissions = self.get_perm_arg(request["permissions"])
+
+ token = await new_token()
+
+ inserted = await self.db.new_user(
+ username,
+ permissions,
+ hash_token(TOKEN_ALGORITHM, new_salt(), token),
+ )
+ if not inserted:
+ raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'")
+
+ return {
+ "username": username,
+ "permissions": self.return_perms(permissions),
+ "token": token,
+ }
+
+ @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
+ async def handle_delete_user(self, request):
+ username = str(request["username"])
+
+ if not await self.db.delete_user(username):
+ self.raise_no_user_error(username)
+
+ return {"username": username}
+
+ @permissions(USER_ADMIN_PERM, allow_anon=False)
+ async def handle_become_user(self, request):
+ username = str(request["username"])
+
+ user = await self.db.lookup_user(username)
+ if user is None:
+ raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist")
+
+ self.user = user
+
+ self.logger.info("Became user %s", username)
+
+ return {
+ "username": self.user.username,
+ "permissions": self.return_perms(self.user.permissions),
+ }
+
+
+class Server(bb.asyncrpc.AsyncServer):
+ def __init__(
+ self,
+ db_engine,
+ upstream=None,
+ read_only=False,
+ anon_perms=DEFAULT_ANON_PERMS,
+ admin_username=None,
+ admin_password=None,
+ ):
+ if upstream and read_only:
+ raise bb.asyncrpc.ServerError(
+ "Read-only hashserv cannot pull from an upstream server"
+ )
+
+ disallowed_perms = set(anon_perms) - set(
+ [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM]
+ )
+
+ if disallowed_perms:
+ raise bb.asyncrpc.ServerError(
+ f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users"
+ )
+
+ super().__init__(logger)
+
+ self.request_stats = Stats()
+ self.db_engine = db_engine
+ self.upstream = upstream
+ self.read_only = read_only
+ self.backfill_queue = None
+ self.anon_perms = set(anon_perms)
+ self.admin_username = admin_username
+ self.admin_password = admin_password
+
+ self.logger.info(
+ "Anonymous user permissions are: %s", ", ".join(self.anon_perms)
+ )
+
+ def accept_client(self, socket):
+ return ServerClient(socket, self)
+
+ async def create_admin_user(self):
+ admin_permissions = (ALL_PERM,)
+ async with self.db_engine.connect(self.logger) as db:
+ added = await db.new_user(
+ self.admin_username,
+ admin_permissions,
+ hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
+ )
+ if added:
+ self.logger.info("Created admin user '%s'", self.admin_username)
+ else:
+ await db.set_user_perms(
+ self.admin_username,
+ admin_permissions,
+ )
+ await db.set_user_token(
+ self.admin_username,
+ hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
+ )
+ self.logger.info("Admin user '%s' updated", self.admin_username)
+
+ async def backfill_worker_task(self):
+ async with await create_async_client(
+ self.upstream
+ ) as client, self.db_engine.connect(self.logger) as db:
+ while True:
+ item = await self.backfill_queue.get()
+ if item is None:
+ self.backfill_queue.task_done()
+ break
+
+ method, taskhash = item
+ d = await client.get_taskhash(method, taskhash)
+ if d is not None:
+ await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
+ self.backfill_queue.task_done()
+
+ def start(self):
+ tasks = super().start()
+ if self.upstream:
+ self.backfill_queue = asyncio.Queue()
+ tasks += [self.backfill_worker_task()]
+
+ self.loop.run_until_complete(self.db_engine.create())
+
+ if self.admin_username:
+ self.loop.run_until_complete(self.create_admin_user())
- if self.close_loop:
- self.loop.close()
+ return tasks
- if self._cleanup_socket is not None:
- self._cleanup_socket()
+ async def stop(self):
+ if self.backfill_queue is not None:
+ await self.backfill_queue.put(None)
+ await super().stop()