summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc/serv.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/serv.py')
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py415
1 files changed, 294 insertions, 121 deletions
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
index cb3384639d..a66117acad 100644
--- a/bitbake/lib/bb/asyncrpc/serv.py
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -1,4 +1,6 @@
#
+# Copyright BitBake Contributors
+#
# SPDX-License-Identifier: GPL-2.0-only
#
@@ -9,210 +11,381 @@ import os
import signal
import socket
import sys
-from . import chunkify, DEFAULT_MAX_CHUNK
-
-
-class ClientError(Exception):
- pass
+import multiprocessing
+import logging
+from .connection import StreamConnection, WebsocketConnection
+from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
-class ServerError(Exception):
- pass
+class ClientLoggerAdapter(logging.LoggerAdapter):
+ def process(self, msg, kwargs):
+ return f"[Client {self.extra['address']}] {msg}", kwargs
class AsyncServerConnection(object):
- def __init__(self, reader, writer, proto_name, logger):
- self.reader = reader
- self.writer = writer
+ # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
+ # return message will be automatically be sent back to the client
+ NO_RESPONSE = object()
+
+ def __init__(self, socket, proto_name, logger):
+ self.socket = socket
self.proto_name = proto_name
- self.max_chunk = DEFAULT_MAX_CHUNK
self.handlers = {
- 'chunk-stream': self.handle_chunk,
+ "ping": self.handle_ping,
}
- self.logger = logger
+ self.logger = ClientLoggerAdapter(
+ logger,
+ {
+ "address": socket.address,
+ },
+ )
+ self.client_headers = {}
+
+ async def close(self):
+ await self.socket.close()
+
+ async def handle_headers(self, headers):
+ return {}
async def process_requests(self):
try:
- self.addr = self.writer.get_extra_info('peername')
- self.logger.debug('Client %r connected' % (self.addr,))
+ self.logger.info("Client %r connected" % (self.socket.address,))
# Read protocol and version
- client_protocol = await self.reader.readline()
- if client_protocol is None:
+ client_protocol = await self.socket.recv()
+ if not client_protocol:
return
- (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
+ (client_proto_name, client_proto_version) = client_protocol.split()
if client_proto_name != self.proto_name:
- self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
+ self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
return
- self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
+ self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
if not self.validate_proto_version():
- self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
+ self.logger.debug(
+ "Rejecting invalid protocol version %s" % (client_proto_version)
+ )
return
- # Read headers. Currently, no headers are implemented, so look for
- # an empty line to signal the end of the headers
+ # Read headers
+ self.client_headers = {}
while True:
- line = await self.reader.readline()
- if line is None:
- return
-
- line = line.decode('utf-8').rstrip()
- if not line:
+ header = await self.socket.recv()
+ if not header:
+ # Empty line. End of headers
break
+ tag, value = header.split(":", 1)
+ self.client_headers[tag.lower()] = value.strip()
+
+ if self.client_headers.get("needs-headers", "false") == "true":
+ for k, v in (await self.handle_headers(self.client_headers)).items():
+ await self.socket.send("%s: %s" % (k, v))
+ await self.socket.send("")
# Handle messages
while True:
- d = await self.read_message()
+ d = await self.socket.recv_message()
if d is None:
break
- await self.dispatch_message(d)
- await self.writer.drain()
- except ClientError as e:
+ try:
+ response = await self.dispatch_message(d)
+ except InvokeError as e:
+ await self.socket.send_message(
+ {"invoke-error": {"message": str(e)}}
+ )
+ break
+
+ if response is not self.NO_RESPONSE:
+ await self.socket.send_message(response)
+
+ except ConnectionClosedError as e:
+ self.logger.info(str(e))
+ except (ClientError, ConnectionError) as e:
self.logger.error(str(e))
finally:
- self.writer.close()
+ await self.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
if k in msg:
- self.logger.debug('Handling %s' % k)
- await self.handlers[k](msg[k])
- return
+ self.logger.debug("Handling %s" % k)
+ return await self.handlers[k](msg[k])
raise ClientError("Unrecognized command %r" % msg)
- def write_message(self, msg):
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode('utf-8'))
-
- async def read_message(self):
- l = await self.reader.readline()
- if not l:
- return None
-
- try:
- message = l.decode('utf-8')
-
- if not message.endswith('\n'):
- return None
+ async def handle_ping(self, request):
+ return {"alive": True}
- return json.loads(message)
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- self.logger.error('Bad message from client: %r' % message)
- raise e
- async def handle_chunk(self, request):
- lines = []
- try:
- while True:
- l = await self.reader.readline()
- l = l.rstrip(b"\n").decode("utf-8")
- if not l:
- break
- lines.append(l)
-
- msg = json.loads(''.join(lines))
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- self.logger.error('Bad message from client: %r' % lines)
- raise e
+class StreamServer(object):
+ def __init__(self, handler, logger):
+ self.handler = handler
+ self.logger = logger
+ self.closed = False
- if 'chunk-stream' in msg:
- raise ClientError("Nested chunks are not allowed")
+ async def handle_stream_client(self, reader, writer):
+ # writer.transport.set_write_buffer_limits(0)
+ socket = StreamConnection(reader, writer, -1)
+ if self.closed:
+ await socket.close()
+ return
- await self.dispatch_message(msg)
+ await self.handler(socket)
+ async def stop(self):
+ self.closed = True
-class AsyncServer(object):
- def __init__(self, logger, loop=None):
- if loop is None:
- self.loop = asyncio.new_event_loop()
- self.close_loop = True
- else:
- self.loop = loop
- self.close_loop = False
- self._cleanup_socket = None
- self.logger = logger
+class TCPStreamServer(StreamServer):
+ def __init__(self, host, port, handler, logger):
+ super().__init__(handler, logger)
+ self.host = host
+ self.port = port
- def start_tcp_server(self, host, port):
- self.server = self.loop.run_until_complete(
- asyncio.start_server(self.handle_client, host, port, loop=self.loop)
+ def start(self, loop):
+ self.server = loop.run_until_complete(
+ asyncio.start_server(self.handle_stream_client, self.host, self.port)
)
for s in self.server.sockets:
- self.logger.info('Listening on %r' % (s.getsockname(),))
+ self.logger.debug("Listening on %r" % (s.getsockname(),))
# Newer python does this automatically. Do it manually here for
# maximum compatibility
s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
+ # Enable keep alives. This prevents broken client connections
+ # from persisting on the server for long periods of time.
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
+
name = self.server.sockets[0].getsockname()
if self.server.sockets[0].family == socket.AF_INET6:
self.address = "[%s]:%d" % (name[0], name[1])
else:
self.address = "%s:%d" % (name[0], name[1])
- def start_unix_server(self, path):
- def cleanup():
- os.unlink(path)
+ return [self.server.wait_closed()]
+
+ async def stop(self):
+ await super().stop()
+ self.server.close()
+ def cleanup(self):
+ pass
+
+
+class UnixStreamServer(StreamServer):
+ def __init__(self, path, handler, logger):
+ super().__init__(handler, logger)
+ self.path = path
+
+ def start(self, loop):
cwd = os.getcwd()
try:
# Work around path length limits in AF_UNIX
- os.chdir(os.path.dirname(path))
- self.server = self.loop.run_until_complete(
- asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
+ os.chdir(os.path.dirname(self.path))
+ self.server = loop.run_until_complete(
+ asyncio.start_unix_server(
+ self.handle_stream_client, os.path.basename(self.path)
+ )
)
finally:
os.chdir(cwd)
- self.logger.info('Listening on %r' % path)
+ self.logger.debug("Listening on %r" % self.path)
+ self.address = "unix://%s" % os.path.abspath(self.path)
+ return [self.server.wait_closed()]
- self._cleanup_socket = cleanup
- self.address = "unix://%s" % os.path.abspath(path)
+ async def stop(self):
+ await super().stop()
+ self.server.close()
- @abc.abstractmethod
- def accept_client(self, reader, writer):
+ def cleanup(self):
+ os.unlink(self.path)
+
+
+class WebsocketsServer(object):
+ def __init__(self, host, port, handler, logger):
+ self.host = host
+ self.port = port
+ self.handler = handler
+ self.logger = logger
+
+ def start(self, loop):
+ import websockets.server
+
+ self.server = loop.run_until_complete(
+ websockets.server.serve(
+ self.client_handler,
+ self.host,
+ self.port,
+ ping_interval=None,
+ )
+ )
+
+ for s in self.server.sockets:
+ self.logger.debug("Listening on %r" % (s.getsockname(),))
+
+ # Enable keep alives. This prevents broken client connections
+ # from persisting on the server for long periods of time.
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
+
+ name = self.server.sockets[0].getsockname()
+ if self.server.sockets[0].family == socket.AF_INET6:
+ self.address = "ws://[%s]:%d" % (name[0], name[1])
+ else:
+ self.address = "ws://%s:%d" % (name[0], name[1])
+
+ return [self.server.wait_closed()]
+
+ async def stop(self):
+ self.server.close()
+
+ def cleanup(self):
pass
- async def handle_client(self, reader, writer):
- # writer.transport.set_write_buffer_limits(0)
+ async def client_handler(self, websocket):
+ socket = WebsocketConnection(websocket, -1)
+ await self.handler(socket)
+
+
+class AsyncServer(object):
+ def __init__(self, logger):
+ self.logger = logger
+ self.loop = None
+ self.run_tasks = []
+
+ def start_tcp_server(self, host, port):
+ self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
+
+ def start_unix_server(self, path):
+ self.server = UnixStreamServer(path, self._client_handler, self.logger)
+
+ def start_websocket_server(self, host, port):
+ self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
+
+ async def _client_handler(self, socket):
+ address = socket.address
try:
- client = self.accept_client(reader, writer)
+ client = self.accept_client(socket)
await client.process_requests()
except Exception as e:
import traceback
- self.logger.error('Error from client: %s' % str(e), exc_info=True)
+
+ self.logger.error(
+ "Error from client %s: %s" % (address, str(e)), exc_info=True
+ )
traceback.print_exc()
- writer.close()
- self.logger.info('Client disconnected')
+ finally:
+ self.logger.debug("Client %s disconnected", address)
+ await socket.close()
- def run_loop_forever(self):
- try:
- self.loop.run_forever()
- except KeyboardInterrupt:
- pass
+ @abc.abstractmethod
+ def accept_client(self, socket):
+ pass
+
+ async def stop(self):
+ self.logger.debug("Stopping server")
+ await self.server.stop()
+
+ def start(self):
+ tasks = self.server.start(self.loop)
+ self.address = self.server.address
+ return tasks
def signal_handler(self):
- self.loop.stop()
+ self.logger.debug("Got exit signal")
+ self.loop.create_task(self.stop())
- def serve_forever(self):
- asyncio.set_event_loop(self.loop)
+ def _serve_forever(self, tasks):
try:
self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
+ self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
+ self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
- self.run_loop_forever()
- self.server.close()
+ self.loop.run_until_complete(asyncio.gather(*tasks))
- self.loop.run_until_complete(self.server.wait_closed())
- self.logger.info('Server shutting down')
+ self.logger.debug("Server shutting down")
finally:
- if self.close_loop:
- if sys.version_info >= (3, 6):
- self.loop.run_until_complete(self.loop.shutdown_asyncgens())
- self.loop.close()
+ self.server.cleanup()
+
+ def serve_forever(self):
+ """
+ Serve requests in the current process
+ """
+ self._create_loop()
+ tasks = self.start()
+ self._serve_forever(tasks)
+ self.loop.close()
+
+ def _create_loop(self):
+ # Create loop and override any loop that may have existed in
+ # a parent process. It is possible that the usecases of
+ # serve_forever might be constrained enough to allow using
+ # get_event_loop here, but better safe than sorry for now.
+ self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(self.loop)
+
+ def serve_as_process(self, *, prefunc=None, args=(), log_level=None):
+ """
+ Serve requests in a child process
+ """
+
+ def run(queue):
+ # Create loop and override any loop that may have existed
+ # in a parent process. Without doing this and instead
+ # using get_event_loop, at the very minimum the hashserv
+ # unit tests will hang when running the second test.
+ # This happens since get_event_loop in the spawned server
+ # process for the second testcase ends up with the loop
+ # from the hashserv client created in the unit test process
+ # when running the first testcase. The problem is somewhat
+ # more general, though, as any potential use of asyncio in
+ # Cooker could create a loop that needs to replaced in this
+ # new process.
+ self._create_loop()
+ try:
+ self.address = None
+ tasks = self.start()
+ finally:
+ # Always put the server address to wake up the parent task
+ queue.put(self.address)
+ queue.close()
+
+ if prefunc is not None:
+ prefunc(self, *args)
+
+ if log_level is not None:
+ self.logger.setLevel(log_level)
+
+ self._serve_forever(tasks)
+
+ if sys.version_info >= (3, 6):
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+ self.loop.close()
+
+ queue = multiprocessing.Queue()
+
+ # Temporarily block SIGTERM. The server process will inherit this
+ # block which will ensure it doesn't receive the SIGTERM until the
+ # handler is ready for it
+ mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
+ try:
+ self.process = multiprocessing.Process(target=run, args=(queue,))
+ self.process.start()
- if self._cleanup_socket is not None:
- self._cleanup_socket()
+ self.address = queue.get()
+ queue.close()
+ queue.join_thread()
+
+ return self.process
+ finally:
+ signal.pthread_sigmask(signal.SIG_SETMASK, mask)