documents.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909
  1. import asyncio
  2. import copy
  3. import csv
  4. import json
  5. import logging
  6. import tempfile
  7. from typing import IO, Any, Optional
  8. from uuid import UUID
  9. import asyncpg
  10. from fastapi import HTTPException
  11. from core.base import (
  12. DocumentResponse,
  13. DocumentType,
  14. Handler,
  15. IngestionStatus,
  16. KGEnrichmentStatus,
  17. KGExtractionStatus,
  18. R2RException,
  19. SearchSettings,
  20. )
  21. from .base import PostgresConnectionManager
  22. from .filters import apply_filters
  23. logger = logging.getLogger()
  24. class PostgresDocumentsHandler(Handler):
  25. TABLE_NAME = "documents"
  26. def __init__(
  27. self,
  28. project_name: str,
  29. connection_manager: PostgresConnectionManager,
  30. dimension: int,
  31. ):
  32. self.dimension = dimension
  33. super().__init__(project_name, connection_manager)
  34. async def create_tables(self):
  35. logger.info(
  36. f"Creating table, if not exists: {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
  37. )
  38. try:
  39. query = f"""
  40. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} (
  41. id UUID PRIMARY KEY,
  42. collection_ids UUID[],
  43. owner_id UUID,
  44. type TEXT,
  45. metadata JSONB,
  46. title TEXT,
  47. summary TEXT NULL,
  48. summary_embedding vector({self.dimension}) NULL,
  49. version TEXT,
  50. size_in_bytes INT,
  51. ingestion_status TEXT DEFAULT 'pending',
  52. extraction_status TEXT DEFAULT 'pending',
  53. created_at TIMESTAMPTZ DEFAULT NOW(),
  54. updated_at TIMESTAMPTZ DEFAULT NOW(),
  55. ingestion_attempt_number INT DEFAULT 0,
  56. raw_tsvector tsvector GENERATED ALWAYS AS (
  57. setweight(to_tsvector('english', COALESCE(title, '')), 'A') ||
  58. setweight(to_tsvector('english', COALESCE(summary, '')), 'B') ||
  59. setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C')
  60. ) STORED
  61. );
  62. CREATE INDEX IF NOT EXISTS idx_collection_ids_{self.project_name}
  63. ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} USING GIN (collection_ids);
  64. -- Full text search index
  65. CREATE INDEX IF NOT EXISTS idx_doc_search_{self.project_name}
  66. ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  67. USING GIN (raw_tsvector);
  68. """
  69. await self.connection_manager.execute_query(query)
  70. except Exception as e:
  71. logger.warning(f"Error {e} when creating document table.")
  72. async def upsert_documents_overview(
  73. self, documents_overview: DocumentResponse | list[DocumentResponse]
  74. ) -> None:
  75. if isinstance(documents_overview, DocumentResponse):
  76. documents_overview = [documents_overview]
  77. # TODO: make this an arg
  78. max_retries = 20
  79. for document in documents_overview:
  80. retries = 0
  81. while retries < max_retries:
  82. try:
  83. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  84. async with conn.transaction():
  85. # Lock the row for update
  86. check_query = f"""
  87. SELECT ingestion_attempt_number, ingestion_status FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  88. WHERE id = $1 FOR UPDATE
  89. """
  90. existing_doc = await conn.fetchrow(
  91. check_query, document.id
  92. )
  93. db_entry = document.convert_to_db_entry()
  94. if existing_doc:
  95. db_version = existing_doc[
  96. "ingestion_attempt_number"
  97. ]
  98. db_status = existing_doc["ingestion_status"]
  99. new_version = db_entry[
  100. "ingestion_attempt_number"
  101. ]
  102. # Only increment version if status is changing to 'success' or if it's a new version
  103. if (
  104. db_status != "success"
  105. and db_entry["ingestion_status"]
  106. == "success"
  107. ) or (new_version > db_version):
  108. new_attempt_number = db_version + 1
  109. else:
  110. new_attempt_number = db_version
  111. db_entry["ingestion_attempt_number"] = (
  112. new_attempt_number
  113. )
  114. update_query = f"""
  115. UPDATE {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  116. SET collection_ids = $1, owner_id = $2, type = $3, metadata = $4,
  117. title = $5, version = $6, size_in_bytes = $7, ingestion_status = $8,
  118. extraction_status = $9, updated_at = $10, ingestion_attempt_number = $11,
  119. summary = $12, summary_embedding = $13
  120. WHERE id = $14
  121. """
  122. await conn.execute(
  123. update_query,
  124. db_entry["collection_ids"],
  125. db_entry["owner_id"],
  126. db_entry["document_type"],
  127. db_entry["metadata"],
  128. db_entry["title"],
  129. db_entry["version"],
  130. db_entry["size_in_bytes"],
  131. db_entry["ingestion_status"],
  132. db_entry["extraction_status"],
  133. db_entry["updated_at"],
  134. new_attempt_number,
  135. db_entry["summary"],
  136. db_entry["summary_embedding"],
  137. document.id,
  138. )
  139. else:
  140. insert_query = f"""
  141. INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  142. (id, collection_ids, owner_id, type, metadata, title, version,
  143. size_in_bytes, ingestion_status, extraction_status, created_at,
  144. updated_at, ingestion_attempt_number, summary, summary_embedding)
  145. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
  146. """
  147. await conn.execute(
  148. insert_query,
  149. db_entry["id"],
  150. db_entry["collection_ids"],
  151. db_entry["owner_id"],
  152. db_entry["document_type"],
  153. db_entry["metadata"],
  154. db_entry["title"],
  155. db_entry["version"],
  156. db_entry["size_in_bytes"],
  157. db_entry["ingestion_status"],
  158. db_entry["extraction_status"],
  159. db_entry["created_at"],
  160. db_entry["updated_at"],
  161. db_entry["ingestion_attempt_number"],
  162. db_entry["summary"],
  163. db_entry["summary_embedding"],
  164. )
  165. break # Success, exit the retry loop
  166. except (
  167. asyncpg.exceptions.UniqueViolationError,
  168. asyncpg.exceptions.DeadlockDetectedError,
  169. ) as e:
  170. retries += 1
  171. if retries == max_retries:
  172. logger.error(
  173. f"Failed to update document {document.id} after {max_retries} attempts. Error: {str(e)}"
  174. )
  175. raise
  176. else:
  177. wait_time = 0.1 * (2**retries) # Exponential backoff
  178. await asyncio.sleep(wait_time)
  179. async def delete(
  180. self, document_id: UUID, version: Optional[str] = None
  181. ) -> None:
  182. query = f"""
  183. DELETE FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  184. WHERE id = $1
  185. """
  186. params = [str(document_id)]
  187. if version:
  188. query += " AND version = $2"
  189. params.append(version)
  190. await self.connection_manager.execute_query(query=query, params=params)
  191. async def _get_status_from_table(
  192. self,
  193. ids: list[UUID],
  194. table_name: str,
  195. status_type: str,
  196. column_name: str,
  197. ):
  198. """
  199. Get the workflow status for a given document or list of documents.
  200. Args:
  201. ids (list[UUID]): The document IDs.
  202. table_name (str): The table name.
  203. status_type (str): The type of status to retrieve.
  204. Returns:
  205. The workflow status for the given document or list of documents.
  206. """
  207. query = f"""
  208. SELECT {status_type} FROM {self._get_table_name(table_name)}
  209. WHERE {column_name} = ANY($1)
  210. """
  211. return [
  212. row[status_type]
  213. for row in await self.connection_manager.fetch_query(query, [ids])
  214. ]
  215. async def _get_ids_from_table(
  216. self,
  217. status: list[str],
  218. table_name: str,
  219. status_type: str,
  220. collection_id: Optional[UUID] = None,
  221. ):
  222. """
  223. Get the IDs from a given table.
  224. Args:
  225. status (str | list[str]): The status or list of statuses to retrieve.
  226. table_name (str): The table name.
  227. status_type (str): The type of status to retrieve.
  228. """
  229. query = f"""
  230. SELECT id FROM {self._get_table_name(table_name)}
  231. WHERE {status_type} = ANY($1) and $2 = ANY(collection_ids)
  232. """
  233. records = await self.connection_manager.fetch_query(
  234. query, [status, collection_id]
  235. )
  236. return [record["id"] for record in records]
  237. async def _set_status_in_table(
  238. self,
  239. ids: list[UUID],
  240. status: str,
  241. table_name: str,
  242. status_type: str,
  243. column_name: str,
  244. ):
  245. """
  246. Set the workflow status for a given document or list of documents.
  247. Args:
  248. ids (list[UUID]): The document IDs.
  249. status (str): The status to set.
  250. table_name (str): The table name.
  251. status_type (str): The type of status to set.
  252. column_name (str): The column name in the table to update.
  253. """
  254. query = f"""
  255. UPDATE {self._get_table_name(table_name)}
  256. SET {status_type} = $1
  257. WHERE {column_name} = Any($2)
  258. """
  259. await self.connection_manager.execute_query(query, [status, ids])
  260. def _get_status_model(self, status_type: str):
  261. """
  262. Get the status model for a given status type.
  263. Args:
  264. status_type (str): The type of status to retrieve.
  265. Returns:
  266. The status model for the given status type.
  267. """
  268. if status_type == "ingestion":
  269. return IngestionStatus
  270. elif status_type == "extraction_status":
  271. return KGExtractionStatus
  272. elif status_type in {"graph_cluster_status", "graph_sync_status"}:
  273. return KGEnrichmentStatus
  274. else:
  275. raise R2RException(
  276. status_code=400, message=f"Invalid status type: {status_type}"
  277. )
  278. async def get_workflow_status(
  279. self, id: UUID | list[UUID], status_type: str
  280. ):
  281. """
  282. Get the workflow status for a given document or list of documents.
  283. Args:
  284. id (UUID | list[UUID]): The document ID or list of document IDs.
  285. status_type (str): The type of status to retrieve.
  286. Returns:
  287. The workflow status for the given document or list of documents.
  288. """
  289. ids = [id] if isinstance(id, UUID) else id
  290. out_model = self._get_status_model(status_type)
  291. result = await self._get_status_from_table(
  292. ids,
  293. out_model.table_name(),
  294. status_type,
  295. out_model.id_column(),
  296. )
  297. result = [out_model[status.upper()] for status in result]
  298. return result[0] if isinstance(id, UUID) else result
  299. async def set_workflow_status(
  300. self, id: UUID | list[UUID], status_type: str, status: str
  301. ):
  302. """
  303. Set the workflow status for a given document or list of documents.
  304. Args:
  305. id (UUID | list[UUID]): The document ID or list of document IDs.
  306. status_type (str): The type of status to set.
  307. status (str): The status to set.
  308. """
  309. ids = [id] if isinstance(id, UUID) else id
  310. out_model = self._get_status_model(status_type)
  311. return await self._set_status_in_table(
  312. ids,
  313. status,
  314. out_model.table_name(),
  315. status_type,
  316. out_model.id_column(),
  317. )
  318. async def get_document_ids_by_status(
  319. self,
  320. status_type: str,
  321. status: str | list[str],
  322. collection_id: Optional[UUID] = None,
  323. ):
  324. """
  325. Get the IDs for a given status.
  326. Args:
  327. ids_key (str): The key to retrieve the IDs.
  328. status_type (str): The type of status to retrieve.
  329. status (str | list[str]): The status or list of statuses to retrieve.
  330. """
  331. if isinstance(status, str):
  332. status = [status]
  333. out_model = self._get_status_model(status_type)
  334. return await self._get_ids_from_table(
  335. status, out_model.table_name(), status_type, collection_id
  336. )
  337. async def get_documents_overview(
  338. self,
  339. offset: int,
  340. limit: int,
  341. filter_user_ids: Optional[list[UUID]] = None,
  342. filter_document_ids: Optional[list[UUID]] = None,
  343. filter_collection_ids: Optional[list[UUID]] = None,
  344. ) -> dict[str, Any]:
  345. conditions = []
  346. or_conditions = []
  347. params: list[Any] = []
  348. param_index = 1
  349. # Handle document IDs with AND
  350. if filter_document_ids:
  351. conditions.append(f"id = ANY(${param_index})")
  352. params.append(filter_document_ids)
  353. param_index += 1
  354. # Handle user_ids and collection_ids with OR
  355. if filter_user_ids:
  356. or_conditions.append(f"owner_id = ANY(${param_index})")
  357. params.append(filter_user_ids)
  358. param_index += 1
  359. if filter_collection_ids:
  360. or_conditions.append(f"collection_ids && ${param_index}")
  361. params.append(filter_collection_ids)
  362. param_index += 1
  363. base_query = f"""
  364. FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  365. """
  366. # Combine conditions with appropriate AND/OR logic
  367. where_conditions = []
  368. if conditions:
  369. where_conditions.append("(" + " AND ".join(conditions) + ")")
  370. if or_conditions:
  371. where_conditions.append("(" + " OR ".join(or_conditions) + ")")
  372. if where_conditions:
  373. base_query += " WHERE " + " AND ".join(where_conditions)
  374. # Construct the SELECT part of the query based on column existence
  375. select_fields = """
  376. SELECT id, collection_ids, owner_id, type, metadata, title, version,
  377. size_in_bytes, ingestion_status, extraction_status, created_at, updated_at,
  378. summary, summary_embedding,
  379. COUNT(*) OVER() AS total_entries
  380. """
  381. query = f"""
  382. {select_fields}
  383. {base_query}
  384. ORDER BY created_at DESC
  385. OFFSET ${param_index}
  386. """
  387. params.append(offset)
  388. param_index += 1
  389. if limit != -1:
  390. query += f" LIMIT ${param_index}"
  391. params.append(limit)
  392. param_index += 1
  393. try:
  394. results = await self.connection_manager.fetch_query(query, params)
  395. total_entries = results[0]["total_entries"] if results else 0
  396. documents = []
  397. for row in results:
  398. # Safely handle the embedding
  399. embedding = None
  400. if (
  401. "summary_embedding" in row
  402. and row["summary_embedding"] is not None
  403. ):
  404. try:
  405. # Parse the vector string returned by Postgres
  406. embedding_str = row["summary_embedding"]
  407. if embedding_str.startswith(
  408. "["
  409. ) and embedding_str.endswith("]"):
  410. embedding = [
  411. float(x)
  412. for x in embedding_str[1:-1].split(",")
  413. if x
  414. ]
  415. except Exception as e:
  416. logger.warning(
  417. f"Failed to parse embedding for document {row['id']}: {e}"
  418. )
  419. documents.append(
  420. DocumentResponse(
  421. id=row["id"],
  422. collection_ids=row["collection_ids"],
  423. owner_id=row["owner_id"],
  424. document_type=DocumentType(row["type"]),
  425. metadata=json.loads(row["metadata"]),
  426. title=row["title"],
  427. version=row["version"],
  428. size_in_bytes=row["size_in_bytes"],
  429. ingestion_status=IngestionStatus(
  430. row["ingestion_status"]
  431. ),
  432. extraction_status=KGExtractionStatus(
  433. row["extraction_status"]
  434. ),
  435. created_at=row["created_at"],
  436. updated_at=row["updated_at"],
  437. summary=row["summary"] if "summary" in row else None,
  438. summary_embedding=embedding,
  439. )
  440. )
  441. return {"results": documents, "total_entries": total_entries}
  442. except Exception as e:
  443. logger.error(f"Error in get_documents_overview: {str(e)}")
  444. raise HTTPException(
  445. status_code=500,
  446. detail="Database query failed",
  447. ) from e
  448. async def semantic_document_search(
  449. self, query_embedding: list[float], search_settings: SearchSettings
  450. ) -> list[DocumentResponse]:
  451. """Search documents using semantic similarity with their summary embeddings."""
  452. where_clauses = ["summary_embedding IS NOT NULL"]
  453. params: list[str | int | bytes] = [str(query_embedding)]
  454. if search_settings.filters:
  455. filter_condition, params = apply_filters(
  456. search_settings.filters, params, mode="condition_only"
  457. )
  458. if filter_condition:
  459. where_clauses.append(filter_condition)
  460. where_clause = " AND ".join(where_clauses)
  461. query = f"""
  462. WITH document_scores AS (
  463. SELECT
  464. id,
  465. collection_ids,
  466. owner_id,
  467. type,
  468. metadata,
  469. title,
  470. version,
  471. size_in_bytes,
  472. ingestion_status,
  473. extraction_status,
  474. created_at,
  475. updated_at,
  476. summary,
  477. summary_embedding,
  478. (summary_embedding <=> $1::vector({self.dimension})) as semantic_distance
  479. FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  480. WHERE {where_clause}
  481. ORDER BY semantic_distance ASC
  482. LIMIT ${len(params) + 1}
  483. OFFSET ${len(params) + 2}
  484. )
  485. SELECT *,
  486. 1.0 - semantic_distance as semantic_score
  487. FROM document_scores
  488. """
  489. params.extend([search_settings.limit, search_settings.offset])
  490. results = await self.connection_manager.fetch_query(query, params)
  491. return [
  492. DocumentResponse(
  493. id=row["id"],
  494. collection_ids=row["collection_ids"],
  495. owner_id=row["owner_id"],
  496. document_type=DocumentType(row["type"]),
  497. metadata={
  498. **(
  499. json.loads(row["metadata"])
  500. if search_settings.include_metadatas
  501. else {}
  502. ),
  503. "search_score": float(row["semantic_score"]),
  504. "search_type": "semantic",
  505. },
  506. title=row["title"],
  507. version=row["version"],
  508. size_in_bytes=row["size_in_bytes"],
  509. ingestion_status=IngestionStatus(row["ingestion_status"]),
  510. extraction_status=KGExtractionStatus(row["extraction_status"]),
  511. created_at=row["created_at"],
  512. updated_at=row["updated_at"],
  513. summary=row["summary"],
  514. summary_embedding=[
  515. float(x)
  516. for x in row["summary_embedding"][1:-1].split(",")
  517. if x
  518. ],
  519. )
  520. for row in results
  521. ]
  522. async def full_text_document_search(
  523. self, query_text: str, search_settings: SearchSettings
  524. ) -> list[DocumentResponse]:
  525. """Enhanced full-text search using generated tsvector."""
  526. where_clauses = ["raw_tsvector @@ websearch_to_tsquery('english', $1)"]
  527. params: list[str | int | bytes] = [query_text]
  528. if search_settings.filters:
  529. filter_condition, params = apply_filters(
  530. search_settings.filters, params, mode="condition_only"
  531. )
  532. if filter_condition:
  533. where_clauses.append(filter_condition)
  534. where_clause = " AND ".join(where_clauses)
  535. query = f"""
  536. WITH document_scores AS (
  537. SELECT
  538. id,
  539. collection_ids,
  540. owner_id,
  541. type,
  542. metadata,
  543. title,
  544. version,
  545. size_in_bytes,
  546. ingestion_status,
  547. extraction_status,
  548. created_at,
  549. updated_at,
  550. summary,
  551. summary_embedding,
  552. ts_rank_cd(raw_tsvector, websearch_to_tsquery('english', $1), 32) as text_score
  553. FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  554. WHERE {where_clause}
  555. ORDER BY text_score DESC
  556. LIMIT ${len(params) + 1}
  557. OFFSET ${len(params) + 2}
  558. )
  559. SELECT * FROM document_scores
  560. """
  561. params.extend([search_settings.limit, search_settings.offset])
  562. results = await self.connection_manager.fetch_query(query, params)
  563. return [
  564. DocumentResponse(
  565. id=row["id"],
  566. collection_ids=row["collection_ids"],
  567. owner_id=row["owner_id"],
  568. document_type=DocumentType(row["type"]),
  569. metadata={
  570. **(
  571. json.loads(row["metadata"])
  572. if search_settings.include_metadatas
  573. else {}
  574. ),
  575. "search_score": float(row["text_score"]),
  576. "search_type": "full_text",
  577. },
  578. title=row["title"],
  579. version=row["version"],
  580. size_in_bytes=row["size_in_bytes"],
  581. ingestion_status=IngestionStatus(row["ingestion_status"]),
  582. extraction_status=KGExtractionStatus(row["extraction_status"]),
  583. created_at=row["created_at"],
  584. updated_at=row["updated_at"],
  585. summary=row["summary"],
  586. summary_embedding=(
  587. [
  588. float(x)
  589. for x in row["summary_embedding"][1:-1].split(",")
  590. if x
  591. ]
  592. if row["summary_embedding"]
  593. else None
  594. ),
  595. )
  596. for row in results
  597. ]
  598. async def hybrid_document_search(
  599. self,
  600. query_text: str,
  601. query_embedding: list[float],
  602. search_settings: SearchSettings,
  603. ) -> list[DocumentResponse]:
  604. """Search documents using both semantic and full-text search with RRF fusion."""
  605. # Get more results than needed for better fusion
  606. extended_settings = copy.deepcopy(search_settings)
  607. extended_settings.limit = search_settings.limit * 3
  608. # Get results from both search methods
  609. semantic_results = await self.semantic_document_search(
  610. query_embedding, extended_settings
  611. )
  612. full_text_results = await self.full_text_document_search(
  613. query_text, extended_settings
  614. )
  615. # Combine results using RRF
  616. doc_scores: dict[str, dict] = {}
  617. # Process semantic results
  618. for rank, result in enumerate(semantic_results, 1):
  619. doc_id = str(result.id)
  620. doc_scores[doc_id] = {
  621. "semantic_rank": rank,
  622. "full_text_rank": len(full_text_results)
  623. + 1, # Default rank if not found
  624. "data": result,
  625. }
  626. # Process full-text results
  627. for rank, result in enumerate(full_text_results, 1):
  628. doc_id = str(result.id)
  629. if doc_id in doc_scores:
  630. doc_scores[doc_id]["full_text_rank"] = rank
  631. else:
  632. doc_scores[doc_id] = {
  633. "semantic_rank": len(semantic_results)
  634. + 1, # Default rank if not found
  635. "full_text_rank": rank,
  636. "data": result,
  637. }
  638. # Calculate RRF scores using hybrid search settings
  639. rrf_k = search_settings.hybrid_settings.rrf_k
  640. semantic_weight = search_settings.hybrid_settings.semantic_weight
  641. full_text_weight = search_settings.hybrid_settings.full_text_weight
  642. for scores in doc_scores.values():
  643. semantic_score = 1 / (rrf_k + scores["semantic_rank"])
  644. full_text_score = 1 / (rrf_k + scores["full_text_rank"])
  645. # Weighted combination
  646. combined_score = (
  647. semantic_score * semantic_weight
  648. + full_text_score * full_text_weight
  649. ) / (semantic_weight + full_text_weight)
  650. scores["final_score"] = combined_score
  651. # Sort by final score and apply offset/limit
  652. sorted_results = sorted(
  653. doc_scores.values(), key=lambda x: x["final_score"], reverse=True
  654. )[
  655. search_settings.offset : search_settings.offset
  656. + search_settings.limit
  657. ]
  658. return [
  659. DocumentResponse(
  660. **{
  661. **result["data"].__dict__,
  662. "metadata": {
  663. **(
  664. result["data"].metadata
  665. if search_settings.include_metadatas
  666. else {}
  667. ),
  668. "search_score": result["final_score"],
  669. "semantic_rank": result["semantic_rank"],
  670. "full_text_rank": result["full_text_rank"],
  671. "search_type": "hybrid",
  672. },
  673. }
  674. )
  675. for result in sorted_results
  676. ]
  677. async def search_documents(
  678. self,
  679. query_text: str,
  680. query_embedding: Optional[list[float]] = None,
  681. settings: Optional[SearchSettings] = None,
  682. ) -> list[DocumentResponse]:
  683. """
  684. Main search method that delegates to the appropriate search method based on settings.
  685. """
  686. if settings is None:
  687. settings = SearchSettings()
  688. if (
  689. settings.use_semantic_search and settings.use_fulltext_search
  690. ) or settings.use_hybrid_search:
  691. if query_embedding is None:
  692. raise ValueError(
  693. "query_embedding is required for hybrid search"
  694. )
  695. return await self.hybrid_document_search(
  696. query_text, query_embedding, settings
  697. )
  698. elif settings.use_semantic_search:
  699. if query_embedding is None:
  700. raise ValueError(
  701. "query_embedding is required for vector search"
  702. )
  703. return await self.semantic_document_search(
  704. query_embedding, settings
  705. )
  706. else:
  707. return await self.full_text_document_search(query_text, settings)
  708. async def export_to_csv(
  709. self,
  710. columns: Optional[list[str]] = None,
  711. filters: Optional[dict] = None,
  712. include_header: bool = True,
  713. ) -> tuple[str, IO]:
  714. """
  715. Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
  716. """
  717. valid_columns = {
  718. "id",
  719. "collection_ids",
  720. "owner_id",
  721. "type",
  722. "metadata",
  723. "title",
  724. "summary",
  725. "version",
  726. "size_in_bytes",
  727. "ingestion_status",
  728. "extraction_status",
  729. "created_at",
  730. "updated_at",
  731. }
  732. if not columns:
  733. columns = list(valid_columns)
  734. elif invalid_cols := set(columns) - valid_columns:
  735. raise ValueError(f"Invalid columns: {invalid_cols}")
  736. select_stmt = f"""
  737. SELECT
  738. id::text,
  739. collection_ids::text,
  740. owner_id::text,
  741. type::text,
  742. metadata::text AS metadata,
  743. title,
  744. summary,
  745. version,
  746. size_in_bytes,
  747. ingestion_status,
  748. extraction_status,
  749. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
  750. to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
  751. FROM {self._get_table_name(self.TABLE_NAME)}
  752. """
  753. conditions = []
  754. params: list[Any] = []
  755. param_index = 1
  756. if filters:
  757. for field, value in filters.items():
  758. if field not in valid_columns:
  759. continue
  760. if isinstance(value, dict):
  761. for op, val in value.items():
  762. if op == "$eq":
  763. conditions.append(f"{field} = ${param_index}")
  764. params.append(val)
  765. param_index += 1
  766. elif op == "$gt":
  767. conditions.append(f"{field} > ${param_index}")
  768. params.append(val)
  769. param_index += 1
  770. elif op == "$lt":
  771. conditions.append(f"{field} < ${param_index}")
  772. params.append(val)
  773. param_index += 1
  774. else:
  775. # Direct equality
  776. conditions.append(f"{field} = ${param_index}")
  777. params.append(value)
  778. param_index += 1
  779. if conditions:
  780. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  781. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  782. temp_file = None
  783. try:
  784. temp_file = tempfile.NamedTemporaryFile(
  785. mode="w", delete=True, suffix=".csv"
  786. )
  787. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  788. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  789. async with conn.transaction():
  790. cursor = await conn.cursor(select_stmt, *params)
  791. if include_header:
  792. writer.writerow(columns)
  793. chunk_size = 1000
  794. while True:
  795. rows = await cursor.fetch(chunk_size)
  796. if not rows:
  797. break
  798. for row in rows:
  799. writer.writerow(row)
  800. temp_file.flush()
  801. return temp_file.name, temp_file
  802. except Exception as e:
  803. if temp_file:
  804. temp_file.close()
  805. raise HTTPException(
  806. status_code=500,
  807. detail=f"Failed to export data: {str(e)}",
  808. ) from e