diff options
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/serv.py')
-rw-r--r-- | bitbake/lib/bb/asyncrpc/serv.py | 385 |
1 files changed, 245 insertions, 140 deletions
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py index b4cffff213..a66117acad 100644 --- a/bitbake/lib/bb/asyncrpc/serv.py +++ b/bitbake/lib/bb/asyncrpc/serv.py @@ -1,4 +1,6 @@ # +# Copyright BitBake Contributors +# # SPDX-License-Identifier: GPL-2.0-only # @@ -10,234 +12,333 @@ import signal import socket import sys import multiprocessing -from . import chunkify, DEFAULT_MAX_CHUNK - +import logging +from .connection import StreamConnection, WebsocketConnection +from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError -class ClientError(Exception): - pass - -class ServerError(Exception): - pass +class ClientLoggerAdapter(logging.LoggerAdapter): + def process(self, msg, kwargs): + return f"[Client {self.extra['address']}] {msg}", kwargs class AsyncServerConnection(object): - def __init__(self, reader, writer, proto_name, logger): - self.reader = reader - self.writer = writer + # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no + # return message will be automatically be sent back to the client + NO_RESPONSE = object() + + def __init__(self, socket, proto_name, logger): + self.socket = socket self.proto_name = proto_name - self.max_chunk = DEFAULT_MAX_CHUNK self.handlers = { - 'chunk-stream': self.handle_chunk, - 'ping': self.handle_ping, + "ping": self.handle_ping, } - self.logger = logger + self.logger = ClientLoggerAdapter( + logger, + { + "address": socket.address, + }, + ) + self.client_headers = {} + + async def close(self): + await self.socket.close() + + async def handle_headers(self, headers): + return {} async def process_requests(self): try: - self.addr = self.writer.get_extra_info('peername') - self.logger.debug('Client %r connected' % (self.addr,)) + self.logger.info("Client %r connected" % (self.socket.address,)) # Read protocol and version - client_protocol = await self.reader.readline() - if client_protocol is None: + client_protocol = await self.socket.recv() + if not client_protocol: return - (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split() + (client_proto_name, client_proto_version) = client_protocol.split() if client_proto_name != self.proto_name: - self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name)) + self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name)) return - self.proto_version = tuple(int(v) for v in client_proto_version.split('.')) + self.proto_version = tuple(int(v) for v in client_proto_version.split(".")) if not self.validate_proto_version(): - self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version)) + self.logger.debug( + "Rejecting invalid protocol version %s" % (client_proto_version) + ) return - # Read headers. Currently, no headers are implemented, so look for - # an empty line to signal the end of the headers + # Read headers + self.client_headers = {} while True: - line = await self.reader.readline() - if line is None: - return - - line = line.decode('utf-8').rstrip() - if not line: + header = await self.socket.recv() + if not header: + # Empty line. End of headers break + tag, value = header.split(":", 1) + self.client_headers[tag.lower()] = value.strip() + + if self.client_headers.get("needs-headers", "false") == "true": + for k, v in (await self.handle_headers(self.client_headers)).items(): + await self.socket.send("%s: %s" % (k, v)) + await self.socket.send("") # Handle messages while True: - d = await self.read_message() + d = await self.socket.recv_message() if d is None: break - await self.dispatch_message(d) - await self.writer.drain() - except ClientError as e: + try: + response = await self.dispatch_message(d) + except InvokeError as e: + await self.socket.send_message( + {"invoke-error": {"message": str(e)}} + ) + break + + if response is not self.NO_RESPONSE: + await self.socket.send_message(response) + + except ConnectionClosedError as e: + self.logger.info(str(e)) + except (ClientError, ConnectionError) as e: self.logger.error(str(e)) finally: - self.writer.close() + await self.close() async def dispatch_message(self, msg): for k in self.handlers.keys(): if k in msg: - self.logger.debug('Handling %s' % k) - await self.handlers[k](msg[k]) - return + self.logger.debug("Handling %s" % k) + return await self.handlers[k](msg[k]) raise ClientError("Unrecognized command %r" % msg) - def write_message(self, msg): - for c in chunkify(json.dumps(msg), self.max_chunk): - self.writer.write(c.encode('utf-8')) + async def handle_ping(self, request): + return {"alive": True} + - async def read_message(self): - l = await self.reader.readline() - if not l: - return None +class StreamServer(object): + def __init__(self, handler, logger): + self.handler = handler + self.logger = logger + self.closed = False - try: - message = l.decode('utf-8') + async def handle_stream_client(self, reader, writer): + # writer.transport.set_write_buffer_limits(0) + socket = StreamConnection(reader, writer, -1) + if self.closed: + await socket.close() + return + + await self.handler(socket) + + async def stop(self): + self.closed = True + + +class TCPStreamServer(StreamServer): + def __init__(self, host, port, handler, logger): + super().__init__(handler, logger) + self.host = host + self.port = port + + def start(self, loop): + self.server = loop.run_until_complete( + asyncio.start_server(self.handle_stream_client, self.host, self.port) + ) + + for s in self.server.sockets: + self.logger.debug("Listening on %r" % (s.getsockname(),)) + # Newer python does this automatically. Do it manually here for + # maximum compatibility + s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) + + # Enable keep alives. This prevents broken client connections + # from persisting on the server for long periods of time. + s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) + + name = self.server.sockets[0].getsockname() + if self.server.sockets[0].family == socket.AF_INET6: + self.address = "[%s]:%d" % (name[0], name[1]) + else: + self.address = "%s:%d" % (name[0], name[1]) + + return [self.server.wait_closed()] + + async def stop(self): + await super().stop() + self.server.close() + + def cleanup(self): + pass - if not message.endswith('\n'): - return None - return json.loads(message) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - self.logger.error('Bad message from client: %r' % message) - raise e +class UnixStreamServer(StreamServer): + def __init__(self, path, handler, logger): + super().__init__(handler, logger) + self.path = path - async def handle_chunk(self, request): - lines = [] + def start(self, loop): + cwd = os.getcwd() try: - while True: - l = await self.reader.readline() - l = l.rstrip(b"\n").decode("utf-8") - if not l: - break - lines.append(l) - - msg = json.loads(''.join(lines)) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - self.logger.error('Bad message from client: %r' % lines) - raise e + # Work around path length limits in AF_UNIX + os.chdir(os.path.dirname(self.path)) + self.server = loop.run_until_complete( + asyncio.start_unix_server( + self.handle_stream_client, os.path.basename(self.path) + ) + ) + finally: + os.chdir(cwd) - if 'chunk-stream' in msg: - raise ClientError("Nested chunks are not allowed") + self.logger.debug("Listening on %r" % self.path) + self.address = "unix://%s" % os.path.abspath(self.path) + return [self.server.wait_closed()] - await self.dispatch_message(msg) + async def stop(self): + await super().stop() + self.server.close() - async def handle_ping(self, request): - response = {'alive': True} - self.write_message(response) + def cleanup(self): + os.unlink(self.path) -class AsyncServer(object): - def __init__(self, logger): - self._cleanup_socket = None +class WebsocketsServer(object): + def __init__(self, host, port, handler, logger): + self.host = host + self.port = port + self.handler = handler self.logger = logger - self.start = None - self.address = None - self.loop = None - def start_tcp_server(self, host, port): - def start_tcp(): - self.server = self.loop.run_until_complete( - asyncio.start_server(self.handle_client, host, port) + def start(self, loop): + import websockets.server + + self.server = loop.run_until_complete( + websockets.server.serve( + self.client_handler, + self.host, + self.port, + ping_interval=None, ) + ) - for s in self.server.sockets: - self.logger.debug('Listening on %r' % (s.getsockname(),)) - # Newer python does this automatically. Do it manually here for - # maximum compatibility - s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) + for s in self.server.sockets: + self.logger.debug("Listening on %r" % (s.getsockname(),)) - name = self.server.sockets[0].getsockname() - if self.server.sockets[0].family == socket.AF_INET6: - self.address = "[%s]:%d" % (name[0], name[1]) - else: - self.address = "%s:%d" % (name[0], name[1]) + # Enable keep alives. This prevents broken client connections + # from persisting on the server for long periods of time. + s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15) + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4) - self.start = start_tcp + name = self.server.sockets[0].getsockname() + if self.server.sockets[0].family == socket.AF_INET6: + self.address = "ws://[%s]:%d" % (name[0], name[1]) + else: + self.address = "ws://%s:%d" % (name[0], name[1]) - def start_unix_server(self, path): - def cleanup(): - os.unlink(path) + return [self.server.wait_closed()] - def start_unix(): - cwd = os.getcwd() - try: - # Work around path length limits in AF_UNIX - os.chdir(os.path.dirname(path)) - self.server = self.loop.run_until_complete( - asyncio.start_unix_server(self.handle_client, os.path.basename(path)) - ) - finally: - os.chdir(cwd) + async def stop(self): + self.server.close() - self.logger.debug('Listening on %r' % path) + def cleanup(self): + pass - self._cleanup_socket = cleanup - self.address = "unix://%s" % os.path.abspath(path) + async def client_handler(self, websocket): + socket = WebsocketConnection(websocket, -1) + await self.handler(socket) - self.start = start_unix - @abc.abstractmethod - def accept_client(self, reader, writer): - pass +class AsyncServer(object): + def __init__(self, logger): + self.logger = logger + self.loop = None + self.run_tasks = [] - async def handle_client(self, reader, writer): - # writer.transport.set_write_buffer_limits(0) + def start_tcp_server(self, host, port): + self.server = TCPStreamServer(host, port, self._client_handler, self.logger) + + def start_unix_server(self, path): + self.server = UnixStreamServer(path, self._client_handler, self.logger) + + def start_websocket_server(self, host, port): + self.server = WebsocketsServer(host, port, self._client_handler, self.logger) + + async def _client_handler(self, socket): + address = socket.address try: - client = self.accept_client(reader, writer) + client = self.accept_client(socket) await client.process_requests() except Exception as e: import traceback - self.logger.error('Error from client: %s' % str(e), exc_info=True) + + self.logger.error( + "Error from client %s: %s" % (address, str(e)), exc_info=True + ) traceback.print_exc() - writer.close() - self.logger.debug('Client disconnected') + finally: + self.logger.debug("Client %s disconnected", address) + await socket.close() - def run_loop_forever(self): - try: - self.loop.run_forever() - except KeyboardInterrupt: - pass + @abc.abstractmethod + def accept_client(self, socket): + pass + + async def stop(self): + self.logger.debug("Stopping server") + await self.server.stop() + + def start(self): + tasks = self.server.start(self.loop) + self.address = self.server.address + return tasks def signal_handler(self): self.logger.debug("Got exit signal") - self.loop.stop() + self.loop.create_task(self.stop()) - def _serve_forever(self): + def _serve_forever(self, tasks): try: self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler) + self.loop.add_signal_handler(signal.SIGINT, self.signal_handler) + self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler) signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM]) - self.run_loop_forever() - self.server.close() + self.loop.run_until_complete(asyncio.gather(*tasks)) - self.loop.run_until_complete(self.server.wait_closed()) - self.logger.debug('Server shutting down') + self.logger.debug("Server shutting down") finally: - if self._cleanup_socket is not None: - self._cleanup_socket() + self.server.cleanup() def serve_forever(self): """ Serve requests in the current process """ + self._create_loop() + tasks = self.start() + self._serve_forever(tasks) + self.loop.close() + + def _create_loop(self): # Create loop and override any loop that may have existed in # a parent process. It is possible that the usecases of # serve_forever might be constrained enough to allow using # get_event_loop here, but better safe than sorry for now. self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - self.start() - self._serve_forever() - def serve_as_process(self, *, prefunc=None, args=()): + def serve_as_process(self, *, prefunc=None, args=(), log_level=None): """ Serve requests in a child process """ + def run(queue): # Create loop and override any loop that may have existed # in a parent process. Without doing this and instead @@ -250,18 +351,22 @@ class AsyncServer(object): # more general, though, as any potential use of asyncio in # Cooker could create a loop that needs to replaced in this # new process. - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) + self._create_loop() try: - self.start() + self.address = None + tasks = self.start() finally: + # Always put the server address to wake up the parent task queue.put(self.address) queue.close() if prefunc is not None: prefunc(self, *args) - self._serve_forever() + if log_level is not None: + self.logger.setLevel(log_level) + + self._serve_forever(tasks) if sys.version_info >= (3, 6): self.loop.run_until_complete(self.loop.shutdown_asyncgens()) |