-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathclient.py
More file actions
276 lines (233 loc) · 10 KB
/
client.py
File metadata and controls
276 lines (233 loc) · 10 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
from __future__ import annotations
import asyncio
import inspect
import json
import pickle
from asyncio import Queue, QueueFull, Task
from io import StringIO
from time import time
from typing import Any, NamedTuple, Type
import aiohttp
import msgpack
from aiohttp import ClientConnectorError, ClientResponseError, ClientWebSocketResponse
from rich.console import Console
from rich.segment import Segment
from textual._log import LogGroup, LogVerbosity
from textual.constants import DEVTOOLS_PORT
READY_TIMEOUT = 0.5
LOG_QUEUE_MAXSIZE = 512
class DevtoolsLog(NamedTuple):
"""A devtools log message.
Attributes:
objects_or_string: Corresponds to the data that will
ultimately be passed to Console.print in order to generate the log
Segments.
caller: Information about where this log message was
created. In other words, where did the user call `print` or `App.log`
from. Used to display line number and file name in the devtools window.
"""
objects_or_string: tuple[Any, ...] | str
caller: inspect.Traceback
class DevtoolsConsole(Console):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.record = True
def export_segments(self) -> list[Segment]:
"""Return the list of Segments that have be printed using this console
Returns:
The list of Segments that have been printed using this console
"""
with self._record_buffer_lock:
segments = self._record_buffer[:]
self._record_buffer.clear()
return segments
class DevtoolsConnectionError(Exception):
"""Raise when the devtools client is unable to connect to the server"""
class ClientShutdown:
"""Sentinel type sent to client queue(s) to indicate shutdown"""
class DevtoolsClient:
"""Client responsible for websocket communication with the devtools server.
Communicates using a simple JSON protocol.
Messages have the format `{"type": <str>, "payload": <json>}`.
Valid values for `"type"` (that can be sent from client -> server) are
`"client_log"` (for log messages) and `"client_spillover"` (for reporting
to the server that messages were discarded due to rate limiting).
A `"client_log"` message has a `"payload"` format as follows:
```
{"timestamp": <int, unix timestamp>,
"path": <str, path of file>,
"line_number": <int, line number log was made from>,
"encoded_segments": <str, pickled then b64 encoded Segments to log>}
```
A `"client_spillover"` message has a `"payload"` format as follows:
```
{"spillover": <int, the number of messages discarded by rate-limiting>}
```
Args:
host: The host the devtools server is running on, defaults to "127.0.0.1"
port: The port the devtools server is accessed via, `DEVTOOLS_PORT` by default.
"""
def __init__(self, host: str = "127.0.0.1", port: int | None = None) -> None:
if port is None:
port = DEVTOOLS_PORT
self.url: str = f"ws://{host}:{port}"
self.session: aiohttp.ClientSession | None = None
self.log_queue_task: Task | None = None
self.update_console_task: Task | None = None
self.console: DevtoolsConsole = DevtoolsConsole(file=StringIO())
self.websocket: ClientWebSocketResponse | None = None
self.log_queue: Queue[str | bytes | Type[ClientShutdown]] | None = None
self.spillover: int = 0
self.verbose: bool = False
self._ready_event: asyncio.Event = asyncio.Event()
async def connect(self) -> None:
"""Connect to the devtools server.
Raises:
DevtoolsConnectionError: If we're unable to establish
a connection to the server for any reason.
"""
self.session = aiohttp.ClientSession()
self.log_queue = Queue(maxsize=LOG_QUEUE_MAXSIZE)
try:
self.websocket = await self.session.ws_connect(
f"{self.url}/textual-devtools-websocket"
)
except (ClientConnectorError, ClientResponseError):
await self.session.close()
self.session = None
raise DevtoolsConnectionError()
log_queue = self.log_queue
websocket = self.websocket
async def update_console() -> None:
"""Coroutine function scheduled as a Task, which listens on
the websocket for updates from the server regarding any changes
in the server Console dimensions. When the client learns of this
change, it will update its own Console to ensure it renders at
the correct width for server-side display.
"""
assert self.websocket is not None
async for message in self.websocket:
if message.type == aiohttp.WSMsgType.TEXT:
message_json = json.loads(message.data)
if message_json["type"] == "server_info":
payload = message_json["payload"]
self.console.width = payload["width"]
self.console.height = payload["height"]
self.verbose = payload.get("verbose", False)
self._ready_event.set()
async def send_queued_logs():
"""Coroutine function which is scheduled as a Task, which consumes
messages from the log queue and sends them to the server via websocket.
"""
while True:
log = await log_queue.get()
if log is ClientShutdown:
log_queue.task_done()
break
if isinstance(log, str):
await websocket.send_str(log)
else:
assert isinstance(log, bytes)
await websocket.send_bytes(log)
log_queue.task_done()
async def server_info_received() -> None:
"""Wait for the first server info message to be received and handled."""
try:
await asyncio.wait_for(self._ready_event.wait(), timeout=READY_TIMEOUT)
except asyncio.TimeoutError:
return
self.log_queue_task = asyncio.create_task(send_queued_logs())
self.update_console_task = asyncio.create_task(update_console())
await server_info_received()
async def _stop_log_queue_processing(self) -> None:
"""Schedule end of processing of the log queue, meaning that any messages a
user logs will be added to the queue, but not consumed and sent to
the server.
"""
if self.log_queue is not None:
await self.log_queue.put(ClientShutdown)
if self.log_queue_task:
await self.log_queue_task
async def _stop_incoming_message_processing(self) -> None:
"""Schedule stop of the task which listens for incoming messages from the
server around changes in the server console size.
"""
if self.websocket:
await self.websocket.close()
if self.update_console_task:
await self.update_console_task
if self.session:
await self.session.close()
async def disconnect(self) -> None:
"""Disconnect from the devtools server by stopping tasks and
closing connections.
"""
await self._stop_log_queue_processing()
await self._stop_incoming_message_processing()
@property
def is_connected(self) -> bool:
"""Checks connection to devtools server.
Returns:
True if this host is connected to the server. False otherwise.
"""
if not self.session or not self.websocket:
return False
return not (self.session.closed or self.websocket.closed)
def log(
self,
log: DevtoolsLog,
group: LogGroup = LogGroup.UNDEFINED,
verbosity: LogVerbosity = LogVerbosity.NORMAL,
) -> None:
"""Queue a log to be sent to the devtools server for display.
Args:
log: The log to write to devtools
"""
if isinstance(log.objects_or_string, str):
self.console.print(log.objects_or_string, markup=False)
else:
self.console.print(*log.objects_or_string, markup=False)
segments = self.console.export_segments()
encoded_segments = self._encode_segments(segments)
message: bytes | None = msgpack.packb(
{
"type": "client_log",
"payload": {
"group": group.value,
"verbosity": verbosity.value,
"timestamp": int(time()),
"path": getattr(log.caller, "filename", ""),
"line_number": getattr(log.caller, "lineno", 0),
"segments": encoded_segments,
},
}
)
assert message is not None
try:
if self.log_queue:
self.log_queue.put_nowait(message)
if self.spillover > 0 and self.log_queue.qsize() < LOG_QUEUE_MAXSIZE:
# Tell the server how many messages we had to discard due
# to the log queue filling to capacity on the client.
spillover_message = json.dumps(
{
"type": "client_spillover",
"payload": {
"spillover": self.spillover,
},
}
)
self.log_queue.put_nowait(spillover_message)
self.spillover = 0
except QueueFull:
self.spillover += 1
@classmethod
def _encode_segments(cls, segments: list[Segment]) -> bytes:
"""Pickle a list of Segments
Args:
segments: A list of Segments to encode
Returns:
The Segment list pickled with the latest protocol.
"""
pickled = pickle.dumps(segments, protocol=4)
return pickled