summaryrefslogtreecommitdiffstats
path: root/bitbake/lib/bb/asyncrpc/connection.py
blob: 7f0cf6ba96ec8115df76143ad98e0a5ce2d97a21 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#
# Copyright BitBake Contributors
#
# SPDX-License-Identifier: GPL-2.0-only
#

import asyncio
import itertools
import json
from datetime import datetime
from .exceptions import ClientError, ConnectionClosedError


# The Python async server defaults to a 64K receive buffer, so we hardcode our
# maximum chunk size. It would be better if the client and server reported to
# each other what the maximum chunk sizes were, but that will slow down the
# connection setup with a round trip delay so I'd rather not do that unless it
# is necessary
DEFAULT_MAX_CHUNK = 32 * 1024


def chunkify(msg, max_chunk):
    if len(msg) < max_chunk - 1:
        yield "".join((msg, "\n"))
    else:
        yield "".join((json.dumps({"chunk-stream": None}), "\n"))

        args = [iter(msg)] * (max_chunk - 1)
        for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
            yield "".join(itertools.chain(m, "\n"))
        yield "\n"


def json_serialize(obj):
    if isinstance(obj, datetime):
        return obj.isoformat()
    raise TypeError("Type %s not serializeable" % type(obj))


class StreamConnection(object):
    def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
        self.reader = reader
        self.writer = writer
        self.timeout = timeout
        self.max_chunk = max_chunk

    @property
    def address(self):
        return self.writer.get_extra_info("peername")

    async def send_message(self, msg):
        for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk):
            self.writer.write(c.encode("utf-8"))
        await self.writer.drain()

    async def recv_message(self):
        l = await self.recv()

        m = json.loads(l)
        if not m:
            return m

        if "chunk-stream" in m:
            lines = []
            while True:
                l = await self.recv()
                if not l:
                    break
                lines.append(l)

            m = json.loads("".join(lines))

        return m

    async def send(self, msg):
        self.writer.write(("%s\n" % msg).encode("utf-8"))
        await self.writer.drain()

    async def recv(self):
        if self.timeout < 0:
            line = await self.reader.readline()
        else:
            try:
                line = await asyncio.wait_for(self.reader.readline(), self.timeout)
            except asyncio.TimeoutError:
                raise ConnectionError("Timed out waiting for data")

        if not line:
            raise ConnectionClosedError("Connection closed")

        line = line.decode("utf-8")

        if not line.endswith("\n"):
            raise ConnectionError("Bad message %r" % (line))

        return line.rstrip()

    async def close(self):
        self.reader = None
        if self.writer is not None:
            self.writer.close()
            self.writer = None


class WebsocketConnection(object):
    def __init__(self, socket, timeout):
        self.socket = socket
        self.timeout = timeout

    @property
    def address(self):
        return ":".join(str(s) for s in self.socket.remote_address)

    async def send_message(self, msg):
        await self.send(json.dumps(msg, default=json_serialize))

    async def recv_message(self):
        m = await self.recv()
        return json.loads(m)

    async def send(self, msg):
        import websockets.exceptions

        try:
            await self.socket.send(msg)
        except websockets.exceptions.ConnectionClosed:
            raise ConnectionClosedError("Connection closed")

    async def recv(self):
        import websockets.exceptions

        try:
            if self.timeout < 0:
                return await self.socket.recv()

            try:
                return await asyncio.wait_for(self.socket.recv(), self.timeout)
            except asyncio.TimeoutError:
                raise ConnectionError("Timed out waiting for data")
        except websockets.exceptions.ConnectionClosed:
            raise ConnectionClosedError("Connection closed")

    async def close(self):
        if self.socket is not None:
            await self.socket.close()
            self.socket = None