documents.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933
  1. import asyncio
  2. import copy
  3. import json
  4. import logging
  5. from typing import Any, Optional
  6. from uuid import UUID
  7. import asyncpg
  8. from fastapi import HTTPException
  9. from core.base import (
  10. Handler,
  11. DocumentResponse,
  12. DocumentType,
  13. IngestionStatus,
  14. KGEnrichmentStatus,
  15. KGExtractionStatus,
  16. R2RException,
  17. SearchSettings,
  18. )
  19. from .base import PostgresConnectionManager
  20. logger = logging.getLogger()
  21. class PostgresDocumentsHandler(Handler):
  22. TABLE_NAME = "documents"
  23. COLUMN_VARS = [
  24. "extraction_id",
  25. "id",
  26. "owner_id",
  27. "collection_ids",
  28. ]
  29. def __init__(
  30. self,
  31. project_name: str,
  32. connection_manager: PostgresConnectionManager,
  33. dimension: int,
  34. ):
  35. self.dimension = dimension
  36. super().__init__(project_name, connection_manager)
  37. async def create_tables(self):
  38. logger.info(
  39. f"Creating table, if not exists: {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
  40. )
  41. try:
  42. query = f"""
  43. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} (
  44. id UUID PRIMARY KEY,
  45. collection_ids UUID[],
  46. owner_id UUID,
  47. type TEXT,
  48. metadata JSONB,
  49. title TEXT,
  50. summary TEXT NULL,
  51. summary_embedding vector({self.dimension}) NULL,
  52. version TEXT,
  53. size_in_bytes INT,
  54. ingestion_status TEXT DEFAULT 'pending',
  55. extraction_status TEXT DEFAULT 'pending',
  56. created_at TIMESTAMPTZ DEFAULT NOW(),
  57. updated_at TIMESTAMPTZ DEFAULT NOW(),
  58. ingestion_attempt_number INT DEFAULT 0,
  59. raw_tsvector tsvector GENERATED ALWAYS AS (
  60. setweight(to_tsvector('english', COALESCE(title, '')), 'A') ||
  61. setweight(to_tsvector('english', COALESCE(summary, '')), 'B') ||
  62. setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C')
  63. ) STORED
  64. );
  65. CREATE INDEX IF NOT EXISTS idx_collection_ids_{self.project_name}
  66. ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} USING GIN (collection_ids);
  67. -- Full text search index
  68. CREATE INDEX IF NOT EXISTS idx_doc_search_{self.project_name}
  69. ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  70. USING GIN (raw_tsvector);
  71. """
  72. await self.connection_manager.execute_query(query)
  73. except Exception as e:
  74. logger.warning(f"Error {e} when creating document table.")
  75. async def upsert_documents_overview(
  76. self, documents_overview: DocumentResponse | list[DocumentResponse]
  77. ) -> None:
  78. if isinstance(documents_overview, DocumentResponse):
  79. documents_overview = [documents_overview]
  80. # TODO: make this an arg
  81. max_retries = 20
  82. for document in documents_overview:
  83. retries = 0
  84. while retries < max_retries:
  85. try:
  86. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  87. async with conn.transaction():
  88. # Lock the row for update
  89. check_query = f"""
  90. SELECT ingestion_attempt_number, ingestion_status FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  91. WHERE id = $1 FOR UPDATE
  92. """
  93. existing_doc = await conn.fetchrow(
  94. check_query, document.id
  95. )
  96. db_entry = document.convert_to_db_entry()
  97. if existing_doc:
  98. db_version = existing_doc[
  99. "ingestion_attempt_number"
  100. ]
  101. db_status = existing_doc["ingestion_status"]
  102. new_version = db_entry[
  103. "ingestion_attempt_number"
  104. ]
  105. # Only increment version if status is changing to 'success' or if it's a new version
  106. if (
  107. db_status != "success"
  108. and db_entry["ingestion_status"]
  109. == "success"
  110. ) or (new_version > db_version):
  111. new_attempt_number = db_version + 1
  112. else:
  113. new_attempt_number = db_version
  114. db_entry["ingestion_attempt_number"] = (
  115. new_attempt_number
  116. )
  117. update_query = f"""
  118. UPDATE {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  119. SET collection_ids = $1, owner_id = $2, type = $3, metadata = $4,
  120. title = $5, version = $6, size_in_bytes = $7, ingestion_status = $8,
  121. extraction_status = $9, updated_at = $10, ingestion_attempt_number = $11,
  122. summary = $12, summary_embedding = $13
  123. WHERE id = $14
  124. """
  125. await conn.execute(
  126. update_query,
  127. db_entry["collection_ids"],
  128. db_entry["owner_id"],
  129. db_entry["document_type"],
  130. db_entry["metadata"],
  131. db_entry["title"],
  132. db_entry["version"],
  133. db_entry["size_in_bytes"],
  134. db_entry["ingestion_status"],
  135. db_entry["extraction_status"],
  136. db_entry["updated_at"],
  137. new_attempt_number,
  138. db_entry["summary"],
  139. db_entry["summary_embedding"],
  140. document.id,
  141. )
  142. else:
  143. insert_query = f"""
  144. INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  145. (id, collection_ids, owner_id, type, metadata, title, version,
  146. size_in_bytes, ingestion_status, extraction_status, created_at,
  147. updated_at, ingestion_attempt_number, summary, summary_embedding)
  148. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
  149. """
  150. await conn.execute(
  151. insert_query,
  152. db_entry["id"],
  153. db_entry["collection_ids"],
  154. db_entry["owner_id"],
  155. db_entry["document_type"],
  156. db_entry["metadata"],
  157. db_entry["title"],
  158. db_entry["version"],
  159. db_entry["size_in_bytes"],
  160. db_entry["ingestion_status"],
  161. db_entry["extraction_status"],
  162. db_entry["created_at"],
  163. db_entry["updated_at"],
  164. db_entry["ingestion_attempt_number"],
  165. db_entry["summary"],
  166. db_entry["summary_embedding"],
  167. )
  168. break # Success, exit the retry loop
  169. except (
  170. asyncpg.exceptions.UniqueViolationError,
  171. asyncpg.exceptions.DeadlockDetectedError,
  172. ) as e:
  173. retries += 1
  174. if retries == max_retries:
  175. logger.error(
  176. f"Failed to update document {document.id} after {max_retries} attempts. Error: {str(e)}"
  177. )
  178. raise
  179. else:
  180. wait_time = 0.1 * (2**retries) # Exponential backoff
  181. await asyncio.sleep(wait_time)
  182. async def delete(
  183. self, document_id: UUID, version: Optional[str] = None
  184. ) -> None:
  185. query = f"""
  186. DELETE FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  187. WHERE id = $1
  188. """
  189. params = [str(document_id)]
  190. if version:
  191. query += " AND version = $2"
  192. params.append(version)
  193. await self.connection_manager.execute_query(query=query, params=params)
  194. async def _get_status_from_table(
  195. self,
  196. ids: list[UUID],
  197. table_name: str,
  198. status_type: str,
  199. column_name: str,
  200. ):
  201. """
  202. Get the workflow status for a given document or list of documents.
  203. Args:
  204. ids (list[UUID]): The document IDs.
  205. table_name (str): The table name.
  206. status_type (str): The type of status to retrieve.
  207. Returns:
  208. The workflow status for the given document or list of documents.
  209. """
  210. query = f"""
  211. SELECT {status_type} FROM {self._get_table_name(table_name)}
  212. WHERE {column_name} = ANY($1)
  213. """
  214. return [
  215. row[status_type]
  216. for row in await self.connection_manager.fetch_query(query, [ids])
  217. ]
  218. async def _get_ids_from_table(
  219. self,
  220. status: list[str],
  221. table_name: str,
  222. status_type: str,
  223. collection_id: Optional[UUID] = None,
  224. ):
  225. """
  226. Get the IDs from a given table.
  227. Args:
  228. status (Union[str, list[str]]): The status or list of statuses to retrieve.
  229. table_name (str): The table name.
  230. status_type (str): The type of status to retrieve.
  231. """
  232. query = f"""
  233. SELECT id FROM {self._get_table_name(table_name)}
  234. WHERE {status_type} = ANY($1) and $2 = ANY(collection_ids)
  235. """
  236. records = await self.connection_manager.fetch_query(
  237. query, [status, collection_id]
  238. )
  239. return [record["id"] for record in records]
  240. async def _set_status_in_table(
  241. self,
  242. ids: list[UUID],
  243. status: str,
  244. table_name: str,
  245. status_type: str,
  246. column_name: str,
  247. ):
  248. """
  249. Set the workflow status for a given document or list of documents.
  250. Args:
  251. ids (list[UUID]): The document IDs.
  252. status (str): The status to set.
  253. table_name (str): The table name.
  254. status_type (str): The type of status to set.
  255. column_name (str): The column name in the table to update.
  256. """
  257. query = f"""
  258. UPDATE {self._get_table_name(table_name)}
  259. SET {status_type} = $1
  260. WHERE {column_name} = Any($2)
  261. """
  262. await self.connection_manager.execute_query(query, [status, ids])
  263. def _get_status_model(self, status_type: str):
  264. """
  265. Get the status model for a given status type.
  266. Args:
  267. status_type (str): The type of status to retrieve.
  268. Returns:
  269. The status model for the given status type.
  270. """
  271. if status_type == "ingestion":
  272. return IngestionStatus
  273. elif status_type == "extraction_status":
  274. return KGExtractionStatus
  275. elif status_type == "graph_cluster_status":
  276. return KGEnrichmentStatus
  277. elif status_type == "graph_sync_status":
  278. return KGEnrichmentStatus
  279. else:
  280. raise R2RException(
  281. status_code=400, message=f"Invalid status type: {status_type}"
  282. )
  283. async def get_workflow_status(
  284. self, id: UUID | list[UUID], status_type: str
  285. ):
  286. """
  287. Get the workflow status for a given document or list of documents.
  288. Args:
  289. id (Union[UUID, list[UUID]]): The document ID or list of document IDs.
  290. status_type (str): The type of status to retrieve.
  291. Returns:
  292. The workflow status for the given document or list of documents.
  293. """
  294. ids = [id] if isinstance(id, UUID) else id
  295. out_model = self._get_status_model(status_type)
  296. result = await self._get_status_from_table(
  297. ids,
  298. out_model.table_name(),
  299. status_type,
  300. out_model.id_column(),
  301. )
  302. result = [out_model[status.upper()] for status in result]
  303. return result[0] if isinstance(id, UUID) else result
  304. async def set_workflow_status(
  305. self, id: UUID | list[UUID], status_type: str, status: str
  306. ):
  307. """
  308. Set the workflow status for a given document or list of documents.
  309. Args:
  310. id (Union[UUID, list[UUID]]): The document ID or list of document IDs.
  311. status_type (str): The type of status to set.
  312. status (str): The status to set.
  313. """
  314. ids = [id] if isinstance(id, UUID) else id
  315. out_model = self._get_status_model(status_type)
  316. return await self._set_status_in_table(
  317. ids,
  318. status,
  319. out_model.table_name(),
  320. status_type,
  321. out_model.id_column(),
  322. )
  323. async def get_document_ids_by_status(
  324. self,
  325. status_type: str,
  326. status: str | list[str],
  327. collection_id: Optional[UUID] = None,
  328. ):
  329. """
  330. Get the IDs for a given status.
  331. Args:
  332. ids_key (str): The key to retrieve the IDs.
  333. status_type (str): The type of status to retrieve.
  334. status (Union[str, list[str]]): The status or list of statuses to retrieve.
  335. """
  336. if isinstance(status, str):
  337. status = [status]
  338. out_model = self._get_status_model(status_type)
  339. return await self._get_ids_from_table(
  340. status, out_model.table_name(), status_type, collection_id
  341. )
  342. async def get_documents_overview(
  343. self,
  344. offset: int,
  345. limit: int,
  346. filter_user_ids: Optional[list[UUID]] = None,
  347. filter_document_ids: Optional[list[UUID]] = None,
  348. filter_collection_ids: Optional[list[UUID]] = None,
  349. ) -> dict[str, Any]:
  350. conditions = []
  351. or_conditions = []
  352. params: list[Any] = []
  353. param_index = 1
  354. # Handle document IDs with AND
  355. if filter_document_ids:
  356. conditions.append(f"id = ANY(${param_index})")
  357. params.append(filter_document_ids)
  358. param_index += 1
  359. # Handle user_ids and collection_ids with OR
  360. if filter_user_ids:
  361. or_conditions.append(f"owner_id = ANY(${param_index})")
  362. params.append(filter_user_ids)
  363. param_index += 1
  364. if filter_collection_ids:
  365. or_conditions.append(f"collection_ids && ${param_index}")
  366. params.append(filter_collection_ids)
  367. param_index += 1
  368. base_query = f"""
  369. FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  370. """
  371. # Combine conditions with appropriate AND/OR logic
  372. where_conditions = []
  373. if conditions:
  374. where_conditions.append("(" + " AND ".join(conditions) + ")")
  375. if or_conditions:
  376. where_conditions.append("(" + " OR ".join(or_conditions) + ")")
  377. if where_conditions:
  378. base_query += " WHERE " + " AND ".join(where_conditions)
  379. # Construct the SELECT part of the query based on column existence
  380. select_fields = """
  381. SELECT id, collection_ids, owner_id, type, metadata, title, version,
  382. size_in_bytes, ingestion_status, extraction_status, created_at, updated_at,
  383. summary, summary_embedding,
  384. COUNT(*) OVER() AS total_entries
  385. """
  386. query = f"""
  387. {select_fields}
  388. {base_query}
  389. ORDER BY created_at DESC
  390. OFFSET ${param_index}
  391. """
  392. params.append(offset)
  393. param_index += 1
  394. if limit != -1:
  395. query += f" LIMIT ${param_index}"
  396. params.append(limit)
  397. param_index += 1
  398. try:
  399. results = await self.connection_manager.fetch_query(query, params)
  400. total_entries = results[0]["total_entries"] if results else 0
  401. documents = []
  402. for row in results:
  403. # Safely handle the embedding
  404. embedding = None
  405. if (
  406. "summary_embedding" in row
  407. and row["summary_embedding"] is not None
  408. ):
  409. try:
  410. # Parse the vector string returned by Postgres
  411. embedding_str = row["summary_embedding"]
  412. if embedding_str.startswith(
  413. "["
  414. ) and embedding_str.endswith("]"):
  415. embedding = [
  416. float(x)
  417. for x in embedding_str[1:-1].split(",")
  418. if x
  419. ]
  420. except Exception as e:
  421. logger.warning(
  422. f"Failed to parse embedding for document {row['id']}: {e}"
  423. )
  424. documents.append(
  425. DocumentResponse(
  426. id=row["id"],
  427. collection_ids=row["collection_ids"],
  428. owner_id=row["owner_id"],
  429. document_type=DocumentType(row["type"]),
  430. metadata=json.loads(row["metadata"]),
  431. title=row["title"],
  432. version=row["version"],
  433. size_in_bytes=row["size_in_bytes"],
  434. ingestion_status=IngestionStatus(
  435. row["ingestion_status"]
  436. ),
  437. extraction_status=KGExtractionStatus(
  438. row["extraction_status"]
  439. ),
  440. created_at=row["created_at"],
  441. updated_at=row["updated_at"],
  442. summary=row["summary"] if "summary" in row else None,
  443. summary_embedding=embedding,
  444. )
  445. )
  446. return {"results": documents, "total_entries": total_entries}
  447. except Exception as e:
  448. logger.error(f"Error in get_documents_overview: {str(e)}")
  449. raise HTTPException(
  450. status_code=500,
  451. detail="Database query failed",
  452. )
  453. async def semantic_document_search(
  454. self, query_embedding: list[float], search_settings: SearchSettings
  455. ) -> list[DocumentResponse]:
  456. """Search documents using semantic similarity with their summary embeddings."""
  457. where_clauses = ["summary_embedding IS NOT NULL"]
  458. params: list[str | int | bytes] = [str(query_embedding)]
  459. # Handle filters
  460. if search_settings.filters:
  461. filter_clause = self._build_filters(
  462. search_settings.filters, params
  463. )
  464. where_clauses.append(filter_clause)
  465. where_clause = " AND ".join(where_clauses)
  466. query = f"""
  467. WITH document_scores AS (
  468. SELECT
  469. id,
  470. collection_ids,
  471. owner_id,
  472. type,
  473. metadata,
  474. title,
  475. version,
  476. size_in_bytes,
  477. ingestion_status,
  478. extraction_status,
  479. created_at,
  480. updated_at,
  481. summary,
  482. summary_embedding,
  483. (summary_embedding <=> $1::vector({self.dimension})) as semantic_distance
  484. FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  485. WHERE {where_clause}
  486. ORDER BY semantic_distance ASC
  487. LIMIT ${len(params) + 1}
  488. OFFSET ${len(params) + 2}
  489. )
  490. SELECT *,
  491. 1.0 - semantic_distance as semantic_score
  492. FROM document_scores
  493. """
  494. params.extend([search_settings.limit, search_settings.offset])
  495. results = await self.connection_manager.fetch_query(query, params)
  496. return [
  497. DocumentResponse(
  498. id=row["id"],
  499. collection_ids=row["collection_ids"],
  500. owner_id=row["owner_id"],
  501. document_type=DocumentType(row["type"]),
  502. metadata={
  503. **(
  504. json.loads(row["metadata"])
  505. if search_settings.include_metadatas
  506. else {}
  507. ),
  508. "search_score": float(row["semantic_score"]),
  509. "search_type": "semantic",
  510. },
  511. title=row["title"],
  512. version=row["version"],
  513. size_in_bytes=row["size_in_bytes"],
  514. ingestion_status=IngestionStatus(row["ingestion_status"]),
  515. extraction_status=KGExtractionStatus(row["extraction_status"]),
  516. created_at=row["created_at"],
  517. updated_at=row["updated_at"],
  518. summary=row["summary"],
  519. summary_embedding=[
  520. float(x)
  521. for x in row["summary_embedding"][1:-1].split(",")
  522. if x
  523. ],
  524. )
  525. for row in results
  526. ]
  527. async def full_text_document_search(
  528. self, query_text: str, search_settings: SearchSettings
  529. ) -> list[DocumentResponse]:
  530. """Enhanced full-text search using generated tsvector."""
  531. where_clauses = ["raw_tsvector @@ websearch_to_tsquery('english', $1)"]
  532. params: list[str | int | bytes] = [query_text]
  533. # Handle filters
  534. if search_settings.filters:
  535. filter_clause = self._build_filters(
  536. search_settings.filters, params
  537. )
  538. where_clauses.append(filter_clause)
  539. where_clause = " AND ".join(where_clauses)
  540. query = f"""
  541. WITH document_scores AS (
  542. SELECT
  543. id,
  544. collection_ids,
  545. owner_id,
  546. type,
  547. metadata,
  548. title,
  549. version,
  550. size_in_bytes,
  551. ingestion_status,
  552. extraction_status,
  553. created_at,
  554. updated_at,
  555. summary,
  556. summary_embedding,
  557. ts_rank_cd(raw_tsvector, websearch_to_tsquery('english', $1), 32) as text_score
  558. FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  559. WHERE {where_clause}
  560. ORDER BY text_score DESC
  561. LIMIT ${len(params) + 1}
  562. OFFSET ${len(params) + 2}
  563. )
  564. SELECT * FROM document_scores
  565. """
  566. params.extend([search_settings.limit, search_settings.offset])
  567. results = await self.connection_manager.fetch_query(query, params)
  568. return [
  569. DocumentResponse(
  570. id=row["id"],
  571. collection_ids=row["collection_ids"],
  572. owner_id=row["owner_id"],
  573. document_type=DocumentType(row["type"]),
  574. metadata={
  575. **(
  576. json.loads(row["metadata"])
  577. if search_settings.include_metadatas
  578. else {}
  579. ),
  580. "search_score": float(row["text_score"]),
  581. "search_type": "full_text",
  582. },
  583. title=row["title"],
  584. version=row["version"],
  585. size_in_bytes=row["size_in_bytes"],
  586. ingestion_status=IngestionStatus(row["ingestion_status"]),
  587. extraction_status=KGExtractionStatus(row["extraction_status"]),
  588. created_at=row["created_at"],
  589. updated_at=row["updated_at"],
  590. summary=row["summary"],
  591. summary_embedding=(
  592. [
  593. float(x)
  594. for x in row["summary_embedding"][1:-1].split(",")
  595. if x
  596. ]
  597. if row["summary_embedding"]
  598. else None
  599. ),
  600. )
  601. for row in results
  602. ]
  603. async def hybrid_document_search(
  604. self,
  605. query_text: str,
  606. query_embedding: list[float],
  607. search_settings: SearchSettings,
  608. ) -> list[DocumentResponse]:
  609. """Search documents using both semantic and full-text search with RRF fusion."""
  610. # Get more results than needed for better fusion
  611. extended_settings = copy.deepcopy(search_settings)
  612. extended_settings.limit = search_settings.limit * 3
  613. # Get results from both search methods
  614. semantic_results = await self.semantic_document_search(
  615. query_embedding, extended_settings
  616. )
  617. full_text_results = await self.full_text_document_search(
  618. query_text, extended_settings
  619. )
  620. # Combine results using RRF
  621. doc_scores: dict[str, dict] = {}
  622. # Process semantic results
  623. for rank, result in enumerate(semantic_results, 1):
  624. doc_id = str(result.id)
  625. doc_scores[doc_id] = {
  626. "semantic_rank": rank,
  627. "full_text_rank": len(full_text_results)
  628. + 1, # Default rank if not found
  629. "data": result,
  630. }
  631. # Process full-text results
  632. for rank, result in enumerate(full_text_results, 1):
  633. doc_id = str(result.id)
  634. if doc_id in doc_scores:
  635. doc_scores[doc_id]["full_text_rank"] = rank
  636. else:
  637. doc_scores[doc_id] = {
  638. "semantic_rank": len(semantic_results)
  639. + 1, # Default rank if not found
  640. "full_text_rank": rank,
  641. "data": result,
  642. }
  643. # Calculate RRF scores using hybrid search settings
  644. rrf_k = search_settings.hybrid_settings.rrf_k
  645. semantic_weight = search_settings.hybrid_settings.semantic_weight
  646. full_text_weight = search_settings.hybrid_settings.full_text_weight
  647. for scores in doc_scores.values():
  648. semantic_score = 1 / (rrf_k + scores["semantic_rank"])
  649. full_text_score = 1 / (rrf_k + scores["full_text_rank"])
  650. # Weighted combination
  651. combined_score = (
  652. semantic_score * semantic_weight
  653. + full_text_score * full_text_weight
  654. ) / (semantic_weight + full_text_weight)
  655. scores["final_score"] = combined_score
  656. # Sort by final score and apply offset/limit
  657. sorted_results = sorted(
  658. doc_scores.values(), key=lambda x: x["final_score"], reverse=True
  659. )[
  660. search_settings.offset : search_settings.offset
  661. + search_settings.limit
  662. ]
  663. return [
  664. DocumentResponse(
  665. **{
  666. **result["data"].__dict__,
  667. "metadata": {
  668. **(
  669. result["data"].metadata
  670. if search_settings.include_metadatas
  671. else {}
  672. ),
  673. "search_score": result["final_score"],
  674. "semantic_rank": result["semantic_rank"],
  675. "full_text_rank": result["full_text_rank"],
  676. "search_type": "hybrid",
  677. },
  678. }
  679. )
  680. for result in sorted_results
  681. ]
  682. async def search_documents(
  683. self,
  684. query_text: str,
  685. query_embedding: Optional[list[float]] = None,
  686. settings: Optional[SearchSettings] = None,
  687. ) -> list[DocumentResponse]:
  688. """
  689. Main search method that delegates to the appropriate search method based on settings.
  690. """
  691. if settings is None:
  692. settings = SearchSettings()
  693. if (
  694. settings.use_semantic_search and settings.use_fulltext_search
  695. ) or settings.use_hybrid_search:
  696. if query_embedding is None:
  697. raise ValueError(
  698. "query_embedding is required for hybrid search"
  699. )
  700. return await self.hybrid_document_search(
  701. query_text, query_embedding, settings
  702. )
  703. elif settings.use_semantic_search:
  704. if query_embedding is None:
  705. raise ValueError(
  706. "query_embedding is required for vector search"
  707. )
  708. return await self.semantic_document_search(
  709. query_embedding, settings
  710. )
  711. else:
  712. return await self.full_text_document_search(query_text, settings)
  713. # TODO - Remove copy pasta, consolidate
  714. def _build_filters(
  715. self, filters: dict, parameters: list[str | int | bytes]
  716. ) -> str:
  717. def parse_condition(key: str, value: Any) -> str: # type: ignore
  718. # nonlocal parameters
  719. if key in self.COLUMN_VARS:
  720. # Handle column-based filters
  721. if isinstance(value, dict):
  722. op, clause = next(iter(value.items()))
  723. if op == "$eq":
  724. parameters.append(clause)
  725. return f"{key} = ${len(parameters)}"
  726. elif op == "$ne":
  727. parameters.append(clause)
  728. return f"{key} != ${len(parameters)}"
  729. elif op == "$in":
  730. parameters.append(clause)
  731. return f"{key} = ANY(${len(parameters)})"
  732. elif op == "$nin":
  733. parameters.append(clause)
  734. return f"{key} != ALL(${len(parameters)})"
  735. elif op == "$overlap":
  736. parameters.append(clause)
  737. return f"{key} && ${len(parameters)}"
  738. elif op == "$contains":
  739. parameters.append(clause)
  740. return f"{key} @> ${len(parameters)}"
  741. elif op == "$any":
  742. if key == "collection_ids":
  743. parameters.append(f"%{clause}%")
  744. return f"array_to_string({key}, ',') LIKE ${len(parameters)}"
  745. parameters.append(clause)
  746. return f"${len(parameters)} = ANY({key})"
  747. else:
  748. raise ValueError(
  749. f"Unsupported operator for column {key}: {op}"
  750. )
  751. else:
  752. # Handle direct equality
  753. parameters.append(value)
  754. return f"{key} = ${len(parameters)}"
  755. else:
  756. # Handle JSON-based filters
  757. json_col = "metadata"
  758. if key.startswith("metadata."):
  759. key = key.split("metadata.")[1]
  760. if isinstance(value, dict):
  761. op, clause = next(iter(value.items()))
  762. if op not in (
  763. "$eq",
  764. "$ne",
  765. "$lt",
  766. "$lte",
  767. "$gt",
  768. "$gte",
  769. "$in",
  770. "$contains",
  771. ):
  772. raise ValueError("unknown operator")
  773. if op == "$eq":
  774. parameters.append(json.dumps(clause))
  775. return (
  776. f"{json_col}->'{key}' = ${len(parameters)}::jsonb"
  777. )
  778. elif op == "$ne":
  779. parameters.append(json.dumps(clause))
  780. return (
  781. f"{json_col}->'{key}' != ${len(parameters)}::jsonb"
  782. )
  783. elif op == "$lt":
  784. parameters.append(json.dumps(clause))
  785. return f"({json_col}->'{key}')::float < (${len(parameters)}::jsonb)::float"
  786. elif op == "$lte":
  787. parameters.append(json.dumps(clause))
  788. return f"({json_col}->'{key}')::float <= (${len(parameters)}::jsonb)::float"
  789. elif op == "$gt":
  790. parameters.append(json.dumps(clause))
  791. return f"({json_col}->'{key}')::float > (${len(parameters)}::jsonb)::float"
  792. elif op == "$gte":
  793. parameters.append(json.dumps(clause))
  794. return f"({json_col}->'{key}')::float >= (${len(parameters)}::jsonb)::float"
  795. elif op == "$in":
  796. if not isinstance(clause, list):
  797. raise ValueError(
  798. "argument to $in filter must be a list"
  799. )
  800. parameters.append(json.dumps(clause))
  801. return f"{json_col}->'{key}' = ANY(SELECT jsonb_array_elements(${len(parameters)}::jsonb))"
  802. elif op == "$contains":
  803. if not isinstance(clause, (int, str, float, list)):
  804. raise ValueError(
  805. "argument to $contains filter must be a scalar or array"
  806. )
  807. parameters.append(json.dumps(clause))
  808. return (
  809. f"{json_col}->'{key}' @> ${len(parameters)}::jsonb"
  810. )
  811. def parse_filter(filter_dict: dict) -> str:
  812. filter_conditions = []
  813. for key, value in filter_dict.items():
  814. if key == "$and":
  815. and_conditions = [
  816. parse_filter(f) for f in value if f
  817. ] # Skip empty dictionaries
  818. if and_conditions:
  819. filter_conditions.append(
  820. f"({' AND '.join(and_conditions)})"
  821. )
  822. elif key == "$or":
  823. or_conditions = [
  824. parse_filter(f) for f in value if f
  825. ] # Skip empty dictionaries
  826. if or_conditions:
  827. filter_conditions.append(
  828. f"({' OR '.join(or_conditions)})"
  829. )
  830. else:
  831. filter_conditions.append(parse_condition(key, value))
  832. # Check if there is only a single condition
  833. if len(filter_conditions) == 1:
  834. return filter_conditions[0]
  835. else:
  836. return " AND ".join(filter_conditions)
  837. where_clause = parse_filter(filters)
  838. return where_clause