Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from fiber import CancelledError 

2from fiber import Fiber 

3from fiber import current 

4from fiber import sleep 

5from fiber import suspend 

6from io import Reader 

7from io import Writer 

8from .. import NetError 

9from .server import Server 

10 

11c""" 

12class WriteRequest { 

13public: 

14 uv_buf_t m_buf; 

15 uv_write_t m_request; 

16 Bytes m_data; 

17 Client *m_client; 

18}; 

19 

20static void on_close_complete(uv_handle_t *handle_p) 

21{ 

22 Client *client_p = (Client *)(handle_p->data); 

23 

24 resume(client_p->_fiber); 

25} 

26 

27static void on_shutdown_complete(uv_shutdown_t *request_p, int status) 

28{ 

29 Client *client_p = (Client *)(request_p->data); 

30 

31 uv_close((uv_handle_t*)&client_p->m_socket, on_close_complete); 

32} 

33 

34static void on_connect_complete(uv_connect_t *request_p, int status) 

35{ 

36 Client *client_p = (Client *)(request_p->data); 

37 

38 client_p->_status = status; 

39 

40 if (status == 0) { 

41 resume(client_p->_fiber); 

42 } else { 

43 uv_shutdown(&client_p->m_shutdown, 

44 (uv_stream_t *)&client_p->m_socket, 

45 on_shutdown_complete); 

46 } 

47} 

48 

49static void on_read_complete(uv_stream_t *request_p, 

50 ssize_t nread, 

51 const uv_buf_t* buf_p) 

52{ 

53 Client *client_p = (Client *)(request_p->data); 

54 bool completed = false; 

55 

56 if (nread > 0) { 

57 client_p->_read_offset += nread; 

58 

59 if (client_p->_read_offset == client_p->_read_data.m_bytes->size()) { 

60 completed = true; 

61 } 

62 } else if (nread < 0) { 

63 completed = true; 

64 client_p->_read_data.m_bytes->resize(client_p->_read_offset); 

65 client_p->_error = true; 

66 } 

67 

68 if (completed) { 

69 uv_read_stop(request_p); 

70 resume(client_p->_fiber); 

71 } 

72} 

73 

74static void read_alloc(uv_handle_t *handle_p, size_t size, uv_buf_t *buf_p) 

75{ 

76 Client *client_p = (Client *)(handle_p->data); 

77 

78 buf_p->base = ((char *)client_p->_read_data.m_bytes->data() 

79 + client_p->_read_offset); 

80 buf_p->len = client_p->_read_data.m_bytes->size() - client_p->_read_offset; 

81} 

82 

83static void on_write_complete(uv_write_t *request_p, int status) 

84{ 

85 WriteRequest *write_request_p = (WriteRequest *)(request_p->data); 

86 

87 if (status != 0) { 

88 write_request_p->m_client->_error = true; 

89 } 

90 

91 delete write_request_p; 

92} 

93 

94static void on_getaddrinfo_complete(uv_getaddrinfo_t *resolver_p, 

95 int status, 

96 struct addrinfo *info_p) 

97{ 

98 Client *client_p = (Client *)(resolver_p->data); 

99 

100 if (status < 0) { 

101 fprintf(stderr, "getaddrinfo callback error %s\n", uv_err_name(status)); 

102 return; 

103 } 

104 

105 client_p->m_connect.data = client_p; 

106 uv_tcp_connect(&client_p->m_connect, 

107 &client_p->m_socket, 

108 (const struct sockaddr*)info_p->ai_addr, 

109 on_connect_complete); 

110 uv_freeaddrinfo(info_p); 

111} 

112""" 

113 

114class Client(Reader): 

115 c""" 

116 uv_tcp_t m_socket; 

117 uv_connect_t m_connect; 

118 struct addrinfo m_hints; 

119 uv_getaddrinfo_t m_resolver; 

120 uv_buf_t m_buf; 

121 uv_shutdown_t m_shutdown; 

122 """ 

123 _read_data: bytes? 

124 _read_offset: i64 

125 _fiber: Fiber? 

126 _status: i32 

127 _connected: bool 

128 _error: bool 

129 

130 func __init__(self): 

131 c""" 

132 m_socket.data = this; 

133 m_shutdown.data = this; 

134 """ 

135 

136 self._fiber = None 

137 self._connected = False 

138 self._error = False 

139 

140 func __del__(self): 

141 self.disconnect() 

142 

143 func _ensure_one_caller(self): 

144 if self._fiber is not None: 

145 raise NetError("Only one fiber may perform blocking operations.") 

146 

147 func _wait_for_completion(self): 

148 self._fiber = current() 

149 suspend() 

150 self._fiber = None 

151 

152 func is_connected(self) -> bool: 

153 """Returns true if conencted to the server, false otherwise. 

154 

155 """ 

156 

157 return self._connected and not self._error 

158 

159 func connect(self, host: string, port: i64): 

160 """Connect to a server using given `host` and `port`. Reconnects 

161 if already connected. 

162 

163 """ 

164 

165 self.disconnect() 

166 self._ensure_one_caller() 

167 

168 host_utf8 = host.to_utf8() 

169 host_utf8 += 0 

170 port_utf8 = str(port).to_utf8() 

171 port_utf8 += 0 

172 

173 c""" 

174 uv_tcp_init(uv_default_loop(), &m_socket); 

175 m_resolver.data = this; 

176 

177 m_hints.ai_family = PF_INET; 

178 m_hints.ai_socktype = SOCK_STREAM; 

179 m_hints.ai_protocol = IPPROTO_TCP; 

180 m_hints.ai_flags = 0; 

181 

182 uv_getaddrinfo(uv_default_loop(), 

183 &m_resolver, 

184 on_getaddrinfo_complete, 

185 (const char *)host_utf8.m_bytes->data(), 

186 (const char *)port_utf8.m_bytes->data(), 

187 &m_hints); 

188 """ 

189 

190 self._wait_for_completion() 

191 

192 if self._status != 0: 

193 raise NetError(f"Connect to {host}:{port} failed.") 

194 

195 self._connected = True 

196 

197 func disconnect(self): 

198 """Disconnect from the server. 

199 

200 """ 

201 

202 if not self._connected: 

203 return 

204 

205 self._ensure_one_caller() 

206 

207 c""" 

208 uv_shutdown(&m_shutdown, (uv_stream_t *)&m_socket, on_shutdown_complete); 

209 """ 

210 

211 self._wait_for_completion() 

212 self._error = False 

213 self._connected = False 

214 

215 func write(self, data: bytes): 

216 """Write data to the server. Never blocks. Raises an error if disconnected. 

217 

218 """ 

219 

220 if not self.is_connected(): 

221 raise NetError("Not connected.") 

222 

223 c""" 

224 WriteRequest *request_p = new WriteRequest(); 

225 request_p->m_buf = uv_buf_init((char *)data.m_bytes->data(), 

226 data.m_bytes->size()); 

227 request_p->m_request.data = request_p; 

228 request_p->m_data = data; 

229 request_p->m_client = this; 

230 uv_write(&request_p->m_request, 

231 (uv_stream_s *)&m_socket, 

232 &request_p->m_buf, 

233 1, 

234 on_write_complete); 

235 """ 

236 

237 func read(self, size: i64) -> bytes: 

238 """Read data from the server. Always returns size number of bytes, 

239 unless the connection was closed, in which case the remaining 

240 data is returned. 

241 

242 """ 

243 

244 if not self.is_connected(): 

245 return b"" 

246 

247 self._ensure_one_caller() 

248 

249 self._read_offset = 0; 

250 self._read_data = bytes(size) 

251 

252 c""" 

253 uv_read_start((uv_stream_t *)&m_socket, read_alloc, on_read_complete); 

254 """ 

255 

256 self._fiber = current() 

257 

258 try: 

259 suspend() 

260 except CancelledError: 

261 c"uv_read_stop((uv_stream_t *)&m_socket);" 

262 self._read_data = None 

263 raise 

264 finally: 

265 self._fiber = None 

266 

267 data = self._read_data 

268 self._read_data = None 

269 

270 return data 

271 

272class _ServerCommunicationFiber(Fiber): 

273 server: Server 

274 

275 func run(self): 

276 client = self.server.accept() 

277 assert client.is_connected() 

278 

279 assert client.read(1) == b"1" 

280 assert client.read(9) == b"234567890" 

281 client.write(b"0") 

282 client.write(b"987654321") 

283 

284 data = bytes(10) 

285 offset = 0 

286 left = data.length() 

287 

288 while left > 0: 

289 size = client.try_read_into(data, offset, left) 

290 assert size > 0 

291 offset += size 

292 left -= size 

293 

294 assert data == b"1234567890" 

295 client.write(b"0") 

296 client.write(b"987654321") 

297 

298 assert client.read(1) == b"" 

299 assert not client.is_connected() 

300 

301test server_communication(): 

302 port = 50222 

303 

304 server = Server() 

305 server.listen(port) 

306 server_fiber = _ServerCommunicationFiber(server) 

307 server_fiber.start() 

308 

309 client = Client() 

310 assert not client.is_connected() 

311 client.connect("localhost", port) 

312 assert client.is_connected() 

313 

314 client.write(b"123456789") 

315 client.write(b"0") 

316 assert client.read(9) == b"098765432" 

317 assert client.read(1) == b"1" 

318 

319 client.write(b"123456789") 

320 # To make try_read_into() not read all at once. 

321 sleep(0.2) 

322 client.write(b"0") 

323 assert client.read(9) == b"098765432" 

324 assert client.read(1) == b"1" 

325 

326 client.disconnect() 

327 assert not client.is_connected() 

328 server_fiber.join() 

329 

330test connection_refused(): 

331 port = 50223 

332 

333 client = Client() 

334 

335 try: 

336 message = "" 

337 client.connect("localhost", port) 

338 except NetError as e: 

339 message = e.message 

340 

341 assert message == "Connect to localhost:50223 failed." 

342 assert not client.is_connected() 

343 

344class _ClientReadCloseFiber(Fiber): 

345 server: Server 

346 

347 func run(self): 

348 client = self.server.accept() 

349 assert client.is_connected() 

350 client.write(b"1234567890") 

351 client.disconnect() 

352 

353test client_read_close(): 

354 port = 50224 

355 

356 server = Server() 

357 server.listen(port) 

358 server_fiber = _ClientReadCloseFiber(server) 

359 server_fiber.start() 

360 

361 client = Client() 

362 client.connect("localhost", port) 

363 assert client.read(20) == b"1234567890" 

364 assert client.read(1) == b"" 

365 client.disconnect() 

366 server_fiber.join() 

367 

368class _ClientReUseFiber(Fiber): 

369 server: Server 

370 

371 func run(self): 

372 for _ in range(10): 

373 client = self.server.accept() 

374 assert client.read(1) == b"" 

375 client.disconnect() 

376 client.disconnect() 

377 client.disconnect() 

378 client = self.server.accept() 

379 assert client.read(1) == b"" 

380 client.disconnect() 

381 client = self.server.accept() 

382 assert client.read(1) == b"" 

383 client.disconnect() 

384 

385test client_re_use_client(): 

386 port = 50225 

387 

388 server = Server() 

389 server.listen(port) 

390 server_fiber = _ClientReUseFiber(server) 

391 server_fiber.start() 

392 

393 client = Client() 

394 

395 for i in range(10): 

396 client.connect("localhost", port) 

397 client.disconnect() 

398 client.disconnect() 

399 client.disconnect() 

400 client.connect("localhost", port) 

401 client.connect("localhost", port) 

402 client.disconnect() 

403 

404 server_fiber.join() 

405 

406class _ClientCancelFiber(Fiber): 

407 port: i64 

408 

409 func run(self): 

410 client = Client() 

411 client.connect("localhost", self.port) 

412 client.write(b"1") 

413 

414 try: 

415 client.read(1) 

416 except CancelledError: 

417 client.write(b"2") 

418 assert client.read(1) == b"3" 

419 

420test client_cancel(): 

421 port = 50226 

422 

423 server = Server() 

424 server.listen(port) 

425 

426 client_fiber = _ClientCancelFiber(port) 

427 client_fiber.start() 

428 

429 client = server.accept() 

430 assert client.read(1) == b"1" 

431 client_fiber.cancel() 

432 assert client.read(1) == b"2" 

433 client.write(b"3") 

434 client_fiber.join()