Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from base64 import encode as base64_encode 

2from fiber import CancelledError 

3from fiber import Fiber 

4from fiber import Queue 

5from fiber import QueueError 

6from net.stcp.client import Client as StcpClient 

7from net.tcp.client import Client as TcpClient 

8from random.pseudo import randbytes 

9from . import WebsocketError 

10from .common import HEADER_FIN 

11from .common import HEADER_MASK 

12from .common import OpCode 

13from .server import Server 

14 

15class _ClosedError(Error): 

16 pass 

17 

18trait _Io: 

19 

20 func write(self, data: bytes): 

21 pass 

22 

23 func read(self, size: i64) -> bytes: 

24 pass 

25 

26class _TcpIo(_Io): 

27 client: TcpClient 

28 

29 func write(self, data: bytes): 

30 self.client.write(data) 

31 

32 func read(self, size: i64) -> bytes: 

33 return self.client.read(size) 

34 

35class _StcpIo(_Io): 

36 client: StcpClient 

37 

38 func write(self, data: bytes): 

39 self.client.write(data) 

40 

41 func read(self, size: i64) -> bytes: 

42 return self.client.read(size) 

43 

44trait Handler: 

45 """All methods are called from the client's reader fiber. 

46 

47 """ 

48 

49 func on_binary(self, data: bytes): 

50 """Called when a binary message has been received from the server. 

51 

52 """ 

53 

54 func on_text(self, data: string): 

55 """Called when a text message has been received from the server. 

56 

57 """ 

58 

59 func on_disconnected(self): 

60 """Called when disconnected by the server. 

61 

62 """ 

63 

64class _DefaultHandler(Handler): 

65 _binary_queue: Queue[bytes] 

66 _text_queue: Queue[string] 

67 

68 func __init__(self, binary_queue: Queue[bytes], text_queue: Queue[string]): 

69 self._binary_queue = binary_queue 

70 self._text_queue = text_queue 

71 

72 func on_binary(self, data: bytes?): 

73 self._binary_queue.put(data) 

74 

75 func on_text(self, data: string): 

76 self._text_queue.put(data) 

77 

78 func on_disconnected(self): 

79 self._binary_queue.close() 

80 self._text_queue.close() 

81 

82func _send_frame(io: _Io, data: bytes, op_code: OpCode): 

83 header = b"" 

84 data_size = data.length() 

85 header += (HEADER_FIN | u8(op_code)) 

86 

87 if data_size < 126: 

88 header += (HEADER_MASK | u8(data_size)) 

89 elif data_size < 65536: 

90 header += (HEADER_MASK | 126) 

91 header += u8((data_size >> 8) & 0xff) 

92 header += u8((data_size >> 0) & 0xff) 

93 else: 

94 header += (HEADER_MASK | 127) 

95 header += 0 

96 header += 0 

97 header += 0 

98 header += 0 

99 header += u8((data_size >> 24) & 0xff) 

100 header += u8((data_size >> 16) & 0xff) 

101 header += u8((data_size >> 8) & 0xff) 

102 header += u8((data_size >> 0) & 0xff) 

103 

104 header += 0 

105 header += 0 

106 header += 0 

107 header += 0 

108 

109 io.write(header) 

110 io.write(data) 

111 

112class _ReaderFiber(Fiber): 

113 _io: _Io 

114 _handler: Handler 

115 _data_op_code: OpCode 

116 _data: bytes? 

117 

118 func __init__(self, handler: Handler, io: _Io): 

119 self._handler = handler 

120 self._io = io 

121 self._data = None 

122 

123 func _tcp_read(self, size: i64) -> bytes: 

124 data = self._io.read(size) 

125 

126 if data.length() != size: 

127 raise _ClosedError() 

128 

129 return data 

130 

131 func _read_frame(self) -> (OpCode, bytes, bool): 

132 header = self._tcp_read(2) 

133 op_code = OpCode(header[0] & ~HEADER_MASK) 

134 fin = (header[0] & HEADER_FIN) == HEADER_FIN 

135 masked = (header[1] & HEADER_MASK) == HEADER_MASK 

136 data_size = i64(header[1] & ~HEADER_MASK) 

137 

138 if data_size == 126: 

139 header = self._tcp_read(2) 

140 data_size = (i64(header[0]) << 8 | i64(header[1])) 

141 elif data_size == 127: 

142 header = self._tcp_read(8) 

143 data_size = ((i64(header[4]) << 24) 

144 | (i64(header[5]) << 16) 

145 | (i64(header[6]) << 8) 

146 | i64(header[7])) 

147 

148 if masked: 

149 raise NotImplementedError() 

150 

151 data = self._tcp_read(data_size) 

152 

153 return op_code, data, fin 

154 

155 func _handle_close(self): 

156 print("Close.") 

157 

158 func _handle_ping(self, data: bytes): 

159 _send_frame(self._io, data, OpCode.Pong) 

160 

161 func _handle_pong(self): 

162 pass 

163 

164 func _handle_data(self, op_code: OpCode, data: bytes, fin: bool): 

165 if self._data is None: 

166 self._data_op_code = op_code 

167 self._data = data 

168 elif op_code == OpCode.Continuation: 

169 self._data += data 

170 else: 

171 self._data = None 

172 return 

173 

174 if not fin: 

175 return 

176 

177 match self._data_op_code: 

178 case OpCode.Text: 

179 self._handler.on_text(string(self._data)) 

180 case OpCode.Binary: 

181 self._handler.on_binary(self._data) 

182 

183 self._data = None 

184 

185 func _run(self): 

186 while True: 

187 op_code, data, fin = self._read_frame() 

188 

189 match op_code: 

190 case OpCode.Close: 

191 self._handle_close() 

192 case OpCode.Ping: 

193 self._handle_ping(data) 

194 case OpCode.Pong: 

195 self._handle_pong() 

196 case _: 

197 self._handle_data(op_code, data, fin) 

198 

199 func run(self): 

200 try: 

201 self._run() 

202 except _ClosedError: 

203 self._handler.on_disconnected() 

204 except CancelledError: 

205 pass 

206 

207class Client: 

208 """A websocket client, used to communicate with a websocket server. 

209 

210 """ 

211 

212 _handler: Handler 

213 _tcp_client: TcpClient 

214 _reader_fiber: _ReaderFiber 

215 _binary_queue: Queue[bytes] 

216 _text_queue: Queue[string] 

217 _stcp_client: StcpClient 

218 _secure: bool 

219 _io: _Io 

220 

221 func __init__(self, handler: Handler? = None, secure: bool = False): 

222 """Create a client. Give `handler` as ``None``` to use the default 

223 handler which puts received messages on message queues that 

224 are read from with the `receive_binary()` and `receive_text()` 

225 methods. 

226 

227 """ 

228 

229 # ToDo: Assign to self._binary_queue directly when Mys supports it. 

230 binary_queue = Queue[bytes]() 

231 text_queue = Queue[string]() 

232 self._binary_queue = binary_queue 

233 self._text_queue = text_queue 

234 

235 if handler is None: 

236 handler = _DefaultHandler(binary_queue, text_queue) 

237 

238 self._handler = handler 

239 self._secure = secure 

240 

241 if secure: 

242 self._stcp_client = StcpClient() 

243 self._io = _StcpIo(self._stcp_client) 

244 else: 

245 self._tcp_client = TcpClient() 

246 self._io = _TcpIo(self._tcp_client) 

247 

248 func _read_line(self) -> string: 

249 line = b"" 

250 

251 while True: 

252 byte = self._io.read(1) 

253 

254 if byte.length() != 1: 

255 raise WebsocketError("Handshake failed.") 

256 

257 line += byte 

258 

259 if line.length() < 2: 

260 continue 

261 

262 if line[-2] == u8('\r') and line[-1] == u8('\n'): 

263 break 

264 

265 return string(line)[:-2] 

266 

267 func connect(self, host: string, port: i64, path: string = "/"): 

268 """Connect to the server identified by given `host` and 

269 `port`. Non-secure websockets normally use port 80, while 

270 secure use port 443. 

271 

272 `path` in the path as sent in the HTTP request to the 

273 server. For example "/info/299?name=Kalle&date=2021-03-01". 

274 

275 """ 

276 

277 self._binary_queue.open() 

278 self._text_queue.open() 

279 

280 while self._binary_queue.length() > 0: 

281 self._binary_queue.get() 

282 

283 while self._text_queue.length() > 0: 

284 self._text_queue.get() 

285 

286 if self._secure: 

287 self._stcp_client.connect(host, port) 

288 else: 

289 self._tcp_client.connect(host, port) 

290 

291 sec_websocket_key = base64_encode(randbytes(16)) 

292 self._io.write(f"GET {path} HTTP/1.1\r\n" 

293 f"Host: {host}\r\n" 

294 "Upgrade: WebSocket\r\n" 

295 "Connection: Upgrade\r\n" 

296 f"Sec-WebSocket-Key: {sec_websocket_key}\r\n" 

297 "Origin: MysWebSocketClient\r\n" 

298 "Sec-WebSocket-Version: 13\r\n" 

299 "\r\n".to_utf8()) 

300 

301 while True: 

302 line = self._read_line() 

303 

304 if line == "": 

305 break 

306 

307 self._reader_fiber = _ReaderFiber(self._handler, self._io) 

308 self._reader_fiber.start() 

309 

310 func disconnect(self): 

311 """Disconnect from the server. 

312 

313 """ 

314 

315 self._reader_fiber.cancel() 

316 self._reader_fiber.join() 

317 

318 if self._secure: 

319 self._stcp_client.disconnect() 

320 else: 

321 self._tcp_client.disconnect() 

322 

323 func send_binary(self, data: bytes): 

324 """Send `data` to the server as a binary message. 

325 

326 This method never blocks, but instead enqueues the message if 

327 the OS would block the write. 

328 

329 """ 

330 

331 _send_frame(self._io, data, OpCode.Binary) 

332 

333 func send_text(self, data: string): 

334 """Send `data` to the server as a text message. 

335 

336 This method never blocks, but instead enqueues the message if 

337 the OS would block the write. 

338 

339 """ 

340 

341 _send_frame(self._io, data.to_utf8(), OpCode.Text) 

342 

343 func receive_binary(self) -> bytes: 

344 """Receive a binary message from the server. This method can only be 

345 used if no handler was passed to __init__(). Raises and error if 

346 disconnected. 

347 

348 """ 

349 

350 return self._binary_queue.get() 

351 

352 func receive_text(self) -> string: 

353 """Receive a text message from the server. This method can only be 

354 used if no handler was passed to __init__(). Raises and error if 

355 disconnected. 

356 

357 """ 

358 

359 return self._text_queue.get() 

360 

361class _ConnectionFiber(Fiber): 

362 server: Server 

363 

364 func run(self): 

365 client = self.server.accept() 

366 assert client.is_connected() 

367 client.send_text("Hi!") 

368 

369test connection(): 

370 port = 60101 

371 

372 server = Server() 

373 server.listen(port) 

374 server_fiber = _ConnectionFiber(server) 

375 server_fiber.start() 

376 

377 client = Client() 

378 client.connect("localhost", port) 

379 assert client.receive_text() == "Hi!" 

380 client.disconnect() 

381 

382 server_fiber.join()