summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitbake/lib/bb/asyncrpc/client.py')
-rw-r--r--bitbake/lib/bb/asyncrpc/client.py291
1 files changed, 226 insertions, 65 deletions
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py
index 881434d2e9..65f3f8964d 100644
--- a/bitbake/lib/bb/asyncrpc/client.py
+++ b/bitbake/lib/bb/asyncrpc/client.py
@@ -10,47 +10,148 @@ import json
import os
import socket
import sys
-from . import chunkify, DEFAULT_MAX_CHUNK
+import re
+import contextlib
+from threading import Thread
+from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
+from .exceptions import ConnectionClosedError, InvokeError
+
+UNIX_PREFIX = "unix://"
+WS_PREFIX = "ws://"
+WSS_PREFIX = "wss://"
+
+ADDR_TYPE_UNIX = 0
+ADDR_TYPE_TCP = 1
+ADDR_TYPE_WS = 2
+
+WEBSOCKETS_MIN_VERSION = (9, 1)
+
+
+def parse_address(addr):
+ if addr.startswith(UNIX_PREFIX):
+ return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX) :],))
+ elif addr.startswith(WS_PREFIX) or addr.startswith(WSS_PREFIX):
+ return (ADDR_TYPE_WS, (addr,))
+ else:
+ m = re.match(r"\[(?P<host>[^\]]*)\]:(?P<port>\d+)$", addr)
+ if m is not None:
+ host = m.group("host")
+ port = m.group("port")
+ else:
+ host, port = addr.split(":")
+
+ return (ADDR_TYPE_TCP, (host, int(port)))
class AsyncClient(object):
- def __init__(self, proto_name, proto_version, logger, timeout=30):
- self.reader = None
- self.writer = None
+ def __init__(
+ self,
+ proto_name,
+ proto_version,
+ logger,
+ timeout=30,
+ server_headers=False,
+ headers={},
+ ):
+ self.socket = None
self.max_chunk = DEFAULT_MAX_CHUNK
self.proto_name = proto_name
self.proto_version = proto_version
self.logger = logger
self.timeout = timeout
+ self.needs_server_headers = server_headers
+ self.server_headers = {}
+ self.headers = headers
async def connect_tcp(self, address, port):
async def connect_sock():
- return await asyncio.open_connection(address, port)
+ reader, writer = await asyncio.open_connection(address, port)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
self._connect_sock = connect_sock
async def connect_unix(self, path):
async def connect_sock():
- return await asyncio.open_unix_connection(path)
+ # AF_UNIX has path length issues so chdir here to workaround
+ cwd = os.getcwd()
+ try:
+ os.chdir(os.path.dirname(path))
+ # The socket must be opened synchronously so that CWD doesn't get
+ # changed out from underneath us so we pass as a sock into asyncio
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
+ sock.connect(os.path.basename(path))
+ finally:
+ os.chdir(cwd)
+ reader, writer = await asyncio.open_unix_connection(sock=sock)
+ return StreamConnection(reader, writer, self.timeout, self.max_chunk)
+
+ self._connect_sock = connect_sock
+
+ async def connect_websocket(self, uri):
+ import websockets
+
+ try:
+ version = tuple(
+ int(v)
+ for v in websockets.__version__.split(".")[
+ 0 : len(WEBSOCKETS_MIN_VERSION)
+ ]
+ )
+ except ValueError:
+ raise ImportError(
+ f"Unable to parse websockets version '{websockets.__version__}'"
+ )
+
+ if version < WEBSOCKETS_MIN_VERSION:
+ min_ver_str = ".".join(str(v) for v in WEBSOCKETS_MIN_VERSION)
+ raise ImportError(
+ f"Websockets version {websockets.__version__} is less than minimum required version {min_ver_str}"
+ )
+
+ async def connect_sock():
+ websocket = await websockets.connect(uri, ping_interval=None)
+ return WebsocketConnection(websocket, self.timeout)
self._connect_sock = connect_sock
async def setup_connection(self):
- s = '%s %s\n\n' % (self.proto_name, self.proto_version)
- self.writer.write(s.encode("utf-8"))
- await self.writer.drain()
+ # Send headers
+ await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
+ await self.socket.send(
+ "needs-headers: %s" % ("true" if self.needs_server_headers else "false")
+ )
+ for k, v in self.headers.items():
+ await self.socket.send("%s: %s" % (k, v))
+
+ # End of headers
+ await self.socket.send("")
+
+ self.server_headers = {}
+ if self.needs_server_headers:
+ while True:
+ line = await self.socket.recv()
+ if not line:
+ # End headers
+ break
+ tag, value = line.split(":", 1)
+ self.server_headers[tag.lower()] = value.strip()
+
+ async def get_header(self, tag, default):
+ await self.connect()
+ return self.server_headers.get(tag, default)
async def connect(self):
- if self.reader is None or self.writer is None:
- (self.reader, self.writer) = await self._connect_sock()
+ if self.socket is None:
+ self.socket = await self._connect_sock()
await self.setup_connection()
- async def close(self):
- self.reader = None
+ async def disconnect(self):
+ if self.socket is not None:
+ await self.socket.close()
+ self.socket = None
- if self.writer is not None:
- self.writer.close()
- self.writer = None
+ async def close(self):
+ await self.disconnect()
async def _send_wrapper(self, proc):
count = 0
@@ -61,6 +162,7 @@ class AsyncClient(object):
except (
OSError,
ConnectionError,
+ ConnectionClosedError,
json.JSONDecodeError,
UnicodeDecodeError,
) as e:
@@ -72,49 +174,27 @@ class AsyncClient(object):
await self.close()
count += 1
- async def send_message(self, msg):
- async def get_line():
- try:
- line = await asyncio.wait_for(self.reader.readline(), self.timeout)
- except asyncio.TimeoutError:
- raise ConnectionError("Timed out waiting for server")
-
- if not line:
- raise ConnectionError("Connection closed")
-
- line = line.decode("utf-8")
-
- if not line.endswith("\n"):
- raise ConnectionError("Bad message %r" % (line))
-
- return line
+ def check_invoke_error(self, msg):
+ if isinstance(msg, dict) and "invoke-error" in msg:
+ raise InvokeError(msg["invoke-error"]["message"])
+ async def invoke(self, msg):
async def proc():
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode("utf-8"))
- await self.writer.drain()
-
- l = await get_line()
-
- m = json.loads(l)
- if m and "chunk-stream" in m:
- lines = []
- while True:
- l = (await get_line()).rstrip("\n")
- if not l:
- break
- lines.append(l)
+ await self.socket.send_message(msg)
+ return await self.socket.recv_message()
- m = json.loads("".join(lines))
+ result = await self._send_wrapper(proc)
+ self.check_invoke_error(result)
+ return result
- return m
+ async def ping(self):
+ return await self.invoke({"ping": {}})
- return await self._send_wrapper(proc)
+ async def __aenter__(self):
+ return self
- async def ping(self):
- return await self.send_message(
- {'ping': {}}
- )
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
class Client(object):
@@ -132,7 +212,7 @@ class Client(object):
# required (but harmless) with it.
asyncio.set_event_loop(self.loop)
- self._add_methods('connect_tcp', 'ping')
+ self._add_methods("connect_tcp", "ping")
@abc.abstractmethod
def _get_async_client(self):
@@ -150,14 +230,8 @@ class Client(object):
setattr(self, m, self._get_downcall_wrapper(downcall))
def connect_unix(self, path):
- # AF_UNIX has path length issues so chdir here to workaround
- cwd = os.getcwd()
- try:
- os.chdir(os.path.dirname(path))
- self.loop.run_until_complete(self.client.connect_unix(os.path.basename(path)))
- self.loop.run_until_complete(self.client.connect())
- finally:
- os.chdir(cwd)
+ self.loop.run_until_complete(self.client.connect_unix(path))
+ self.loop.run_until_complete(self.client.connect())
@property
def max_chunk(self):
@@ -167,8 +241,95 @@ class Client(object):
def max_chunk(self, value):
self.client.max_chunk = value
- def close(self):
+ def disconnect(self):
self.loop.run_until_complete(self.client.close())
- if sys.version_info >= (3, 6):
+
+ def close(self):
+ if self.loop:
+ self.loop.run_until_complete(self.client.close())
+ if sys.version_info >= (3, 6):
+ self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+ self.loop.close()
+ self.loop = None
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+ return False
+
+
+class ClientPool(object):
+ def __init__(self, max_clients):
+ self.avail_clients = []
+ self.num_clients = 0
+ self.max_clients = max_clients
+ self.loop = None
+ self.client_condition = None
+
+ @abc.abstractmethod
+ async def _new_client(self):
+ raise NotImplementedError("Must be implemented in derived class")
+
+ def close(self):
+ if self.client_condition:
+ self.client_condition = None
+
+ if self.loop:
+ self.loop.run_until_complete(self.__close_clients())
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
- self.loop.close()
+ self.loop.close()
+ self.loop = None
+
+ def run_tasks(self, tasks):
+ if not self.loop:
+ self.loop = asyncio.new_event_loop()
+
+ thread = Thread(target=self.__thread_main, args=(tasks,))
+ thread.start()
+ thread.join()
+
+ @contextlib.asynccontextmanager
+ async def get_client(self):
+ async with self.client_condition:
+ if self.avail_clients:
+ client = self.avail_clients.pop()
+ elif self.num_clients < self.max_clients:
+ self.num_clients += 1
+ client = await self._new_client()
+ else:
+ while not self.avail_clients:
+ await self.client_condition.wait()
+ client = self.avail_clients.pop()
+
+ try:
+ yield client
+ finally:
+ async with self.client_condition:
+ self.avail_clients.append(client)
+ self.client_condition.notify()
+
+ def __thread_main(self, tasks):
+ async def process_task(task):
+ async with self.get_client() as client:
+ await task(client)
+
+ asyncio.set_event_loop(self.loop)
+ if not self.client_condition:
+ self.client_condition = asyncio.Condition()
+ tasks = [process_task(t) for t in tasks]
+ self.loop.run_until_complete(asyncio.gather(*tasks))
+
+ async def __close_clients(self):
+ for c in self.avail_clients:
+ await c.close()
+ self.avail_clients = []
+ self.num_clients = 0
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+ return False