summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc')
-rw-r--r--bitbake/lib/bb/asyncrpc/__init__.py16
-rw-r--r--bitbake/lib/bb/asyncrpc/client.py337
-rw-r--r--bitbake/lib/bb/asyncrpc/connection.py146
-rw-r--r--bitbake/lib/bb/asyncrpc/exceptions.py21
-rw-r--r--bitbake/lib/bb/asyncrpc/serv.py391
5 files changed, 911 insertions, 0 deletions
diff --git a/bitbake/lib/bb/asyncrpc/__init__.py b/bitbake/lib/bb/asyncrpc/__init__.py
new file mode 100644
index 0000000000..639e1607f8
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/__init__.py
@@ -0,0 +1,16 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+
+from .client import AsyncClient, Client, ClientPool
+from .serv import AsyncServer, AsyncServerConnection
+from .connection import DEFAULT_MAX_CHUNK
+from .exceptions import (
+ ClientError,
+ ServerError,
+ ConnectionClosedError,
+ InvokeError,
+)
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py
new file mode 100644
index 0000000000..b49de99313
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/client.py
@@ -0,0 +1,337 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+import abc
+import asyncio
+import json
+import os
+import socket
+import sys
+import re
+import contextlib
+from threading import Thread
+from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
+from .exceptions import ConnectionClosedError, InvokeError
+
+UNIX_PREFIX = "unix://"
+WS_PREFIX = "ws://"
+WSS_PREFIX = "wss://"
+
+ADDR_TYPE_UNIX = 0
+ADDR_TYPE_TCP = 1
+ADDR_TYPE_WS = 2
+
+WEBSOCKETS_MIN_VERSION = (9, 1)
+# Need websockets 10 with python 3.10+
+if sys.version_info >= (3, 10, 0):
+ WEBSOCKETS_MIN_VERSION = (10, 0)
+
+def parse_address(addr):
+ if addr.startswith(UNIX_PREFIX):
+ return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
+ elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
+ return (ADDR_TYPE_WS, (addr,))
+ else:
+ m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
+ if m is not None:
+ host = m.group("host")
+ port = m.group("port")
+ else:
+ host, port = addr.split(":")
+
+ return (ADDR_TYPE_TCP, (host, int(port)))
+
+
+class AsyncClient(object):
+ def __init__(
+ self,
+ proto_name,
+ proto_version,
+ logger,
+ timeout=30,
+ server_headers=False,
+ headers={},
+ ):
+ self.socket = None
+ self.max_chunk = DEFAULT_MAX_CHUNK
+ self.proto_name = proto_name
+ self.proto_version = proto_version
+ self.logger = logger
+ self.timeout = timeout
+ self.needs_server_headers = server_headers
+ self.server_headers = {}
+ self.headers = headers
+
+ async def connect_tcp(self, address, port):
+ async def connect_sock():
+ reader, writer = await asyncio.open_connection(address, port)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
+
+ self._connect_sock = connect_sock
+
+ async def connect_unix(self, path):
+ async def connect_sock():
+ # AF_UNIX has path length issues so chdir here to workaround
+ cwd = os.getcwd()
+ try:
+ os.chdir(os.path.dirname(path))
+ # The socket must be opened synchronously so that CWD doesn't get
+ # changed out from underneath us so we pass as a sock into asyncio
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
+ sock.connect(os.path.basename(path))
+ finally:
+ os.chdir(cwd)
+ reader, writer = await asyncio.open_unix_connection(sock=sock)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
+
+ self._connect_sock = connect_sock
+
+ async def connect_websocket(self, uri):
+ import websockets
+
+ try:
+ version = tuple(
+ int(v)
+ for v in websockets.__version__.split(".")[
+ 0 : len(WEBSOCKETS_MIN_VERSION)
+ ]
+ )
+ except ValueError:
+ raise ImportError(
+ f"Unable to parse websockets version '{websockets.__version__}'"
+ )
+
+ if version < WEBSOCKETS_MIN_VERSION:
+ min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION)
+ raise ImportError(
+ f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}"
+ )
+
+ async def connect_sock():
+ websocket = await websockets.connect(uri, ping_interval=None)
+ return WebsocketConnection(websocket, self.timeout)
+
+ self._connect_sock = connect_sock
+
+ async def setup_connection(self):
+ # Send headers
+ await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
+ await self.socket.send(
+ "needs-headers: %s" % ("true" if self.needs_server_headers else "false")
+ )
+ for k, v in self.headers.items():
+ await self.socket.send("%s: %s" % (k, v))
+
+ # End of headers
+ await self.socket.send("")
+
+ self.server_headers = {}
+ if self.needs_server_headers:
+ while True:
+ line = await self.socket.recv()
+ if not line:
+ # End headers
+ break
+ tag, value = line.split(":", 1)
+ self.server_headers[tag.lower()] = value.strip()
+
+ async def get_header(self, tag, default):
+ await self.connect()
+ return self.server_headers.get(tag, default)
+
+ async def connect(self):
+ if self.socket is None:
+ self.socket = await self._connect_sock()
+ await self.setup_connection()
+
+ async def disconnect(self):
+ if self.socket is not None:
+ await self.socket.close()
+ self.socket = None
+
+ async def close(self):
+ await self.disconnect()
+
+ async def _send_wrapper(self, proc):
+ count = 0
+ while True:
+ try:
+ await self.connect()
+ return await proc()
+ except (
+ OSError,
+ ConnectionError,
+ ConnectionClosedError,
+ json.JSONDecodeError,
+ UnicodeDecodeError,
+ ) as e:
+ self.logger.warning("Error talking to server: %s" % e)
+ if count >= 3:
+ if not isinstance(e, ConnectionError):
+ raise ConnectionError(str(e))
+ raise e
+ await self.close()
+ count += 1
+
+ def check_invoke_error(self, msg):
+ if isinstance(msg, dict) and "invoke-error" in msg:
+ raise InvokeError(msg["invoke-error"]["message"])
+
+ async def invoke(self, msg):
+ async def proc():
+ await self.socket.send_message(msg)
+ return await self.socket.recv_message()
+
+ result = await self._send_wrapper(proc)
+ self.check_invoke_error(result)
+ return result
+
+ async def ping(self):
+ return await self.invoke({"ping": {}})
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+
+class Client(object):
+ def __init__(self):
+ self.client = self._get_async_client()
+ self.loop = asyncio.new_event_loop()
+
+ # Override any pre-existing loop.
+ # Without this, the PR server export selftest triggers a hang
+ # when running with Python 3.7. The drawback is that there is
+ # potential for issues if the PR and hash equiv (or some new)
+ # clients need to both be instantiated in the same process.
+ # This should be revisited if/when Python 3.9 becomes the
+ # minimum required version for BitBake, as it seems not
+ # required (but harmless) with it.
+ asyncio.set_event_loop(self.loop)
+
+ self._add_methods("connect_tcp", "ping")
+
+ @abc.abstractmethod
+ def _get_async_client(self):
+ pass
+
+ def _get_downcall_wrapper(self, downcall):
+ def wrapper(*args, **kwargs):
+ return self.loop.run_until_complete(downcall(*args, **kwargs))
+
+ return wrapper
+
+ def _add_methods(self, *methods):
+ for m in methods:
+ downcall = getattr(self.client, m)
+ setattr(self, m, self._get_downcall_wrapper(downcall))
+
+ def connect_unix(self, path):
+ self.loop.run_until_complete(self.client.connect_unix(path))
+ self.loop.run_until_complete(self.client.connect())
+
+ @property
+ def max_chunk(self):
+ return self.client.max_chunk
+
+ @max_chunk.setter
+ def max_chunk(self, value):
+ self.client.max_chunk = value
+
+ def disconnect(self):
+ self.loop.run_until_complete(self.client.close())
+
+ def close(self):
+ if self.loop:
+ self.loop.run_until_complete(self.client.close())
+ if sys.version_info >= (3, 6):
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+ self.loop.close()
+ self.loop = None
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+ return False
+
+
+class ClientPool(object):
+ def __init__(self, max_clients):
+ self.avail_clients = []
+ self.num_clients = 0
+ self.max_clients = max_clients
+ self.loop = None
+ self.client_condition = None
+
+ @abc.abstractmethod
+ async def _new_client(self):
+ raise NotImplementedError("Must be implemented in derived class")
+
+ def close(self):
+ if self.client_condition:
+ self.client_condition = None
+
+ if self.loop:
+ self.loop.run_until_complete(self.__close_clients())
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+ self.loop.close()
+ self.loop = None
+
+ def run_tasks(self, tasks):
+ if not self.loop:
+ self.loop = asyncio.new_event_loop()
+
+ thread = Thread(target=self.__thread_main, args=(tasks,))
+ thread.start()
+ thread.join()
+
+ @contextlib.asynccontextmanager
+ async def get_client(self):
+ async with self.client_condition:
+ if self.avail_clients:
+ client = self.avail_clients.pop()
+ elif self.num_clients < self.max_clients:
+ self.num_clients += 1
+ client = await self._new_client()
+ else:
+ while not self.avail_clients:
+ await self.client_condition.wait()
+ client = self.avail_clients.pop()
+
+ try:
+ yield client
+ finally:
+ async with self.client_condition:
+ self.avail_clients.append(client)
+ self.client_condition.notify()
+
+ def __thread_main(self, tasks):
+ async def process_task(task):
+ async with self.get_client() as client:
+ await task(client)
+
+ asyncio.set_event_loop(self.loop)
+ if not self.client_condition:
+ self.client_condition = asyncio.Condition()
+ tasks = [process_task(t) for t in tasks]
+ self.loop.run_until_complete(asyncio.gather(*tasks))
+
+ async def __close_clients(self):
+ for c in self.avail_clients:
+ await c.close()
+ self.avail_clients = []
+ self.num_clients = 0
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+ return False
diff --git a/bitbake/lib/bb/asyncrpc/connection.py b/bitbake/lib/bb/asyncrpc/connection.py
new file mode 100644
index 0000000000..7f0cf6ba96
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/connection.py
@@ -0,0 +1,146 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+import asyncio
+import itertools
+import json
+from datetime import datetime
+from .exceptions import ClientError, ConnectionClosedError
+
+
+# The Python async server defaults to a 64K receive buffer, so we hardcode our
+# maximum chunk size. It would be better if the client and server reported to
+# each other what the maximum chunk sizes were, but that will slow down the
+# connection setup with a round trip delay so I'd rather not do that unless it
+# is necessary
+DEFAULT_MAX_CHUNK = 32 * 1024
+
+
+def chunkify(msg, max_chunk):
+ if len(msg) < max_chunk - 1:
+ yield "".join((msg, "\n"))
+ else:
+ yield "".join((json.dumps({"chunk-stream": None}), "\n"))
+
+ args = [iter(msg)] * (max_chunk - 1)
+ for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
+ yield "".join(itertools.chain(m, "\n"))
+ yield "\n"
+
+
+def json_serialize(obj):
+ if isinstance(obj, datetime):
+ return obj.isoformat()
+ raise TypeError("Type %s not serializeable" % type(obj))
+
+
+class StreamConnection(object):
+ def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
+ self.reader = reader
+ self.writer = writer
+ self.timeout = timeout
+ self.max_chunk = max_chunk
+
+ @property
+ def address(self):
+ return self.writer.get_extra_info("peername")
+
+ async def send_message(self, msg):
+ for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk):
+ self.writer.write(c.encode("utf-8"))
+ await self.writer.drain()
+
+ async def recv_message(self):
+ l = await self.recv()
+
+ m = json.loads(l)
+ if not m:
+ return m
+
+ if "chunk-stream" in m:
+ lines = []
+ while True:
+ l = await self.recv()
+ if not l:
+ break
+ lines.append(l)
+
+ m = json.loads("".join(lines))
+
+ return m
+
+ async def send(self, msg):
+ self.writer.write(("%s\n" % msg).encode("utf-8"))
+ await self.writer.drain()
+
+ async def recv(self):
+ if self.timeout < 0:
+ line = await self.reader.readline()
+ else:
+ try:
+ line = await asyncio.wait_for(self.reader.readline(), self.timeout)
+ except asyncio.TimeoutError:
+ raise ConnectionError("Timed out waiting for data")
+
+ if not line:
+ raise ConnectionClosedError("Connection closed")
+
+ line = line.decode("utf-8")
+
+ if not line.endswith("\n"):
+ raise ConnectionError("Bad message %r" % (line))
+
+ return line.rstrip()
+
+ async def close(self):
+ self.reader = None
+ if self.writer is not None:
+ self.writer.close()
+ self.writer = None
+
+
+class WebsocketConnection(object):
+ def __init__(self, socket, timeout):
+ self.socket = socket
+ self.timeout = timeout
+
+ @property
+ def address(self):
+ return ":".join(str(s) for s in self.socket.remote_address)
+
+ async def send_message(self, msg):
+ await self.send(json.dumps(msg, default=json_serialize))
+
+ async def recv_message(self):
+ m = await self.recv()
+ return json.loads(m)
+
+ async def send(self, msg):
+ import websockets.exceptions
+
+ try:
+ await self.socket.send(msg)
+ except websockets.exceptions.ConnectionClosed:
+ raise ConnectionClosedError("Connection closed")
+
+ async def recv(self):
+ import websockets.exceptions
+
+ try:
+ if self.timeout < 0:
+ return await self.socket.recv()
+
+ try:
+ return await asyncio.wait_for(self.socket.recv(), self.timeout)
+ except asyncio.TimeoutError:
+ raise ConnectionError("Timed out waiting for data")
+ except websockets.exceptions.ConnectionClosed:
+ raise ConnectionClosedError("Connection closed")
+
+ async def close(self):
+ if self.socket is not None:
+ await self.socket.close()
+ self.socket = None
diff --git a/bitbake/lib/bb/asyncrpc/exceptions.py b/bitbake/lib/bb/asyncrpc/exceptions.py
new file mode 100644
index 0000000000..ae1043a38b
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/exceptions.py
@@ -0,0 +1,21 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+
+class ClientError(Exception):
+ pass
+
+
+class InvokeError(Exception):
+ pass
+
+
+class ServerError(Exception):
+ pass
+
+
+class ConnectionClosedError(Exception):
+ pass
diff --git a/bitbake/lib/bb/asyncrpc/serv.py b/bitbake/lib/bb/asyncrpc/serv.py
new file mode 100644
index 0000000000..a66117acad
--- /dev/null
+++ b/bitbake/lib/bb/asyncrpc/serv.py
@@ -0,0 +1,391 @@
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+import abc
+import asyncio
+import json
+import os
+import signal
+import socket
+import sys
+import multiprocessing
+import logging
+from .connection import StreamConnection, WebsocketConnection
+from .exceptions import ClientError, ServerError, ConnectionClosedError, InvokeError
+
+
+class ClientLoggerAdapter(logging.LoggerAdapter):
+ def process(self, msg, kwargs):
+ return f"[Client {self.extra['address']}] {msg}", kwargs
+
+
+class AsyncServerConnection(object):
+ # If a handler returns this object (e.g. `return self.NO_RESPONSE`), no
+ # return message will be automatically be sent back to the client
+ NO_RESPONSE = object()
+
+ def __init__(self, socket, proto_name, logger):
+ self.socket = socket
+ self.proto_name = proto_name
+ self.handlers = {
+ "ping": self.handle_ping,
+ }
+ self.logger = ClientLoggerAdapter(
+ logger,
+ {
+ "address": socket.address,
+ },
+ )
+ self.client_headers = {}
+
+ async def close(self):
+ await self.socket.close()
+
+ async def handle_headers(self, headers):
+ return {}
+
+ async def process_requests(self):
+ try:
+ self.logger.info("Client %r connected" % (self.socket.address,))
+
+ # Read protocol and version
+ client_protocol = await self.socket.recv()
+ if not client_protocol:
+ return
+
+ (client_proto_name, client_proto_version) = client_protocol.split()
+ if client_proto_name != self.proto_name:
+ self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
+ return
+
+ self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
+ if not self.validate_proto_version():
+ self.logger.debug(
+ "Rejecting invalid protocol version %s" % (client_proto_version)
+ )
+ return
+
+ # Read headers
+ self.client_headers = {}
+ while True:
+ header = await self.socket.recv()
+ if not header:
+ # Empty line. End of headers
+ break
+ tag, value = header.split(":", 1)
+ self.client_headers[tag.lower()] = value.strip()
+
+ if self.client_headers.get("needs-headers", "false") == "true":
+ for k, v in (await self.handle_headers(self.client_headers)).items():
+ await self.socket.send("%s: %s" % (k, v))
+ await self.socket.send("")
+
+ # Handle messages
+ while True:
+ d = await self.socket.recv_message()
+ if d is None:
+ break
+ try:
+ response = await self.dispatch_message(d)
+ except InvokeError as e:
+ await self.socket.send_message(
+ {"invoke-error": {"message": str(e)}}
+ )
+ break
+
+ if response is not self.NO_RESPONSE:
+ await self.socket.send_message(response)
+
+ except ConnectionClosedError as e:
+ self.logger.info(str(e))
+ except (ClientError, ConnectionError) as e:
+ self.logger.error(str(e))
+ finally:
+ await self.close()
+
+ async def dispatch_message(self, msg):
+ for k in self.handlers.keys():
+ if k in msg:
+ self.logger.debug("Handling %s" % k)
+ return await self.handlers[k](msg[k])
+
+ raise ClientError("Unrecognized command %r" % msg)
+
+ async def handle_ping(self, request):
+ return {"alive": True}
+
+
+class StreamServer(object):
+ def __init__(self, handler, logger):
+ self.handler = handler
+ self.logger = logger
+ self.closed = False
+
+ async def handle_stream_client(self, reader, writer):
+ # writer.transport.set_write_buffer_limits(0)
+ socket = StreamConnection(reader, writer, -1)
+ if self.closed:
+ await socket.close()
+ return
+
+ await self.handler(socket)
+
+ async def stop(self):
+ self.closed = True
+
+
+class TCPStreamServer(StreamServer):
+ def __init__(self, host, port, handler, logger):
+ super().__init__(handler, logger)
+ self.host = host
+ self.port = port
+
+ def start(self, loop):
+ self.server = loop.run_until_complete(
+ asyncio.start_server(self.handle_stream_client, self.host, self.port)
+ )
+
+ for s in self.server.sockets:
+ self.logger.debug("Listening on %r" % (s.getsockname(),))
+ # Newer python does this automatically. Do it manually here for
+ # maximum compatibility
+ s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
+ s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
+
+ # Enable keep alives. This prevents broken client connections
+ # from persisting on the server for long periods of time.
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
+
+ name = self.server.sockets[0].getsockname()
+ if self.server.sockets[0].family == socket.AF_INET6:
+ self.address = "[%s]:%d" % (name[0], name[1])
+ else:
+ self.address = "%s:%d" % (name[0], name[1])
+
+ return [self.server.wait_closed()]
+
+ async def stop(self):
+ await super().stop()
+ self.server.close()
+
+ def cleanup(self):
+ pass
+
+
+class UnixStreamServer(StreamServer):
+ def __init__(self, path, handler, logger):
+ super().__init__(handler, logger)
+ self.path = path
+
+ def start(self, loop):
+ cwd = os.getcwd()
+ try:
+ # Work around path length limits in AF_UNIX
+ os.chdir(os.path.dirname(self.path))
+ self.server = loop.run_until_complete(
+ asyncio.start_unix_server(
+ self.handle_stream_client, os.path.basename(self.path)
+ )
+ )
+ finally:
+ os.chdir(cwd)
+
+ self.logger.debug("Listening on %r" % self.path)
+ self.address = "unix://%s" % os.path.abspath(self.path)
+ return [self.server.wait_closed()]
+
+ async def stop(self):
+ await super().stop()
+ self.server.close()
+
+ def cleanup(self):
+ os.unlink(self.path)
+
+
+class WebsocketsServer(object):
+ def __init__(self, host, port, handler, logger):
+ self.host = host
+ self.port = port
+ self.handler = handler
+ self.logger = logger
+
+ def start(self, loop):
+ import websockets.server
+
+ self.server = loop.run_until_complete(
+ websockets.server.serve(
+ self.client_handler,
+ self.host,
+ self.port,
+ ping_interval=None,
+ )
+ )
+
+ for s in self.server.sockets:
+ self.logger.debug("Listening on %r" % (s.getsockname(),))
+
+ # Enable keep alives. This prevents broken client connections
+ # from persisting on the server for long periods of time.
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
+
+ name = self.server.sockets[0].getsockname()
+ if self.server.sockets[0].family == socket.AF_INET6:
+ self.address = "ws://[%s]:%d" % (name[0], name[1])
+ else:
+ self.address = "ws://%s:%d" % (name[0], name[1])
+
+ return [self.server.wait_closed()]
+
+ async def stop(self):
+ self.server.close()
+
+ def cleanup(self):
+ pass
+
+ async def client_handler(self, websocket):
+ socket = WebsocketConnection(websocket, -1)
+ await self.handler(socket)
+
+
+class AsyncServer(object):
+ def __init__(self, logger):
+ self.logger = logger
+ self.loop = None
+ self.run_tasks = []
+
+ def start_tcp_server(self, host, port):
+ self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
+
+ def start_unix_server(self, path):
+ self.server = UnixStreamServer(path, self._client_handler, self.logger)
+
+ def start_websocket_server(self, host, port):
+ self.server = WebsocketsServer(host, port, self._client_handler, self.logger)
+
+ async def _client_handler(self, socket):
+ address = socket.address
+ try:
+ client = self.accept_client(socket)
+ await client.process_requests()
+ except Exception as e:
+ import traceback
+
+ self.logger.error(
+ "Error from client %s: %s" % (address, str(e)), exc_info=True
+ )
+ traceback.print_exc()
+ finally:
+ self.logger.debug("Client %s disconnected", address)
+ await socket.close()
+
+ @abc.abstractmethod
+ def accept_client(self, socket):
+ pass
+
+ async def stop(self):
+ self.logger.debug("Stopping server")
+ await self.server.stop()
+
+ def start(self):
+ tasks = self.server.start(self.loop)
+ self.address = self.server.address
+ return tasks
+
+ def signal_handler(self):
+ self.logger.debug("Got exit signal")
+ self.loop.create_task(self.stop())
+
+ def _serve_forever(self, tasks):
+ try:
+ self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
+ self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
+ self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
+ signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
+
+ self.loop.run_until_complete(asyncio.gather(*tasks))
+
+ self.logger.debug("Server shutting down")
+ finally:
+ self.server.cleanup()
+
+ def serve_forever(self):
+ """
+ Serve requests in the current process
+ """
+ self._create_loop()
+ tasks = self.start()
+ self._serve_forever(tasks)
+ self.loop.close()
+
+ def _create_loop(self):
+ # Create loop and override any loop that may have existed in
+ # a parent process. It is possible that the usecases of
+ # serve_forever might be constrained enough to allow using
+ # get_event_loop here, but better safe than sorry for now.
+ self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(self.loop)
+
+ def serve_as_process(self, *, prefunc=None, args=(), log_level=None):
+ """
+ Serve requests in a child process
+ """
+
+ def run(queue):
+ # Create loop and override any loop that may have existed
+ # in a parent process. Without doing this and instead
+ # using get_event_loop, at the very minimum the hashserv
+ # unit tests will hang when running the second test.
+ # This happens since get_event_loop in the spawned server
+ # process for the second testcase ends up with the loop
+ # from the hashserv client created in the unit test process
+ # when running the first testcase. The problem is somewhat
+ # more general, though, as any potential use of asyncio in
+ # Cooker could create a loop that needs to replaced in this
+ # new process.
+ self._create_loop()
+ try:
+ self.address = None
+ tasks = self.start()
+ finally:
+ # Always put the server address to wake up the parent task
+ queue.put(self.address)
+ queue.close()
+
+ if prefunc is not None:
+ prefunc(self, *args)
+
+ if log_level is not None:
+ self.logger.setLevel(log_level)
+
+ self._serve_forever(tasks)
+
+ if sys.version_info >= (3, 6):
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+ self.loop.close()
+
+ queue = multiprocessing.Queue()
+
+ # Temporarily block SIGTERM. The server process will inherit this
+ # block which will ensure it doesn't receive the SIGTERM until the
+ # handler is ready for it
+ mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGTERM])
+ try:
+ self.process = multiprocessing.Process(target=run, args=(queue,))
+ self.process.start()
+
+ self.address = queue.get()
+ queue.close()
+ queue.join_thread()
+
+ return self.process
+ finally:
+ signal.pthread_sigmask(signal.SIG_SETMASK, mask)