Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from os.path import Path 

2 

3c"""header-before-namespace 

4#include <sqlite3.h> 

5#include <string.h> 

6""" 

7 

8enum _Result: 

9 Ok = c"SQLITE_OK" 

10 Row = c"SQLITE_ROW" 

11 Done = c"SQLITE_DONE" 

12 

13enum Type(u8): 

14 """Value type. 

15 

16 """ 

17 

18 Integer = c"SQLITE_INTEGER" 

19 Float = c"SQLITE_FLOAT" 

20 String = c"SQLITE_TEXT" 

21 Null = c"SQLITE_NULL" 

22 Bytes = c"SQLITE_BLOB" 

23 

24class SqlError(Error): 

25 message: string 

26 

27func _to_utf8(value: string) -> bytes: 

28 value_utf8 = value.to_utf8() 

29 value_utf8 += 0 

30 

31 return value_utf8 

32 

33class Database: 

34 c"sqlite3 *database_p;" 

35 

36 func __init__(self, path: Path): 

37 """Create or open a database. 

38 

39 """ 

40 

41 path_utf8 = _to_utf8(str(path)) 

42 res = _Result.Ok 

43 

44 c""" 

45 res = sqlite3_open((const char *)path_utf8.m_bytes->data(), &this->database_p); 

46 """ 

47 

48 if res != _Result.Ok: 

49 raise SqlError("failed to open the database") 

50 

51 func __del__(self): 

52 c"sqlite3_close(this->database_p);" 

53 

54 func execute(self, sql: string): 

55 """Execute given statement. 

56 

57 """ 

58 

59 sql_utf8 = _to_utf8(sql) 

60 res = _Result.Ok 

61 

62 c""" 

63 res = sqlite3_exec(this->database_p, 

64 (const char *)sql_utf8.m_bytes->data(), 

65 NULL, 

66 NULL, 

67 NULL); 

68 """ 

69 

70 if res != _Result.Ok: 

71 raise SqlError(_create_error_message(self)) 

72 

73 func prepare(self, sql: string) -> Statement: 

74 """Prepare a statement. Safer than execute(), and faster if used more 

75 than once. 

76 

77 """ 

78 

79 return Statement(sql, self) 

80 

81func _create_error_message(database: Database) -> string: 

82 message: string? = None 

83 

84 c"message = String(sqlite3_errmsg(database->database_p));" 

85 

86 return message 

87 

88class Statement: 

89 database: Database 

90 _number_of_columns: u32 

91 c"sqlite3_stmt *stmt_p;" 

92 

93 func __init__(self, sql: string, database: Database): 

94 self.database = database 

95 self._number_of_columns = 0 

96 sql_utf8 = _to_utf8(sql) 

97 res = _Result.Ok 

98 

99 c""" 

100 res = sqlite3_prepare(database->database_p, 

101 (const char *)sql_utf8.m_bytes->data(), 

102 -1, 

103 &this->stmt_p, 

104 NULL); 

105 """ 

106 

107 if res != _Result.Ok: 

108 raise SqlError(_create_error_message(database)) 

109 

110 func __del__(self): 

111 c"sqlite3_finalize(this->stmt_p);" 

112 

113 func bind_int(self, column: u32, value: i64): 

114 """Bind given integer to given column. 

115 

116 """ 

117 

118 res = _Result.Ok 

119 

120 c"res = sqlite3_bind_int64(this->stmt_p, column, value);" 

121 

122 if res != _Result.Ok: 

123 raise SqlError(_create_error_message(self.database)) 

124 

125 func bind_float(self, column: u32, value: f64): 

126 """Bind given float to given column. 

127 

128 """ 

129 

130 res = _Result.Ok 

131 

132 c"res = sqlite3_bind_double(this->stmt_p, column, value);" 

133 

134 if res != _Result.Ok: 

135 raise SqlError(_create_error_message(self.database)) 

136 

137 func bind_string(self, column: u32, value: string): 

138 """Bind given string to given column. 

139 

140 """ 

141 

142 value_utf8 = _to_utf8(value) 

143 res = _Result.Ok 

144 

145 c""" 

146 res = sqlite3_bind_text(this->stmt_p, 

147 column, 

148 (const char *)value_utf8.m_bytes->data(), 

149 -1, 

150 SQLITE_TRANSIENT); 

151 """ 

152 

153 if res != _Result.Ok: 

154 raise SqlError(_create_error_message(self.database)) 

155 

156 func bind_bytes(self, column: u32, value: bytes): 

157 """Bind given bytes to given column. 

158 

159 """ 

160 

161 res = _Result.Ok 

162 

163 c""" 

164 res = sqlite3_bind_blob(this->stmt_p, 

165 column, 

166 (const char *)value.m_bytes->data(), 

167 value.m_bytes->size(), 

168 SQLITE_TRANSIENT); 

169 """ 

170 

171 if res != _Result.Ok: 

172 raise SqlError(_create_error_message(self.database)) 

173 

174 func bind_null(self, column: u32): 

175 """Bind null to given column. 

176 

177 """ 

178 

179 res = _Result.Ok 

180 

181 c"res = sqlite3_bind_null(this->stmt_p, column);" 

182 

183 if res != _Result.Ok: 

184 raise SqlError(_create_error_message(self.database)) 

185 

186 func execute(self): 

187 """Execute the statement. Bind any values to columns before executing 

188 it. Calls reset() once complete. 

189 

190 """ 

191 

192 result = self._step() 

193 

194 try: 

195 if result != _Result.Done: 

196 message = _create_error_message(self.database) 

197 

198 raise SqlError(_create_error_message(self.database)) 

199 finally: 

200 self.reset() 

201 

202 func fetch(self) -> bool: 

203 """Fetch the next row from the database. Returns True if a row was 

204 fetched, or calls reset() and returns False when there are no 

205 more rows available. If fetched, get column values with 

206 column_*() methods. 

207 

208 """ 

209 

210 result = self._step() 

211 

212 if result == _Result.Row: 

213 c"this->_number_of_columns = sqlite3_data_count(this->stmt_p); " 

214 

215 return True 

216 else: 

217 self.reset() 

218 

219 if result == _Result.Done: 

220 return False 

221 else: 

222 raise SqlError("fetch") 

223 

224 func _step(self) -> _Result: 

225 res = 0 

226 

227 c"res = sqlite3_step(this->stmt_p);" 

228 

229 return _Result(res) 

230 

231 func reset(self): 

232 """Reset the statement so it can be used again. 

233 

234 """ 

235 

236 self._number_of_columns = 0 

237 

238 c"sqlite3_reset(this->stmt_p);" 

239 

240 func column_type(self, column: u32) -> Type: 

241 """Get the type of given column. 

242 

243 """ 

244 

245 if column >= self._number_of_columns: 

246 raise SqlError(f"bad column {column}") 

247 

248 ctype: u8 = 0 

249 

250 c"ctype = sqlite3_column_type(this->stmt_p, column);" 

251 

252 return Type(ctype) 

253 

254 func column_int(self, column: u32) -> i64: 

255 """Get the value of given column as an integer. 

256 

257 """ 

258 

259 if column >= self._number_of_columns: 

260 raise SqlError(f"bad column {column}") 

261 

262 value = 0 

263 

264 c"value = sqlite3_column_int64(this->stmt_p, column);" 

265 

266 return value 

267 

268 func column_float(self, column: u32) -> f64: 

269 """Get the value of given column as a float. 

270 

271 """ 

272 

273 if column >= self._number_of_columns: 

274 raise SqlError(f"bad column {column}") 

275 

276 value = 0.0 

277 

278 c"value = sqlite3_column_double(this->stmt_p, column);" 

279 

280 return value 

281 

282 func column_string(self, column: u32) -> string?: 

283 """Get the value of given column as a string. 

284 

285 """ 

286 

287 if column >= self._number_of_columns: 

288 raise SqlError(f"bad column {column}") 

289 

290 value: bytes? = b"" 

291 

292 c""" 

293 const unsigned char *value_p = sqlite3_column_text(this->stmt_p, column); 

294 

295 if (value_p != NULL) { 

296 for (size_t i = 0; i < strlen((const char *)value_p); i++) { 

297 value += value_p[i]; 

298 } 

299 } else { 

300 value = nullptr; 

301 } 

302 """ 

303 

304 if value is None: 

305 return None 

306 

307 return string(value) 

308 

309 func column_bytes(self, column: u32) -> bytes: 

310 """Get the value of given column as bytes. 

311 

312 """ 

313 

314 if column >= self._number_of_columns: 

315 raise SqlError(f"bad column {column}") 

316 

317 value: bytes? = None 

318 

319 c""" 

320 value = Bytes(sqlite3_column_bytes(this->stmt_p, column)); 

321 memcpy(value.m_bytes->data(), 

322 sqlite3_column_blob(this->stmt_p, column), 

323 value.m_bytes->size()); 

324 """ 

325 

326 return value 

327 

328 func column_value_string(self, column: u32) -> string: 

329 """Get given columns value as a string. 

330 

331 """ 

332 

333 column_type = self.column_type(column) 

334 

335 if column_type == Type.Integer: 

336 return str(self.column_int(column)) 

337 elif column_type == Type.Float: 

338 return str(self.column_float(column)) 

339 elif column_type == Type.String: 

340 return f"\"{self.column_string(column)}\"" 

341 elif column_type == Type.Bytes: 

342 return str(self.column_bytes(column)) 

343 elif column_type == Type.Null: 

344 return "null" 

345 else: 

346 raise SqlError(f"invalid column type {column_type}") 

347 

348test basics(): 

349 Path("the.db").rm(force=True) 

350 

351 database = Database(Path("the.db")) 

352 

353 database.execute("CREATE TABLE tab(foo, bar, baz)") 

354 database.execute("INSERT INTO tab VALUES(1, 'one', null)") 

355 database.execute("INSERT INTO tab VALUES(2, 2.2, 'two')") 

356 database.execute("INSERT INTO tab VALUES(3, 'three', null)") 

357 database.execute("INSERT INTO tab VALUES(4, X'89', null)") 

358 

359 statement = database.prepare("SELECT * FROM tab WHERE foo >= ? ORDER BY foo") 

360 statement.bind_int(1, 2) 

361 

362 assert statement.fetch() 

363 assert statement.column_type(0) == Type.Integer 

364 assert statement.column_int(0) == 2 

365 assert statement.column_type(1) == Type.Float 

366 assert statement.column_float(1) == 2.2 

367 assert statement.column_type(2) == Type.String 

368 assert statement.column_string(2) == "two" 

369 

370 assert statement.fetch() 

371 assert statement.column_type(0) == Type.Integer 

372 assert statement.column_int(0) == 3 

373 assert statement.column_type(1) == Type.String 

374 assert statement.column_string(1) == "three" 

375 assert statement.column_type(2) == Type.Null 

376 

377 assert statement.fetch() 

378 assert statement.column_type(0) == Type.Integer 

379 assert statement.column_int(0) == 4 

380 assert statement.column_type(1) == Type.Bytes 

381 assert statement.column_bytes(1) == b"\x89" 

382 assert statement.column_type(2) == Type.Null 

383 

384 assert not statement.fetch() 

385 

386 Path("the.db").rm(force=True) 

387 

388test advanced(): 

389 Path("the.db").rm(force=True) 

390 

391 database = Database(Path("the.db")) 

392 

393 database.execute("CREATE TABLE tab(foo, bar, baz)") 

394 database.execute("INSERT INTO tab VALUES(1, 'one', null)") 

395 database.execute("INSERT INTO tab VALUES(2, 2.2, 'two')") 

396 

397 statement = database.prepare("INSERT INTO tab VALUES(?, ?, ?)") 

398 

399 statement.bind_int(1, 3) 

400 statement.bind_string(2, "three") 

401 statement.bind_int(3, 333) 

402 statement.execute() 

403 

404 statement.bind_int(1, 4) 

405 statement.bind_string(2, "four") 

406 statement.bind_null(3) 

407 statement.execute() 

408 

409 statement.bind_int(1, 5) 

410 statement.bind_bytes(2, b"\x12\x34") 

411 statement.bind_null(3) 

412 statement.execute() 

413 

414 statement = database.prepare("SELECT * FROM tab WHERE foo >= ? ORDER BY foo") 

415 statement.bind_int(1, 2) 

416 

417 assert statement.fetch() 

418 assert statement.column_type(0) == Type.Integer 

419 assert statement.column_int(0) == 2 

420 assert statement.column_type(1) == Type.Float 

421 assert statement.column_float(1) == 2.2 

422 assert statement.column_type(2) == Type.String 

423 assert statement.column_string(2) == "two" 

424 

425 assert statement.fetch() 

426 assert statement.column_type(0) == Type.Integer 

427 assert statement.column_int(0) == 3 

428 assert statement.column_type(1) == Type.String 

429 assert statement.column_string(1) == "three" 

430 assert statement.column_type(2) == Type.Integer 

431 assert statement.column_int(2) == 333 

432 

433 assert statement.fetch() 

434 assert statement.column_type(0) == Type.Integer 

435 assert statement.column_int(0) == 4 

436 assert statement.column_type(1) == Type.String 

437 assert statement.column_string(1) == "four" 

438 assert statement.column_type(2) == Type.Null 

439 

440 assert statement.fetch() 

441 assert statement.column_type(0) == Type.Integer 

442 assert statement.column_int(0) == 5 

443 assert statement.column_type(1) == Type.Bytes 

444 assert statement.column_bytes(1) == b"\x12\x34" 

445 assert statement.column_type(2) == Type.Null 

446 

447 assert not statement.fetch() 

448 

449 Path("the.db").rm(force=True) 

450 

451test try_to_create_existing_table(): 

452 Path("the.db").rm(force=True) 

453 

454 database = Database(Path("the.db")) 

455 

456 database.execute("CREATE TABLE tab(foo, bar, baz)") 

457 message: string? = None 

458 

459 try: 

460 database.execute("CREATE TABLE tab(foo, bar, baz)") 

461 except SqlError as e: 

462 message = e.message 

463 

464 assert message == "table tab already exists" 

465 

466 Path("the.db").rm(force=True) 

467 

468test prepare_bad_statement(): 

469 Path("the.db").rm(force=True) 

470 

471 database = Database(Path("the.db")) 

472 

473 try: 

474 message = "" 

475 database.prepare("FOOBAR 123") 

476 except SqlError as e: 

477 message = e.message 

478 

479 assert message == "near \"FOOBAR\": syntax error" 

480 

481 Path("the.db").rm(force=True) 

482 

483test bad_column(): 

484 Path("the.db").rm(force=True) 

485 

486 database = Database(Path("the.db")) 

487 

488 database.execute("CREATE TABLE tab(foo, bar, baz)") 

489 database.execute("INSERT INTO tab VALUES(1, 'one', null)") 

490 

491 statement = database.prepare("SELECT * FROM tab") 

492 

493 assert statement.fetch() 

494 assert statement.column_type(2) == Type.Null 

495 

496 try: 

497 message = "" 

498 statement.column_type(3) 

499 except SqlError as e: 

500 message = e.message 

501 

502 assert message == "bad column 3" 

503 

504 try: 

505 message = "" 

506 statement.column_int(4) 

507 except SqlError as e: 

508 message = e.message 

509 

510 assert message == "bad column 4" 

511 

512 try: 

513 message = "" 

514 statement.column_string(10) 

515 except SqlError as e: 

516 message = e.message 

517 

518 assert message == "bad column 10" 

519 

520 try: 

521 message = "" 

522 statement.column_float(3) 

523 except SqlError as e: 

524 message = e.message 

525 

526 assert message == "bad column 3" 

527 

528 try: 

529 message = "" 

530 statement.column_bytes(3) 

531 except SqlError as e: 

532 message = e.message 

533 

534 assert message == "bad column 3" 

535 

536 Path("the.db").rm(force=True) 

537 

538test column_value_string(): 

539 Path("the.db").rm(force=True) 

540 

541 database = Database(Path("the.db")) 

542 

543 database.execute("CREATE TABLE tab(foo, bar, baz, bak, bat)") 

544 database.execute("INSERT INTO tab VALUES(1, 'one', null, 1.0, X'012345')") 

545 

546 statement = database.prepare("SELECT * FROM tab") 

547 

548 assert statement.fetch() 

549 assert statement.column_value_string(0) == "1" 

550 assert statement.column_value_string(1) == "\"one\"" 

551 assert statement.column_value_string(2) == "null" 

552 assert statement.column_value_string(3) == "1.000000" 

553 assert statement.column_value_string(4) == "b\"\\x01#E\"" 

554 

555test statement(): 

556 Path("the.db").rm(force=True) 

557 

558 database = Database(Path("the.db")) 

559 

560 database.execute("CREATE TABLE tab(x)") 

561 database.execute("INSERT INTO tab VALUES(5)") 

562 database.execute("INSERT INTO tab VALUES(6)") 

563 database.execute("INSERT INTO tab VALUES(7)") 

564 

565 statement = database.prepare("SELECT * FROM tab WHERE x >= ? ORDER BY x") 

566 

567 # Get all. 

568 statement.bind_int(1, 2) 

569 assert statement.fetch() 

570 assert statement.column_int(0) == 5 

571 assert statement.fetch() 

572 assert statement.column_int(0) == 6 

573 assert statement.fetch() 

574 assert not statement.fetch() 

575 

576 # Get one then reset. 

577 statement.bind_int(1, 2) 

578 assert statement.fetch() 

579 assert statement.column_int(0) == 5 

580 statement.reset() 

581 assert statement.fetch() 

582 assert statement.column_int(0) == 5 

583 

584 # Bind before reset. 

585 try: 

586 ok = False 

587 statement.bind_int(1, 2) 

588 except SqlError: 

589 ok = True 

590 

591 assert ok 

592 statement.reset() 

593 

594 # Fetch from beginning once all rows read. 

595 statement.bind_int(1, 2) 

596 assert statement.fetch() 

597 assert statement.column_int(0) == 5 

598 assert statement.fetch() 

599 assert statement.fetch() 

600 assert not statement.fetch() 

601 assert statement.fetch() 

602 assert statement.column_int(0) == 5 

603 

604 Path("the.db").rm(force=True) 

605 

606test two_database_connections(): 

607 Path("the.db").rm(force=True) 

608 

609 database_1 = Database(Path("the.db")) 

610 database_2 = Database(Path("the.db")) 

611 

612 database_1.execute("CREATE TABLE tab(x)") 

613 database_1.execute("INSERT INTO tab VALUES(5)") 

614 database_1.execute("INSERT INTO tab VALUES(6)") 

615 

616 statement_1 = database_1.prepare("SELECT * FROM tab WHERE x = ?") 

617 statement_2 = database_2.prepare("SELECT * FROM tab WHERE x = ?") 

618 

619 # Fetch interleaved. 

620 statement_1.bind_int(1, 5) 

621 statement_2.bind_int(1, 6) 

622 assert statement_2.fetch() 

623 assert statement_2.column_int(0) == 6 

624 assert statement_1.fetch() 

625 assert statement_1.column_int(0) == 5 

626 assert not statement_1.fetch() 

627 assert not statement_2.fetch() 

628 

629 Path("the.db").rm(force=True) 

630 

631test string(): 

632 path = Path("the.db") 

633 path.rm(force=True) 

634 database = Database(path) 

635 database.execute("CREATE TABLE tab(x)") 

636 statement = database.prepare("INSERT INTO tab VALUES(?)") 

637 statement.bind_string(1, "📦") 

638 statement.execute() 

639 statement = database.prepare("SELECT * FROM tab") 

640 assert statement.fetch() 

641 assert statement.column_string(0) == "📦"