Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from io.buffered_reader import BufferedReader 

2from net.tcp.client import Client as TcpClient 

3 

4class RedisError(Error): 

5 message: string 

6 

7trait Message: 

8 """A received Pub/Sub message. 

9 

10 """ 

11 

12class SubscribeMessage(Message): 

13 """A Pub/Sub subscribe reply message. 

14 

15 """ 

16 

17 channel: string 

18 number_of_subscriptions: i64 

19 

20class UnsubscribeMessage(Message): 

21 """An Pub/Sub unsubscribe reply message. 

22 

23 """ 

24 

25 channel: string 

26 number_of_subscriptions: i64 

27 

28class PublishMessage(Message): 

29 """A Pub/Sub published message. 

30 

31 """ 

32 

33 channel: string 

34 payload: bytes 

35 

36trait Reply: 

37 """A command reply. 

38 

39 """ 

40 

41class SimpleStringReply(Reply): 

42 """A simple string command reply. 

43 

44 """ 

45 

46 value: string 

47 

48class ErrorReply(Reply): 

49 """An error command reply. 

50 

51 """ 

52 

53 error: string 

54 message: string 

55 

56class IntegerReply(Reply): 

57 """An integer command reply. 

58 

59 """ 

60 

61 value: i64 

62 

63class BulkStringReply(Reply): 

64 """A bulk string command reply. 

65 

66 """ 

67 

68 value: bytes? 

69 

70class ArrayReply(Reply): 

71 """An array command reply. 

72 

73 """ 

74 

75 items: [Reply]? 

76 

77func _unpack_bulk_string(value: Reply, message: string) -> bytes?: 

78 match value: 

79 case BulkStringReply() as bulk_string_reply: 

80 return bulk_string_reply.value 

81 case _: 

82 raise RedisError(message) 

83 

84func _unpack_array(value: Reply, message: string) -> [Reply]?: 

85 match value: 

86 case ArrayReply() as array_reply: 

87 return array_reply.items 

88 case _: 

89 raise RedisError(message) 

90 

91func _unpack_integer(value: Reply, message: string) -> i64: 

92 match value: 

93 case IntegerReply() as integer_reply: 

94 return integer_reply.value 

95 case _: 

96 raise RedisError(message) 

97 

98func _unpack_simple_string(value: Reply, message: string) -> string: 

99 match value: 

100 case SimpleStringReply() as simple_string_reply: 

101 return simple_string_reply.value 

102 case _: 

103 raise RedisError(message) 

104 

105func _check_simple_string_reply(reply: string, expected: string): 

106 if reply != expected: 

107 raise RedisError( 

108 f"Unexpected simple string reply '{reply}' when expecting " 

109 f"'{expected}'.") 

110 

111func _unpack_publish(channel: Reply, payload: Reply) -> PublishMessage: 

112 return PublishMessage(string(_unpack_bulk_string(channel, "No channel.")), 

113 _unpack_bulk_string(payload, "No payload.")) 

114 

115func _unpack_subscribe(channel: Reply, count: Reply) -> SubscribeMessage: 

116 return SubscribeMessage(string(_unpack_bulk_string(channel, "No channel.")), 

117 _unpack_integer(count, "No count.")) 

118 

119func _unpack_unsubscribe(channel: Reply, count: Reply) -> UnsubscribeMessage: 

120 return UnsubscribeMessage(string(_unpack_bulk_string(channel, "No channel.")), 

121 _unpack_integer(count, "No count.")) 

122 

123class Client: 

124 """A Redis client. 

125 

126 """ 

127 

128 _client: TcpClient 

129 _buffered_reader: BufferedReader 

130 

131 func __init__(self): 

132 self._client = TcpClient() 

133 self._buffered_reader = BufferedReader(self._client) 

134 

135 func connect(self, host: string = "127.0.0.1", port: i64 = 6379): 

136 """Connect to given server. 

137 

138 """ 

139 

140 self._buffered_reader.clear() 

141 self._client.connect(host, port) 

142 

143 func disconnect(self): 

144 """Disconnect from the server. 

145 

146 """ 

147 

148 self._client.disconnect() 

149 

150 func auth(self, password: string): 

151 """Authenticate. 

152 

153 """ 

154 

155 self._call_simple_string_reply_ok([b"AUTH", password.to_utf8()]) 

156 

157 func del(self, key: string) -> i64: 

158 """Delete given key. 

159 

160 """ 

161 

162 self.del_write(key) 

163 

164 return self.del_read() 

165 

166 func set(self, key: string, value: bytes): 

167 """Set given value for given key. 

168 

169 """ 

170 

171 self.set_write(key, value) 

172 self.set_read() 

173 

174 func get(self, key: string) -> bytes?: 

175 """Get the value for given key. 

176 

177 """ 

178 

179 self.get_write(key) 

180 

181 return self.get_read() 

182 

183 func getdel(self, key: string) -> bytes?: 

184 """Get the value for given key and then delete it. 

185 

186 """ 

187 

188 return self.call_bulk_string_reply([b"GETDEL", key.to_utf8()]) 

189 

190 func append(self, key: string, value: bytes) -> i64: 

191 """Append given value for given key. 

192 

193 """ 

194 

195 return self.call_integer_reply([b"APPEND", key.to_utf8(), value]) 

196 

197 func incr(self, key: string) -> i64: 

198 """Increment the value for given key. Returns the value after the 

199 increment. 

200 

201 """ 

202 

203 return self.call_integer_reply([b"INCR", key.to_utf8()]) 

204 

205 func decr(self, key: string) -> i64: 

206 """Decrement the value for given key. Returns the value after the 

207 decrement. 

208 

209 """ 

210 

211 return self.call_integer_reply([b"DECR", key.to_utf8()]) 

212 

213 func strlen(self, key: string) -> i64: 

214 """Get the value length for given key. 

215 

216 """ 

217 

218 return self.call_integer_reply([b"STRLEN", key.to_utf8()]) 

219 

220 func scan(self, cursor: i64, pattern: string) -> (i64, [string]): 

221 """Scan for matching keys. 

222 

223 """ 

224 

225 reply = self.call( 

226 [b"SCAN", str(cursor).to_utf8(), b"MATCH", pattern.to_utf8()]) 

227 reply_items = _unpack_array(reply, "Bad SCAN reply.") 

228 

229 if reply_items.length() != 2: 

230 raise RedisError("Bad SCAN reply length.") 

231 

232 keys: [string] = [] 

233 cursor = i64(string(_unpack_bulk_string(reply_items[0], 

234 "Bad SCAN cursor."))) 

235 

236 for item in _unpack_array(reply_items[1], "Bad SCAN keys."): 

237 keys.append(string(_unpack_bulk_string(item, "Bad SCAN key."))) 

238 

239 return cursor, keys 

240 

241 func lpush(self, key: string, value: bytes) -> i64: 

242 """Prepend given value for given list key. 

243 

244 """ 

245 

246 return self.call_integer_reply([b"LPUSH", key.to_utf8(), value]) 

247 

248 func lpop(self, key: string) -> bytes?: 

249 """Pop the first value for given list key. 

250 

251 """ 

252 

253 return self.call_bulk_string_reply([b"LPOP", key.to_utf8()]) 

254 

255 func rpush(self, key: string, value: bytes) -> i64: 

256 """Append given value for given list key. 

257 

258 """ 

259 

260 return self.call_integer_reply([b"RPUSH", key.to_utf8(), value]) 

261 

262 func rpop(self, key: string) -> bytes?: 

263 """Pop the last value for given list key. 

264 

265 """ 

266 

267 return self.call_bulk_string_reply([b"RPOP", key.to_utf8()]) 

268 

269 func hset(self, key: string, field: string, value: bytes) -> i64: 

270 """Set given field to given value for given hash key. 

271 

272 """ 

273 

274 return self.call_integer_reply([b"HSET", key.to_utf8(), field.to_utf8(), value]) 

275 

276 func hget(self, key: string, field: string) -> bytes?: 

277 """Get the value for given field for given hash key. 

278 

279 """ 

280 

281 return self.call_bulk_string_reply([b"HGET", key.to_utf8(), field.to_utf8()]) 

282 

283 func hgetall(self, key: string) -> {string: bytes}: 

284 """Get all fields and valus for given hash key. 

285 

286 """ 

287 

288 reply = self.call([b"HGETALL", key.to_utf8()]) 

289 reply_items = _unpack_array(reply, "Bad hash type.") 

290 items: {string: bytes} = {} 

291 

292 for i in range(0, reply_items.length(), 2): 

293 field = string(_unpack_bulk_string(reply_items[i + 0], 

294 "Bad hash field type.")) 

295 value = _unpack_bulk_string(reply_items[i + 1], 

296 "Bad hash value type.") 

297 items[field] = value 

298 

299 return items 

300 

301 func hdel(self, key: string, field: string) -> i64: 

302 """Delete given field for given hash key. 

303 

304 """ 

305 

306 return self.call_integer_reply([b"HDEL", key.to_utf8(), field.to_utf8()]) 

307 

308 func publish(self, channel: string, message: bytes) -> i64: 

309 """Publish given message on given channel. Returns the number of 

310 clients that received the message. 

311 

312 """ 

313 

314 return self.call_integer_reply([b"PUBLISH", channel.to_utf8(), message]) 

315 

316 func subscribe(self, channel: string): 

317 """Subscribe to given channel. May be called from any fiber, even 

318 if another fiber is waiting for a message. 

319 

320 Call get_message() to get the next message. 

321 

322 Only a limited set of commands are allowed once in Pub/Sub mode. 

323 

324 """ 

325 

326 self._write_request([b"SUBSCRIBE", channel.to_utf8()]) 

327 

328 func unsubscribe(self, channel: string): 

329 """Unsubscribe from given channel. May be called from any fiber, even 

330 if another fiber is waiting for a message. 

331 

332 """ 

333 

334 self._write_request([b"UNSUBSCRIBE", channel.to_utf8()]) 

335 

336 func psubscribe(self, pattern: string): 

337 """Subscribe to given pattern. May be called from any fiber, even 

338 if another fiber is waiting for a message. 

339 

340 Call get_message() to get the next message. 

341 

342 Only a limited set of commands are allowed once in Pub/Sub mode. 

343 

344 """ 

345 

346 self._write_request([b"PSUBSCRIBE", pattern.to_utf8()]) 

347 

348 func punsubscribe(self, pattern: string): 

349 """Unsubscribe from given pattern. May be called from any fiber, even 

350 if another fiber is waiting for a message. 

351 

352 """ 

353 

354 self._write_request([b"PUNSUBSCRIBE", pattern.to_utf8()]) 

355 

356 func get_message(self) -> Message: 

357 """Get the next Pub/Sub message. Blocks until a message is received. 

358 

359 """ 

360 

361 items = _unpack_array(self._read_reply(), "Not a reply.") 

362 

363 if items.length() < 3: 

364 raise RedisError("Message too short.") 

365 

366 match _unpack_bulk_string(items[0], "No message kind."): 

367 case b"message": 

368 return _unpack_publish(items[1], items[2]) 

369 case b"pmessage": 

370 if items.length() < 4: 

371 raise RedisError("Message too short.") 

372 

373 return _unpack_publish(items[2], items[3]) 

374 case b"subscribe": 

375 return _unpack_subscribe(items[1], items[2]) 

376 case b"unsubscribe": 

377 return _unpack_unsubscribe(items[1], items[2]) 

378 case b"psubscribe": 

379 return _unpack_subscribe(items[1], items[2]) 

380 case b"punsubscribe": 

381 return _unpack_unsubscribe(items[1], items[2]) 

382 case _ as kind: 

383 raise RedisError(f"Invalid message kind '{kind}'.") 

384 

385 func call(self, command: [bytes]) -> Reply: 

386 """Call given command. 

387 

388 """ 

389 

390 self.call_write(command) 

391 

392 return self.call_read() 

393 

394 func call_integer_reply(self, command: [bytes]) -> i64: 

395 """Call given command and expect an integer (:) reply. 

396 

397 """ 

398 

399 self._write_request(command) 

400 

401 return _unpack_integer(self._read_reply_check_error(), 

402 "Not an integer reply.") 

403 

404 func call_bulk_string_reply(self, command: [bytes]) -> bytes?: 

405 """Call given command and expect a bulk string ($) reply. 

406 

407 """ 

408 

409 self._write_request(command) 

410 

411 return self._read_bulk_string_reply_check_error() 

412 

413 func call_simple_string_reply(self, command: [bytes]) -> string: 

414 """Call given command and expect a simple string (+) reply. 

415 

416 """ 

417 

418 self._write_request(command) 

419 

420 return self._read_simple_string_reply_check_error() 

421 

422 func del_write(self, key: string): 

423 """Write pipelined delete given key. 

424 

425 """ 

426 

427 self._write_request([b"DEL", key.to_utf8()]) 

428 

429 func del_read(self) -> i64: 

430 """Read pipelined delete given key. 

431 

432 """ 

433 

434 return self._read_integer_reply_check_error() 

435 

436 func set_write(self, key: string, value: bytes): 

437 """Write pipelined set given value for given key. 

438 

439 """ 

440 

441 self._write_request([b"SET", key.to_utf8(), value]) 

442 

443 func set_read(self): 

444 """Read pipelined set given value for given key. 

445 

446 """ 

447 

448 reply = self._read_simple_string_reply_check_error() 

449 _check_simple_string_reply(reply, "OK") 

450 

451 func get_write(self, key: string): 

452 """Write pipelined get the value for given key. 

453 

454 """ 

455 

456 self._write_request([b"GET", key.to_utf8()]) 

457 

458 func get_read(self) -> bytes?: 

459 """Read pipelined get the value for given key. 

460 

461 """ 

462 

463 return self._read_bulk_string_reply_check_error() 

464 

465 func call_write(self, command: [bytes]): 

466 """Write pipelined call given command. 

467 

468 """ 

469 

470 self._write_request(command) 

471 

472 func call_read(self) -> Reply: 

473 """Read pipelined call given command. 

474 

475 """ 

476 

477 return self._read_reply() 

478 

479 func _write_request(self, command: [bytes]): 

480 request = f"*{command.length()}\r\n".to_utf8() 

481 

482 for item in command: 

483 request += f"${item.length()}\r\n".to_utf8() 

484 request += item 

485 request += b"\r\n" 

486 

487 self._client.write(request) 

488 

489 func _read_reply(self) -> Reply: 

490 kind = self._buffered_reader.read(1) 

491 data = self._read_line() 

492 

493 match kind: 

494 case b"-": 

495 error, _, message = string(data).partition(" ") 

496 

497 return ErrorReply(error, message) 

498 case b"+": 

499 return SimpleStringReply(string(data)) 

500 case b":": 

501 return IntegerReply(i64(string(data))) 

502 case b"$": 

503 length = i64(string(data)) 

504 

505 if length == -1: 

506 return BulkStringReply(None) 

507 else: 

508 return BulkStringReply(self._read_line()) 

509 case b"*": 

510 length = i64(string(data)) 

511 

512 if length == -1: 

513 return ArrayReply(None) 

514 else: 

515 return ArrayReply([self._read_reply() for _ in range(length)]) 

516 case _: 

517 raise RedisError(f"Invalid reply kind '{kind}'.") 

518 

519 func _read_reply_check_error(self) -> Reply: 

520 reply = self._read_reply() 

521 

522 match reply: 

523 case ErrorReply() as error_reply: 

524 raise RedisError(f"{error_reply.error}: {error_reply.message}") 

525 

526 return reply 

527 

528 func _read_bulk_string_reply_check_error(self) -> bytes?: 

529 return _unpack_bulk_string(self._read_reply_check_error(), 

530 "Not a bulk string reply.") 

531 

532 func _read_simple_string_reply_check_error(self) -> string: 

533 return _unpack_simple_string(self._read_reply_check_error(), 

534 "Not a simple string reply.") 

535 

536 func _read_integer_reply_check_error(self) -> i64: 

537 return _unpack_integer(self._read_reply_check_error(), 

538 "Not an integer reply.") 

539 

540 func _call_simple_string_reply_ok(self, command: [bytes]): 

541 reply = self.call_simple_string_reply(command) 

542 _check_simple_string_reply(reply, "OK") 

543 

544 func _read_line(self) -> bytes: 

545 line = self._buffered_reader.read_until(b"\r\n", keep_pattern=False) 

546 

547 if line is None: 

548 raise RedisError("No line.") 

549 

550 return line 

551 

552test string(): 

553 client = Client() 

554 client.connect() 

555 client.del("foo") 

556 assert client.get("foo") is None 

557 client.set("foo", b"") 

558 assert client.del("foo") == 1 

559 client.set("foo", b"\x00\x01\x02") 

560 assert client.get("foo") == b"\x00\x01\x02" 

561 assert client.strlen("foo") == 3 

562 assert client.append("foo", b"\x03\x04") == 5 

563 assert client.get("foo") == b"\x00\x01\x02\x03\x04" 

564 assert client.getdel("foo") == b"\x00\x01\x02\x03\x04" 

565 assert client.getdel("foo") is None 

566 assert client.append("foo", b"\x00") == 1 

567 assert client.get("foo") == b"\x00" 

568 client.set("count", b"0") 

569 assert client.incr("count") == 1 

570 assert client.incr("count") == 2 

571 assert client.incr("count") == 3 

572 assert client.decr("count") == 2 

573 client.disconnect() 

574 

575test call(): 

576 client = Client() 

577 client.connect() 

578 

579 match client.call([b"SET", b"foo", b"\x31\x32"]): 

580 case SimpleStringReply() as reply: 

581 assert reply.value == "OK" 

582 case _: 

583 assert False 

584 

585 match client.call([b"DEL", b"bar"]): 

586 case IntegerReply(): 

587 pass 

588 case _: 

589 assert False 

590 

591 match client.call([b"HSET", b"bar", b"k1", b"v1"]): 

592 case IntegerReply() as reply: 

593 assert reply.value == 1 

594 case _: 

595 assert False 

596 

597 match client.call([b"HLEN", b"bar"]): 

598 case IntegerReply() as reply: 

599 assert reply.value == 1 

600 case _: 

601 assert False 

602 

603 client.disconnect() 

604 

605test list(): 

606 client = Client() 

607 client.connect() 

608 

609 client.del("bar") 

610 

611 assert client.lpush("bar", b"2") == 1 

612 assert client.lpush("bar", b"1") == 2 

613 assert client.rpush("bar", b"3") == 3 

614 

615 assert client.lpop("bar") == b"1" 

616 assert client.lpop("bar") == b"2" 

617 assert client.rpop("bar") == b"3" 

618 

619 assert client.lpop("bar") is None 

620 assert client.rpop("bar") is None 

621 

622 client.disconnect() 

623 

624test hash(): 

625 client = Client() 

626 client.connect() 

627 

628 client.del("fie") 

629 

630 assert client.hget("fie", "a") is None 

631 

632 assert client.hset("fie", "a", b"x") == 1 

633 assert client.hset("fie", "b", b"y") == 1 

634 assert client.hset("fie", "c", b"z") == 1 

635 

636 assert client.hget("fie", "a") == b"x" 

637 assert client.hget("fie", "b") == b"y" 

638 assert client.hget("fie", "c") == b"z" 

639 

640 items = client.hgetall("fie") 

641 assert items.length() == 3 

642 assert items["a"] == b"x" 

643 assert items["b"] == b"y" 

644 assert items["c"] == b"z" 

645 

646 assert client.hdel("fie", "b") == 1 

647 assert client.hdel("fie", "b") == 0 

648 

649 assert client.hget("fie", "a") == b"x" 

650 assert client.hget("fie", "b") is None 

651 assert client.hget("fie", "c") == b"z" 

652 

653 client.disconnect() 

654 

655test auth(): 

656 client = Client() 

657 client.connect() 

658 

659 try: 

660 message = "" 

661 client.auth("pass") 

662 except RedisError as error: 

663 message = error.message 

664 

665 assert "ERR: AUTH" in message 

666 

667 client.disconnect() 

668 

669test pub_sub(): 

670 subscriber = Client() 

671 publisher = Client() 

672 

673 subscriber.connect() 

674 publisher.connect() 

675 

676 subscriber.subscribe("foo") 

677 

678 match subscriber.get_message(): 

679 case SubscribeMessage() as subscribe_message: 

680 assert subscribe_message.channel == "foo" 

681 assert subscribe_message.number_of_subscriptions == 1 

682 case _: 

683 assert False 

684 

685 publisher.publish("foo", b"bar") 

686 

687 match subscriber.get_message(): 

688 case PublishMessage() as publish_message: 

689 assert publish_message.channel == "foo" 

690 assert publish_message.payload == b"bar" 

691 case _: 

692 assert False 

693 

694 subscriber.psubscribe("foo*") 

695 

696 match subscriber.get_message(): 

697 case SubscribeMessage() as subscribe_message: 

698 assert subscribe_message.channel == "foo*" 

699 assert subscribe_message.number_of_subscriptions == 2 

700 case _: 

701 assert False 

702 

703 subscriber.unsubscribe("foo") 

704 

705 match subscriber.get_message(): 

706 case UnsubscribeMessage() as unsubscribe_message: 

707 assert unsubscribe_message.channel == "foo" 

708 assert unsubscribe_message.number_of_subscriptions == 1 

709 case _: 

710 assert False 

711 

712 publisher.publish("foo123", b"bar987") 

713 

714 match subscriber.get_message(): 

715 case PublishMessage() as publish_message: 

716 assert publish_message.channel == "foo123" 

717 assert publish_message.payload == b"bar987" 

718 case _: 

719 assert False 

720 

721 subscriber.punsubscribe("foo*") 

722 

723 match subscriber.get_message(): 

724 case UnsubscribeMessage() as unsubscribe_message: 

725 assert unsubscribe_message.channel == "foo*" 

726 assert unsubscribe_message.number_of_subscriptions == 0 

727 case _: 

728 assert False 

729 

730 subscriber.disconnect() 

731 publisher.disconnect() 

732 

733test wrong_reply(): 

734 client = Client() 

735 client.connect() 

736 

737 try: 

738 message = "" 

739 client.call_integer_reply([b"GET", b"foo"]) 

740 except RedisError as error: 

741 message = error.message 

742 

743 assert message == "Not an integer reply." 

744 

745 try: 

746 message = "" 

747 client.call_bulk_string_reply([b"DEL", b"foo"]) 

748 except RedisError as error: 

749 message = error.message 

750 

751 assert message == "Not a bulk string reply." 

752 

753 try: 

754 message = "" 

755 client.call_simple_string_reply([b"DEL", b"foo"]) 

756 except RedisError as error: 

757 message = error.message 

758 

759 assert message == "Not a simple string reply." 

760 

761 client.disconnect() 

762 

763test string_pipeline(): 

764 client = Client() 

765 client.connect() 

766 client.del("foo") 

767 

768 client.set_write("foo", b"") 

769 client.del_write("foo") 

770 client.set_write("foo", b"\x00\x01\x02") 

771 client.get_write("foo") 

772 

773 client.set_read() 

774 assert client.del_read() == 1 

775 client.set_read() 

776 assert client.get_read() == b"\x00\x01\x02" 

777 

778 # assert client.strlen("foo") == 3 

779 # assert client.append("foo", b"\x03\x04") == 5 

780 # assert client.get("foo") == b"\x00\x01\x02\x03\x04" 

781 # assert client.getdel("foo") == b"\x00\x01\x02\x03\x04" 

782 # assert client.append("foo", b"\x00") == 1 

783 # assert client.get("foo") == b"\x00" 

784 # client.set("count", b"0") 

785 # assert client.incr("count") == 1 

786 # assert client.incr("count") == 2 

787 # assert client.incr("count") == 3 

788 # assert client.decr("count") == 2 

789 client.disconnect() 

790 

791test connect_disconnect(): 

792 client = Client() 

793 client.connect() 

794 client.disconnect() 

795 client.disconnect() 

796 client.connect() 

797 client.connect()