1
0
Fork 0
mirror of synced 2024-07-02 04:50:47 +12:00

sqlite search: clean up errors and type-checking

Clean up error handling, and report a better error message
on search and flush if FTS5 tables haven't yet been created.

Add some mypy comments to clean up type-checking errors.
This commit is contained in:
Ross Williams 2023-10-16 14:31:52 -04:00
parent adb9f0ecc9
commit 1e604a1352

View file

@ -1,5 +1,5 @@
import codecs import codecs
from typing import List, Optional, Generator from typing import List, Generator
import sqlite3 import sqlite3
from archivebox.util import enforce_types from archivebox.util import enforce_types
@ -22,7 +22,7 @@ if FTS_SEPARATE_DATABASE:
return database return database
SQLITE_BIND = "?" SQLITE_BIND = "?"
else: else:
from django.db import connection as database from django.db import connection as database # type: ignore[no-redef, assignment]
get_connection = database.cursor get_connection = database.cursor
SQLITE_BIND = "%s" SQLITE_BIND = "%s"
@ -31,7 +31,7 @@ else:
try: try:
limit_id = sqlite3.SQLITE_LIMIT_LENGTH limit_id = sqlite3.SQLITE_LIMIT_LENGTH
try: try:
with database.temporary_connection() as cursor: with database.temporary_connection() as cursor: # type: ignore[attr-defined]
SQLITE_LIMIT_LENGTH = cursor.connection.getlimit(limit_id) SQLITE_LIMIT_LENGTH = cursor.connection.getlimit(limit_id)
except AttributeError: except AttributeError:
SQLITE_LIMIT_LENGTH = database.getlimit(limit_id) SQLITE_LIMIT_LENGTH = database.getlimit(limit_id)
@ -51,6 +51,7 @@ def _escape_sqlite3(value: str, *, quote: str, errors='strict') -> str:
nul_index, nul_index + 1, "NUL not allowed") nul_index, nul_index + 1, "NUL not allowed")
error_handler = codecs.lookup_error(errors) error_handler = codecs.lookup_error(errors)
replacement, _ = error_handler(error) replacement, _ = error_handler(error)
assert isinstance(replacement, str), "handling a UnicodeEncodeError should return a str replacement"
encodable = encodable.replace("\x00", replacement) encodable = encodable.replace("\x00", replacement)
return quote + encodable.replace(quote, quote * 2) + quote return quote + encodable.replace(quote, quote * 2) + quote
@ -99,6 +100,16 @@ def _create_tables():
" END;" " END;"
) )
def _handle_query_exception(exc: Exception):
message = str(exc)
if message.startswith("no such table:"):
raise RuntimeError(
"SQLite full-text search index has not yet"
" been created; run `archivebox update --index-only`."
)
else:
raise exc
@enforce_types @enforce_types
def index(snapshot_id: str, texts: List[str]): def index(snapshot_id: str, texts: List[str]):
text = ' '.join(texts)[:SQLITE_LIMIT_LENGTH] text = ' '.join(texts)[:SQLITE_LIMIT_LENGTH]
@ -145,22 +156,29 @@ def search(text: str) -> List[str]:
id_table = _escape_sqlite3_identifier(FTS_ID_TABLE) id_table = _escape_sqlite3_identifier(FTS_ID_TABLE)
with get_connection() as cursor: with get_connection() as cursor:
try:
res = cursor.execute( res = cursor.execute(
f"SELECT snapshot_id FROM {table}" f"SELECT snapshot_id FROM {table}"
f" INNER JOIN {id_table}" f" INNER JOIN {id_table}"
f" ON {id_table}.rowid = {table}.rowid" f" ON {id_table}.rowid = {table}.rowid"
f" WHERE {table} MATCH {SQLITE_BIND}", f" WHERE {table} MATCH {SQLITE_BIND}",
[text]) [text])
except Exception as e:
_handle_query_exception(e)
snap_ids = [row[0] for row in res.fetchall()] snap_ids = [row[0] for row in res.fetchall()]
return snap_ids return snap_ids
@enforce_types @enforce_types
def flush(snapshot_ids: Generator[str, None, None]): def flush(snapshot_ids: Generator[str, None, None]):
snapshot_ids = list(snapshot_ids) snapshot_ids = list(snapshot_ids) # type: ignore[assignment]
id_table = _escape_sqlite3_identifier(FTS_ID_TABLE) id_table = _escape_sqlite3_identifier(FTS_ID_TABLE)
with get_connection() as cursor: with get_connection() as cursor:
try:
cursor.executemany( cursor.executemany(
f"DELETE FROM {id_table} WHERE snapshot_id={SQLITE_BIND}", f"DELETE FROM {id_table} WHERE snapshot_id={SQLITE_BIND}",
[snapshot_ids]) [snapshot_ids])
except Exception as e:
_handle_query_exception(e)