diff options
Diffstat (limited to 'bitbake/lib/prserv/serv.py')
-rw-r--r-- | bitbake/lib/prserv/serv.py | 315 |
1 files changed, 167 insertions, 148 deletions
diff --git a/bitbake/lib/prserv/serv.py b/bitbake/lib/prserv/serv.py index 5e322bf83d..dc4be5b620 100644 --- a/bitbake/lib/prserv/serv.py +++ b/bitbake/lib/prserv/serv.py @@ -1,160 +1,167 @@ # +# Copyright BitBake Contributors +# # SPDX-License-Identifier: GPL-2.0-only # import os,sys,logging import signal, time -from xmlrpc.server import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler import socket import io import sqlite3 -import bb.server.xmlrpcclient import prserv import prserv.db import errno -import multiprocessing +import bb.asyncrpc logger = logging.getLogger("BitBake.PRserv") -class Handler(SimpleXMLRPCRequestHandler): - def _dispatch(self,method,params): +PIDPREFIX = "/tmp/PRServer_%s_%s.pid" +singleton = None + +class PRServerClient(bb.asyncrpc.AsyncServerConnection): + def __init__(self, socket, server): + super().__init__(socket, "PRSERVICE", server.logger) + self.server = server + + self.handlers.update({ + "get-pr": self.handle_get_pr, + "test-pr": self.handle_test_pr, + "test-package": self.handle_test_package, + "max-package-pr": self.handle_max_package_pr, + "import-one": self.handle_import_one, + "export": self.handle_export, + "is-readonly": self.handle_is_readonly, + }) + + def validate_proto_version(self): + return (self.proto_version == (1, 0)) + + async def dispatch_message(self, msg): try: - value=self.server.funcs[method](*params) + return await super().dispatch_message(msg) except: - import traceback - traceback.print_exc() + self.server.table.sync() raise - return value + else: + self.server.table.sync_if_dirty() -PIDPREFIX = "/tmp/PRServer_%s_%s.pid" -singleton = None + async def handle_test_pr(self, request): + '''Finds the PR value corresponding to the request. If not found, returns None and doesn't insert a new value''' + version = request["version"] + pkgarch = request["pkgarch"] + checksum = request["checksum"] + value = self.server.table.find_value(version, pkgarch, checksum) + return {"value": value} -class PRServer(SimpleXMLRPCServer): - def __init__(self, dbfile, logfile, interface): - ''' constructor ''' - try: - SimpleXMLRPCServer.__init__(self, interface, - logRequests=False, allow_none=True) - except socket.error: - ip=socket.gethostbyname(interface[0]) - port=interface[1] - msg="PR Server unable to bind to %s:%s\n" % (ip, port) - sys.stderr.write(msg) - raise PRServiceConfigError - - self.dbfile=dbfile - self.logfile=logfile - self.host, self.port = self.socket.getsockname() + async def handle_test_package(self, request): + '''Tells whether there are entries for (version, pkgarch) in the db. Returns True or False''' + version = request["version"] + pkgarch = request["pkgarch"] - self.register_function(self.getPR, "getPR") - self.register_function(self.ping, "ping") - self.register_function(self.export, "export") - self.register_function(self.importone, "importone") - self.register_introspection_functions() + value = self.server.table.test_package(version, pkgarch) + return {"value": value} - self.iter_count = 0 - # 60 iterations between syncs or sync if dirty every ~30 seconds - self.iterations_between_sync = 60 + async def handle_max_package_pr(self, request): + '''Finds the greatest PR value for (version, pkgarch) in the db. Returns None if no entry was found''' + version = request["version"] + pkgarch = request["pkgarch"] - def sigint_handler(self, signum, stack): - if self.table: - self.table.sync() + value = self.server.table.find_max_value(version, pkgarch) + return {"value": value} - def sigterm_handler(self, signum, stack): - if self.table: - self.table.sync() - raise(SystemExit) + async def handle_get_pr(self, request): + version = request["version"] + pkgarch = request["pkgarch"] + checksum = request["checksum"] - def process_request(self, request, client_address): - if request is None: - return + response = None try: - self.finish_request(request, client_address) - self.shutdown_request(request) - self.iter_count = (self.iter_count + 1) % self.iterations_between_sync - if self.iter_count == 0: - self.table.sync_if_dirty() - except: - self.handle_error(request, client_address) - self.shutdown_request(request) - self.table.sync() - self.table.sync_if_dirty() - - def serve_forever(self, poll_interval=0.5): - signal.signal(signal.SIGINT, self.sigint_handler) - signal.signal(signal.SIGTERM, self.sigterm_handler) + value = self.server.table.get_value(version, pkgarch, checksum) + response = {"value": value} + except prserv.NotFoundError: + self.logger.error("failure storing value in database for (%s, %s)",version, checksum) - self.db = prserv.db.PRData(self.dbfile) - self.table = self.db["PRMAIN"] - return super().serve_forever(poll_interval) + return response - def export(self, version=None, pkgarch=None, checksum=None, colinfo=True): - try: - return self.table.export(version, pkgarch, checksum, colinfo) - except sqlite3.Error as exc: - logger.error(str(exc)) - return None + async def handle_import_one(self, request): + response = None + if not self.server.read_only: + version = request["version"] + pkgarch = request["pkgarch"] + checksum = request["checksum"] + value = request["value"] - def importone(self, version, pkgarch, checksum, value): - return self.table.importone(version, pkgarch, checksum, value) + value = self.server.table.importone(version, pkgarch, checksum, value) + if value is not None: + response = {"value": value} - def ping(self): - return True + return response - def getinfo(self): - return (self.host, self.port) + async def handle_export(self, request): + version = request["version"] + pkgarch = request["pkgarch"] + checksum = request["checksum"] + colinfo = request["colinfo"] - def getPR(self, version, pkgarch, checksum): try: - return self.table.getValue(version, pkgarch, checksum) - except prserv.NotFoundError: - logger.error("can not find value for (%s, %s)",version, checksum) - return None + (metainfo, datainfo) = self.server.table.export(version, pkgarch, checksum, colinfo) except sqlite3.Error as exc: - logger.error(str(exc)) - return None + self.logger.error(str(exc)) + metainfo = datainfo = None -class PRServSingleton(object): - def __init__(self, dbfile, logfile, interface): + return {"metainfo": metainfo, "datainfo": datainfo} + + async def handle_is_readonly(self, request): + return {"readonly": self.server.read_only} + +class PRServer(bb.asyncrpc.AsyncServer): + def __init__(self, dbfile, read_only=False): + super().__init__(logger) self.dbfile = dbfile - self.logfile = logfile - self.interface = interface - self.host = None - self.port = None + self.table = None + self.read_only = read_only - def start(self): - self.prserv = PRServer(self.dbfile, self.logfile, self.interface) - self.process = multiprocessing.Process(target=self.prserv.serve_forever) - self.process.start() + def accept_client(self, socket): + return PRServerClient(socket, self) - self.host, self.port = self.prserv.getinfo() + def start(self): + tasks = super().start() + self.db = prserv.db.PRData(self.dbfile, read_only=self.read_only) + self.table = self.db["PRMAIN"] - def getinfo(self): - return (self.host, self.port) + self.logger.info("Started PRServer with DBfile: %s, Address: %s, PID: %s" % + (self.dbfile, self.address, str(os.getpid()))) -class PRServerConnection(object): - def __init__(self, host, port): - if is_local_special(host, port): - host, port = singleton.getinfo() - self.host = host - self.port = port - self.connection, self.transport = bb.server.xmlrpcclient._create_server(self.host, self.port) + return tasks - def getPR(self, version, pkgarch, checksum): - return self.connection.getPR(version, pkgarch, checksum) + async def stop(self): + self.table.sync_if_dirty() + self.db.disconnect() + await super().stop() - def ping(self): - return self.connection.ping() + def signal_handler(self): + super().signal_handler() + if self.table: + self.table.sync() - def export(self,version=None, pkgarch=None, checksum=None, colinfo=True): - return self.connection.export(version, pkgarch, checksum, colinfo) +class PRServSingleton(object): + def __init__(self, dbfile, logfile, host, port): + self.dbfile = dbfile + self.logfile = logfile + self.host = host + self.port = port - def importone(self, version, pkgarch, checksum, value): - return self.connection.importone(version, pkgarch, checksum, value) + def start(self): + self.prserv = PRServer(self.dbfile) + self.prserv.start_tcp_server(socket.gethostbyname(self.host), self.port) + self.process = self.prserv.serve_as_process(log_level=logging.WARNING) - def getinfo(self): - return self.host, self.port + if not self.prserv.address: + raise PRServiceConfigError + if not self.port: + self.port = int(self.prserv.address.rsplit(":", 1)[1]) def run_as_daemon(func, pidfile, logfile): """ @@ -190,18 +197,18 @@ def run_as_daemon(func, pidfile, logfile): # stdout/stderr or it could be 'real' unix fd forking where we need # to physically close the fds to prevent the program launching us from # potentially hanging on a pipe. Handle both cases. - si = open('/dev/null', 'r') + si = open("/dev/null", "r") try: - os.dup2(si.fileno(),sys.stdin.fileno()) + os.dup2(si.fileno(), sys.stdin.fileno()) except (AttributeError, io.UnsupportedOperation): sys.stdin = si - so = open(logfile, 'a+') + so = open(logfile, "a+") try: - os.dup2(so.fileno(),sys.stdout.fileno()) + os.dup2(so.fileno(), sys.stdout.fileno()) except (AttributeError, io.UnsupportedOperation): sys.stdout = so try: - os.dup2(so.fileno(),sys.stderr.fileno()) + os.dup2(so.fileno(), sys.stderr.fileno()) except (AttributeError, io.UnsupportedOperation): sys.stderr = so @@ -219,14 +226,14 @@ def run_as_daemon(func, pidfile, logfile): # write pidfile pid = str(os.getpid()) - with open(pidfile, 'w') as pf: + with open(pidfile, "w") as pf: pf.write("%s\n" % pid) func() os.remove(pidfile) os._exit(0) -def start_daemon(dbfile, host, port, logfile): +def start_daemon(dbfile, host, port, logfile, read_only=False): ip = socket.gethostbyname(host) pidfile = PIDPREFIX % (ip, port) try: @@ -240,15 +247,13 @@ def start_daemon(dbfile, host, port, logfile): % pidfile) return 1 - server = PRServer(os.path.abspath(dbfile), os.path.abspath(logfile), (ip,port)) - run_as_daemon(server.serve_forever, pidfile, os.path.abspath(logfile)) + dbfile = os.path.abspath(dbfile) + def daemon_main(): + server = PRServer(dbfile, read_only=read_only) + server.start_tcp_server(ip, port) + server.serve_forever() - # Sometimes, the port (i.e. localhost:0) indicated by the user does not match with - # the one the server actually is listening, so at least warn the user about it - _,rport = server.getinfo() - if port != rport: - sys.stdout.write("Server is listening at port %s instead of %s\n" - % (rport,port)) + run_as_daemon(daemon_main, pidfile, os.path.abspath(logfile)) return 0 def stop_daemon(host, port): @@ -266,15 +271,15 @@ def stop_daemon(host, port): # so at least advise the user which ports the corresponding server is listening ports = [] portstr = "" - for pf in glob.glob(PIDPREFIX % (ip,'*')): + for pf in glob.glob(PIDPREFIX % (ip, "*")): bn = os.path.basename(pf) root, _ = os.path.splitext(bn) - ports.append(root.split('_')[-1]) + ports.append(root.split("_")[-1]) if len(ports): - portstr = "Wrong port? Other ports listening at %s: %s" % (host, ' '.join(ports)) + portstr = "Wrong port? Other ports listening at %s: %s" % (host, " ".join(ports)) sys.stderr.write("pidfile %s does not exist. Daemon not running? %s\n" - % (pidfile,portstr)) + % (pidfile, portstr)) return 1 try: @@ -283,8 +288,11 @@ def stop_daemon(host, port): os.kill(pid, signal.SIGTERM) time.sleep(0.1) - if os.path.exists(pidfile): + try: os.remove(pidfile) + except FileNotFoundError: + # The PID file might have been removed by the exiting process + pass except OSError as e: err = str(e) @@ -302,7 +310,7 @@ def is_running(pid): return True def is_local_special(host, port): - if host.strip().upper() == 'localhost'.upper() and (not port): + if (host == "localhost" or host == "127.0.0.1") and not port: return True else: return False @@ -313,7 +321,7 @@ class PRServiceConfigError(Exception): def auto_start(d): global singleton - host_params = list(filter(None, (d.getVar('PRSERV_HOST') or '').split(':'))) + host_params = list(filter(None, (d.getVar("PRSERV_HOST") or "").split(":"))) if not host_params: # Shutdown any existing PR Server auto_shutdown() @@ -322,11 +330,13 @@ def auto_start(d): if len(host_params) != 2: # Shutdown any existing PR Server auto_shutdown() - logger.critical('\n'.join(['PRSERV_HOST: incorrect format', + logger.critical("\n".join(["PRSERV_HOST: incorrect format", 'Usage: PRSERV_HOST = "<hostname>:<port>"'])) raise PRServiceConfigError - if is_local_special(host_params[0], int(host_params[1])): + host = host_params[0].strip().lower() + port = int(host_params[1]) + if is_local_special(host, port): import bb.utils cachedir = (d.getVar("PERSISTENT_DIR") or d.getVar("CACHE")) if not cachedir: @@ -340,20 +350,16 @@ def auto_start(d): auto_shutdown() if not singleton: bb.utils.mkdirhier(cachedir) - singleton = PRServSingleton(os.path.abspath(dbfile), os.path.abspath(logfile), ("localhost",0)) + singleton = PRServSingleton(os.path.abspath(dbfile), os.path.abspath(logfile), host, port) singleton.start() if singleton: - host, port = singleton.getinfo() - else: - host = host_params[0] - port = int(host_params[1]) + host = singleton.host + port = singleton.port try: - connection = PRServerConnection(host,port) - connection.ping() - realhost, realport = connection.getinfo() - return str(realhost) + ":" + str(realport) - + ping(host, port) + return str(host) + ":" + str(port) + except Exception: logger.critical("PRservice %s:%d not available" % (host, port)) raise PRServiceConfigError @@ -366,8 +372,21 @@ def auto_shutdown(): singleton = None def ping(host, port): - conn=PRServerConnection(host, port) - return conn.ping() + from . import client + + with client.PRClient() as conn: + conn.connect_tcp(host, port) + return conn.ping() def connect(host, port): - return PRServerConnection(host, port) + from . import client + + global singleton + + if host.strip().lower() == "localhost" and not port: + host = "localhost" + port = singleton.port + + conn = client.PRClient() + conn.connect_tcp(host, port) + return conn |