diff --git a/archivebox/search/backends/sqlite.py b/archivebox/search/backends/sqlite.py index 4ed9e79c..b4c61efb 100644 --- a/archivebox/search/backends/sqlite.py +++ b/archivebox/search/backends/sqlite.py @@ -1,5 +1,5 @@ import codecs -from typing import List, Optional, Generator +from typing import List, Generator import sqlite3 from archivebox.util import enforce_types @@ -22,7 +22,7 @@ if FTS_SEPARATE_DATABASE: return database SQLITE_BIND = "?" else: - from django.db import connection as database + from django.db import connection as database # type: ignore[no-redef, assignment] get_connection = database.cursor SQLITE_BIND = "%s" @@ -31,7 +31,7 @@ else: try: limit_id = sqlite3.SQLITE_LIMIT_LENGTH try: - with database.temporary_connection() as cursor: + with database.temporary_connection() as cursor: # type: ignore[attr-defined] SQLITE_LIMIT_LENGTH = cursor.connection.getlimit(limit_id) except AttributeError: SQLITE_LIMIT_LENGTH = database.getlimit(limit_id) @@ -51,6 +51,7 @@ def _escape_sqlite3(value: str, *, quote: str, errors='strict') -> str: nul_index, nul_index + 1, "NUL not allowed") error_handler = codecs.lookup_error(errors) replacement, _ = error_handler(error) + assert isinstance(replacement, str), "handling a UnicodeEncodeError should return a str replacement" encodable = encodable.replace("\x00", replacement) return quote + encodable.replace(quote, quote * 2) + quote @@ -99,6 +100,16 @@ def _create_tables(): " END;" ) +def _handle_query_exception(exc: Exception): + message = str(exc) + if message.startswith("no such table:"): + raise RuntimeError( + "SQLite full-text search index has not yet" + " been created; run `archivebox update --index-only`." + ) + else: + raise exc + @enforce_types def index(snapshot_id: str, texts: List[str]): text = ' '.join(texts)[:SQLITE_LIMIT_LENGTH] @@ -145,22 +156,29 @@ def search(text: str) -> List[str]: id_table = _escape_sqlite3_identifier(FTS_ID_TABLE) with get_connection() as cursor: - res = cursor.execute( - f"SELECT snapshot_id FROM {table}" - f" INNER JOIN {id_table}" - f" ON {id_table}.rowid = {table}.rowid" - f" WHERE {table} MATCH {SQLITE_BIND}", - [text]) + try: + res = cursor.execute( + f"SELECT snapshot_id FROM {table}" + f" INNER JOIN {id_table}" + f" ON {id_table}.rowid = {table}.rowid" + f" WHERE {table} MATCH {SQLITE_BIND}", + [text]) + except Exception as e: + _handle_query_exception(e) + snap_ids = [row[0] for row in res.fetchall()] return snap_ids @enforce_types def flush(snapshot_ids: Generator[str, None, None]): - snapshot_ids = list(snapshot_ids) + snapshot_ids = list(snapshot_ids) # type: ignore[assignment] id_table = _escape_sqlite3_identifier(FTS_ID_TABLE) with get_connection() as cursor: - cursor.executemany( - f"DELETE FROM {id_table} WHERE snapshot_id={SQLITE_BIND}", - [snapshot_ids]) + try: + cursor.executemany( + f"DELETE FROM {id_table} WHERE snapshot_id={SQLITE_BIND}", + [snapshot_ids]) + except Exception as e: + _handle_query_exception(e)