documents.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187
  1. import asyncio
  2. import copy
  3. import csv
  4. import json
  5. import logging
  6. import math
  7. import tempfile
  8. from typing import IO, Any, Optional
  9. from uuid import UUID
  10. import asyncpg
  11. from fastapi import HTTPException
  12. from core.base import (
  13. DocumentResponse,
  14. DocumentType,
  15. GraphConstructionStatus,
  16. GraphExtractionStatus,
  17. Handler,
  18. IngestionStatus,
  19. R2RException,
  20. SearchSettings,
  21. )
  22. from .base import PostgresConnectionManager
  23. from .filters import apply_filters
  24. logger = logging.getLogger()
  25. def transform_filter_fields(filters: dict[str, Any]) -> dict[str, Any]:
  26. """Recursively transform filter field names by replacing 'document_id' with
  27. 'id'. Handles nested logical operators like $and, $or, etc.
  28. Args:
  29. filters (dict[str, Any]): The original filters dictionary
  30. Returns:
  31. dict[str, Any]: A new dictionary with transformed field names
  32. """
  33. if not filters:
  34. return {}
  35. transformed = {}
  36. for key, value in filters.items():
  37. # Handle logical operators recursively
  38. if key in ("$and", "$or", "$not"):
  39. if isinstance(value, list):
  40. transformed[key] = [
  41. transform_filter_fields(item) for item in value
  42. ]
  43. else:
  44. transformed[key] = transform_filter_fields(value) # type: ignore
  45. continue
  46. # Replace 'document_id' with 'id'
  47. new_key = "id" if key == "document_id" else key
  48. # Handle nested dictionary cases (e.g., for operators like $eq, $gt, etc.)
  49. if isinstance(value, dict):
  50. transformed[new_key] = transform_filter_fields(value) # type: ignore
  51. else:
  52. transformed[new_key] = value
  53. logger.debug(f"Transformed filters from {filters} to {transformed}")
  54. return transformed
  55. class PostgresDocumentsHandler(Handler):
  56. TABLE_NAME = "documents"
  57. def __init__(
  58. self,
  59. project_name: str,
  60. connection_manager: PostgresConnectionManager,
  61. dimension: int | float,
  62. ):
  63. self.dimension = dimension
  64. super().__init__(project_name, connection_manager)
  65. async def create_tables(self):
  66. logger.info(
  67. f"Creating table, if it does not exist: {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
  68. )
  69. vector_dim = (
  70. "" if math.isnan(self.dimension) else f"({self.dimension})"
  71. )
  72. vector_type = f"vector{vector_dim}"
  73. try:
  74. query = f"""
  75. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} (
  76. id UUID PRIMARY KEY,
  77. collection_ids UUID[],
  78. owner_id UUID,
  79. type TEXT,
  80. metadata JSONB,
  81. title TEXT,
  82. summary TEXT NULL,
  83. summary_embedding {vector_type} NULL,
  84. version TEXT,
  85. size_in_bytes INT,
  86. ingestion_status TEXT DEFAULT 'pending',
  87. extraction_status TEXT DEFAULT 'pending',
  88. created_at TIMESTAMPTZ DEFAULT NOW(),
  89. updated_at TIMESTAMPTZ DEFAULT NOW(),
  90. ingestion_attempt_number INT DEFAULT 0,
  91. raw_tsvector tsvector GENERATED ALWAYS AS (
  92. setweight(to_tsvector('english', COALESCE(title, '')), 'A') ||
  93. setweight(to_tsvector('english', COALESCE(summary, '')), 'B') ||
  94. setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C')
  95. ) STORED,
  96. total_tokens INT DEFAULT 0
  97. );
  98. CREATE INDEX IF NOT EXISTS idx_collection_ids_{self.project_name}
  99. ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} USING GIN (collection_ids);
  100. -- Full text search index
  101. CREATE INDEX IF NOT EXISTS idx_doc_search_{self.project_name}
  102. ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  103. USING GIN (raw_tsvector);
  104. """
  105. await self.connection_manager.execute_query(query)
  106. # ---------------------------------------------------------------
  107. # Now check if total_tokens column exists in the 'documents' table
  108. # ---------------------------------------------------------------
  109. # 1) See what columns exist
  110. # column_check_query = f"""
  111. # SELECT column_name
  112. # FROM information_schema.columns
  113. # WHERE table_name = '{self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}'
  114. # AND table_schema = CURRENT_SCHEMA()
  115. # """
  116. # existing_columns = await self.connection_manager.fetch_query(column_check_query)
  117. # 2) Parse the table name for schema checks
  118. table_full_name = self._get_table_name(
  119. PostgresDocumentsHandler.TABLE_NAME
  120. )
  121. parsed_schema = "public"
  122. parsed_table_name = table_full_name
  123. if "." in table_full_name:
  124. parts = table_full_name.split(".", maxsplit=1)
  125. parsed_schema = parts[0].replace('"', "").strip()
  126. parsed_table_name = parts[1].replace('"', "").strip()
  127. else:
  128. parsed_table_name = parsed_table_name.replace('"', "").strip()
  129. # 3) Check columns
  130. column_check_query = f"""
  131. SELECT column_name
  132. FROM information_schema.columns
  133. WHERE table_name = '{parsed_table_name}'
  134. AND table_schema = '{parsed_schema}'
  135. """
  136. existing_columns = await self.connection_manager.fetch_query(
  137. column_check_query
  138. )
  139. existing_column_names = {
  140. row["column_name"] for row in existing_columns
  141. }
  142. if "total_tokens" not in existing_column_names:
  143. # 2) If missing, see if the table already has data
  144. # doc_count_query = f"SELECT COUNT(*) FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
  145. # doc_count = await self.connection_manager.fetchval(doc_count_query)
  146. doc_count_query = f"SELECT COUNT(*) AS doc_count FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
  147. row = await self.connection_manager.fetchrow_query(
  148. doc_count_query
  149. )
  150. if row is None:
  151. doc_count = 0
  152. else:
  153. doc_count = row[
  154. "doc_count"
  155. ] # or row[0] if you prefer positional indexing
  156. if doc_count > 0:
  157. # We already have documents, but no total_tokens column
  158. # => ask user to run r2r db migrate
  159. logger.warning(
  160. "Adding the missing 'total_tokens' column to the 'documents' table, this will impact existing files."
  161. )
  162. create_tokens_col = f"""
  163. ALTER TABLE {table_full_name}
  164. ADD COLUMN total_tokens INT DEFAULT 0
  165. """
  166. await self.connection_manager.execute_query(create_tokens_col)
  167. except Exception as e:
  168. logger.warning(f"Error {e} when creating document table.")
  169. raise e
  170. async def upsert_documents_overview(
  171. self, documents_overview: DocumentResponse | list[DocumentResponse]
  172. ) -> None:
  173. if isinstance(documents_overview, DocumentResponse):
  174. documents_overview = [documents_overview]
  175. # TODO: make this an arg
  176. max_retries = 20
  177. for document in documents_overview:
  178. retries = 0
  179. while retries < max_retries:
  180. try:
  181. async with (
  182. self.connection_manager.pool.get_connection() as conn # type: ignore
  183. ):
  184. async with conn.transaction():
  185. # Lock the row for update
  186. check_query = f"""
  187. SELECT ingestion_attempt_number, ingestion_status FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  188. WHERE id = $1 FOR UPDATE
  189. """
  190. existing_doc = await conn.fetchrow(
  191. check_query, document.id
  192. )
  193. db_entry = document.convert_to_db_entry()
  194. if existing_doc:
  195. db_version = existing_doc[
  196. "ingestion_attempt_number"
  197. ]
  198. db_status = existing_doc["ingestion_status"]
  199. new_version = db_entry[
  200. "ingestion_attempt_number"
  201. ]
  202. # Only increment version if status is changing to 'success' or if it's a new version
  203. if (
  204. db_status != "success"
  205. and db_entry["ingestion_status"]
  206. == "success"
  207. ) or (new_version > db_version):
  208. new_attempt_number = db_version + 1
  209. else:
  210. new_attempt_number = db_version
  211. db_entry["ingestion_attempt_number"] = (
  212. new_attempt_number
  213. )
  214. update_query = f"""
  215. UPDATE {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  216. SET collection_ids = $1,
  217. owner_id = $2,
  218. type = $3,
  219. metadata = $4,
  220. title = $5,
  221. version = $6,
  222. size_in_bytes = $7,
  223. ingestion_status = $8,
  224. extraction_status = $9,
  225. updated_at = $10,
  226. ingestion_attempt_number = $11,
  227. summary = $12,
  228. summary_embedding = $13,
  229. total_tokens = $14
  230. WHERE id = $15
  231. """
  232. await conn.execute(
  233. update_query,
  234. db_entry["collection_ids"],
  235. db_entry["owner_id"],
  236. db_entry["document_type"],
  237. db_entry["metadata"],
  238. db_entry["title"],
  239. db_entry["version"],
  240. db_entry["size_in_bytes"],
  241. db_entry["ingestion_status"],
  242. db_entry["extraction_status"],
  243. db_entry["updated_at"],
  244. db_entry["ingestion_attempt_number"],
  245. db_entry["summary"],
  246. db_entry["summary_embedding"],
  247. db_entry[
  248. "total_tokens"
  249. ], # pass the new field here
  250. document.id,
  251. )
  252. else:
  253. insert_query = f"""
  254. INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  255. (id, collection_ids, owner_id, type, metadata, title, version,
  256. size_in_bytes, ingestion_status, extraction_status, created_at,
  257. updated_at, ingestion_attempt_number, summary, summary_embedding, total_tokens)
  258. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
  259. """
  260. await conn.execute(
  261. insert_query,
  262. db_entry["id"],
  263. db_entry["collection_ids"],
  264. db_entry["owner_id"],
  265. db_entry["document_type"],
  266. db_entry["metadata"],
  267. db_entry["title"],
  268. db_entry["version"],
  269. db_entry["size_in_bytes"],
  270. db_entry["ingestion_status"],
  271. db_entry["extraction_status"],
  272. db_entry["created_at"],
  273. db_entry["updated_at"],
  274. db_entry["ingestion_attempt_number"],
  275. db_entry["summary"],
  276. db_entry["summary_embedding"],
  277. db_entry["total_tokens"],
  278. )
  279. break # Success, exit the retry loop
  280. except (
  281. asyncpg.exceptions.UniqueViolationError,
  282. asyncpg.exceptions.DeadlockDetectedError,
  283. ) as e:
  284. retries += 1
  285. if retries == max_retries:
  286. logger.error(
  287. f"Failed to update document {document.id} after {max_retries} attempts. Error: {str(e)}"
  288. )
  289. raise
  290. else:
  291. wait_time = 0.1 * (2**retries) # Exponential backoff
  292. await asyncio.sleep(wait_time)
  293. async def delete(
  294. self, document_id: UUID, version: Optional[str] = None
  295. ) -> None:
  296. query = f"""
  297. DELETE FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  298. WHERE id = $1
  299. """
  300. params = [str(document_id)]
  301. if version:
  302. query += " AND version = $2"
  303. params.append(version)
  304. await self.connection_manager.execute_query(query=query, params=params)
  305. async def _get_status_from_table(
  306. self,
  307. ids: list[UUID],
  308. table_name: str,
  309. status_type: str,
  310. column_name: str,
  311. ):
  312. """Get the workflow status for a given document or list of documents.
  313. Args:
  314. ids (list[UUID]): The document IDs.
  315. table_name (str): The table name.
  316. status_type (str): The type of status to retrieve.
  317. Returns:
  318. The workflow status for the given document or list of documents.
  319. """
  320. query = f"""
  321. SELECT {status_type} FROM {self._get_table_name(table_name)}
  322. WHERE {column_name} = ANY($1)
  323. """
  324. return [
  325. row[status_type]
  326. for row in await self.connection_manager.fetch_query(query, [ids])
  327. ]
  328. async def _get_ids_from_table(
  329. self,
  330. status: list[str],
  331. table_name: str,
  332. status_type: str,
  333. collection_id: Optional[UUID] = None,
  334. ):
  335. """Get the IDs from a given table.
  336. Args:
  337. status (str | list[str]): The status or list of statuses to retrieve.
  338. table_name (str): The table name.
  339. status_type (str): The type of status to retrieve.
  340. """
  341. query = f"""
  342. SELECT id FROM {self._get_table_name(table_name)}
  343. WHERE {status_type} = ANY($1) and $2 = ANY(collection_ids)
  344. """
  345. records = await self.connection_manager.fetch_query(
  346. query, [status, collection_id]
  347. )
  348. return [record["id"] for record in records]
  349. async def _set_status_in_table(
  350. self,
  351. ids: list[UUID],
  352. status: str,
  353. table_name: str,
  354. status_type: str,
  355. column_name: str,
  356. ):
  357. """Set the workflow status for a given document or list of documents.
  358. Args:
  359. ids (list[UUID]): The document IDs.
  360. status (str): The status to set.
  361. table_name (str): The table name.
  362. status_type (str): The type of status to set.
  363. column_name (str): The column name in the table to update.
  364. """
  365. query = f"""
  366. UPDATE {self._get_table_name(table_name)}
  367. SET {status_type} = $1
  368. WHERE {column_name} = Any($2)
  369. """
  370. await self.connection_manager.execute_query(query, [status, ids])
  371. def _get_status_model(self, status_type: str):
  372. """Get the status model for a given status type.
  373. Args:
  374. status_type (str): The type of status to retrieve.
  375. Returns:
  376. The status model for the given status type.
  377. """
  378. if status_type == "ingestion":
  379. return IngestionStatus
  380. elif status_type == "extraction_status":
  381. return GraphExtractionStatus
  382. elif status_type in {"graph_cluster_status", "graph_sync_status"}:
  383. return GraphConstructionStatus
  384. else:
  385. raise R2RException(
  386. status_code=400, message=f"Invalid status type: {status_type}"
  387. )
  388. async def get_workflow_status(
  389. self, id: UUID | list[UUID], status_type: str
  390. ):
  391. """Get the workflow status for a given document or list of documents.
  392. Args:
  393. id (UUID | list[UUID]): The document ID or list of document IDs.
  394. status_type (str): The type of status to retrieve.
  395. Returns:
  396. The workflow status for the given document or list of documents.
  397. """
  398. ids = [id] if isinstance(id, UUID) else id
  399. out_model = self._get_status_model(status_type)
  400. result = await self._get_status_from_table(
  401. ids,
  402. out_model.table_name(),
  403. status_type,
  404. out_model.id_column(),
  405. )
  406. result = [out_model[status.upper()] for status in result]
  407. return result[0] if isinstance(id, UUID) else result
  408. async def set_workflow_status(
  409. self, id: UUID | list[UUID], status_type: str, status: str
  410. ):
  411. """Set the workflow status for a given document or list of documents.
  412. Args:
  413. id (UUID | list[UUID]): The document ID or list of document IDs.
  414. status_type (str): The type of status to set.
  415. status (str): The status to set.
  416. """
  417. ids = [id] if isinstance(id, UUID) else id
  418. out_model = self._get_status_model(status_type)
  419. return await self._set_status_in_table(
  420. ids,
  421. status,
  422. out_model.table_name(),
  423. status_type,
  424. out_model.id_column(),
  425. )
  426. async def get_document_ids_by_status(
  427. self,
  428. status_type: str,
  429. status: str | list[str],
  430. collection_id: Optional[UUID] = None,
  431. ):
  432. """Get the IDs for a given status.
  433. Args:
  434. ids_key (str): The key to retrieve the IDs.
  435. status_type (str): The type of status to retrieve.
  436. status (str | list[str]): The status or list of statuses to retrieve.
  437. """
  438. if isinstance(status, str):
  439. status = [status]
  440. out_model = self._get_status_model(status_type)
  441. return await self._get_ids_from_table(
  442. status, out_model.table_name(), status_type, collection_id
  443. )
  444. async def get_documents_overview(
  445. self,
  446. offset: int,
  447. limit: int,
  448. filter_user_ids: Optional[list[UUID]] = None,
  449. filter_document_ids: Optional[list[UUID]] = None,
  450. filter_collection_ids: Optional[list[UUID]] = None,
  451. include_summary_embedding: Optional[bool] = True,
  452. filters: Optional[dict[str, Any]] = None,
  453. sort_order: str = "DESC",
  454. owner_only: bool = False,
  455. ) -> dict[str, Any]:
  456. """Fetch overviews of documents with optional offset/limit pagination.
  457. You can use either:
  458. - Traditional filters: `filter_user_ids`, `filter_document_ids`, `filter_collection_ids`
  459. - A `filters` dict (e.g., like we do in semantic search), which will be passed to `apply_filters`.
  460. If both the `filters` dict and any of the traditional filter arguments are provided,
  461. this method will raise an error.
  462. """
  463. filters = copy.deepcopy(filters)
  464. filters = transform_filter_fields(filters) # type: ignore
  465. # Safety check: We do not allow mixing the old filter arguments with the new `filters` dict.
  466. # This keeps the query logic unambiguous.
  467. if filters and any(
  468. [
  469. filter_user_ids,
  470. filter_document_ids,
  471. filter_collection_ids,
  472. ]
  473. ):
  474. raise HTTPException(
  475. status_code=400,
  476. detail=(
  477. "Cannot use both the 'filters' dictionary "
  478. "and the 'filter_*_ids' parameters simultaneously."
  479. ),
  480. )
  481. conditions = []
  482. params: list[Any] = []
  483. param_index = 1
  484. # -------------------------------------------
  485. # 1) If using the new `filters` dict approach
  486. # -------------------------------------------
  487. if filters:
  488. # Apply the filters to generate a WHERE clause
  489. filter_condition, filter_params = apply_filters(
  490. filters, params, mode="condition_only"
  491. )
  492. if filter_condition:
  493. conditions.append(filter_condition)
  494. # Make sure we keep adding to the same params list
  495. params.extend(filter_params)
  496. param_index += len(filter_params)
  497. # -------------------------------------------
  498. # 2) If using the old filter_*_ids approach
  499. # -------------------------------------------
  500. else:
  501. # Handle document IDs with AND
  502. if filter_document_ids:
  503. conditions.append(f"id = ANY(${param_index})")
  504. params.append(filter_document_ids)
  505. param_index += 1
  506. # For owner/collection filters, we used OR logic previously
  507. # so we combine them into a single sub-condition in parentheses
  508. owner_conditions = []
  509. collection_conditions = []
  510. if filter_user_ids:
  511. owner_conditions.append(f"owner_id = ANY(${param_index})")
  512. params.append(filter_user_ids)
  513. param_index += 1
  514. if filter_collection_ids:
  515. collection_conditions.append(
  516. f"collection_ids && ${param_index}"
  517. )
  518. params.append(filter_collection_ids)
  519. param_index += 1
  520. if owner_only:
  521. if owner_conditions:
  522. conditions.append(f"({' OR '.join(owner_conditions)})")
  523. if collection_conditions:
  524. conditions.append(
  525. f"({' OR '.join(collection_conditions)})"
  526. )
  527. elif (
  528. combined_conditions := owner_conditions + collection_conditions
  529. ):
  530. conditions.append(f"({' OR '.join(combined_conditions)})")
  531. # -------------------------
  532. # Build the full query
  533. # -------------------------
  534. base_query = (
  535. f"FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
  536. )
  537. if conditions:
  538. # Combine everything with AND
  539. base_query += " WHERE " + " AND ".join(conditions)
  540. # Construct SELECT fields (including total_entries via window function)
  541. select_fields = """
  542. SELECT
  543. id,
  544. collection_ids,
  545. owner_id,
  546. type,
  547. metadata,
  548. title,
  549. version,
  550. size_in_bytes,
  551. ingestion_status,
  552. extraction_status,
  553. created_at,
  554. updated_at,
  555. summary,
  556. summary_embedding,
  557. total_tokens,
  558. COUNT(*) OVER() AS total_entries
  559. """
  560. query = f"""
  561. {select_fields}
  562. {base_query}
  563. ORDER BY created_at {sort_order}
  564. OFFSET ${param_index}
  565. """
  566. params.append(offset)
  567. param_index += 1
  568. if limit != -1:
  569. query += f" LIMIT ${param_index}"
  570. params.append(limit)
  571. param_index += 1
  572. try:
  573. results = await self.connection_manager.fetch_query(query, params)
  574. total_entries = results[0]["total_entries"] if results else 0
  575. documents = []
  576. for row in results:
  577. # Safely handle the embedding
  578. embedding = None
  579. if (
  580. "summary_embedding" in row
  581. and row["summary_embedding"] is not None
  582. ):
  583. try:
  584. # The embedding is stored as a string like "[0.1, 0.2, ...]"
  585. embedding_str = row["summary_embedding"]
  586. if embedding_str.startswith(
  587. "["
  588. ) and embedding_str.endswith("]"):
  589. embedding = [
  590. float(x)
  591. for x in embedding_str[1:-1].split(",")
  592. if x
  593. ]
  594. except Exception as e:
  595. logger.warning(
  596. f"Failed to parse embedding for document {row['id']}: {e}"
  597. )
  598. documents.append(
  599. DocumentResponse(
  600. id=row["id"],
  601. collection_ids=row["collection_ids"],
  602. owner_id=row["owner_id"],
  603. document_type=DocumentType(row["type"]),
  604. metadata=json.loads(row["metadata"]),
  605. title=row["title"],
  606. version=row["version"],
  607. size_in_bytes=row["size_in_bytes"],
  608. ingestion_status=IngestionStatus(
  609. row["ingestion_status"]
  610. ),
  611. extraction_status=GraphExtractionStatus(
  612. row["extraction_status"]
  613. ),
  614. created_at=row["created_at"],
  615. updated_at=row["updated_at"],
  616. summary=row["summary"] if "summary" in row else None,
  617. summary_embedding=(
  618. embedding if include_summary_embedding else None
  619. ),
  620. total_tokens=row["total_tokens"],
  621. )
  622. )
  623. return {"results": documents, "total_entries": total_entries}
  624. except Exception as e:
  625. logger.error(f"Error in get_documents_overview: {str(e)}")
  626. raise HTTPException(
  627. status_code=500,
  628. detail="Database query failed",
  629. ) from e
  630. async def update_document_metadata(
  631. self,
  632. document_id: UUID,
  633. metadata: list[dict],
  634. overwrite: bool = False,
  635. ) -> DocumentResponse:
  636. """
  637. Update the metadata of a document, either by appending to existing metadata or overwriting it.
  638. Accepts a list of metadata dictionaries.
  639. """
  640. doc_result = await self.get_documents_overview(
  641. offset=0,
  642. limit=1,
  643. filter_document_ids=[document_id],
  644. )
  645. if not doc_result["results"]:
  646. raise HTTPException(
  647. status_code=404,
  648. detail=f"Document with ID {document_id} not found",
  649. )
  650. existing_doc = doc_result["results"][0]
  651. if overwrite:
  652. combined_metadata: dict[str, Any] = {}
  653. for meta_item in metadata:
  654. combined_metadata |= meta_item
  655. existing_doc.metadata = combined_metadata
  656. else:
  657. for meta_item in metadata:
  658. existing_doc.metadata.update(meta_item)
  659. await self.upsert_documents_overview(existing_doc)
  660. return existing_doc
  661. async def semantic_document_search(
  662. self, query_embedding: list[float], search_settings: SearchSettings
  663. ) -> list[DocumentResponse]:
  664. """Search documents using semantic similarity with their summary
  665. embeddings."""
  666. where_clauses = ["summary_embedding IS NOT NULL"]
  667. params: list[str | int | bytes] = [str(query_embedding)]
  668. vector_dim = (
  669. "" if math.isnan(self.dimension) else f"({self.dimension})"
  670. )
  671. filters = copy.deepcopy(search_settings.filters)
  672. if filters:
  673. filter_condition, params = apply_filters(
  674. transform_filter_fields(filters), params, mode="condition_only"
  675. )
  676. if filter_condition:
  677. where_clauses.append(filter_condition)
  678. where_clause = " AND ".join(where_clauses)
  679. query = f"""
  680. WITH document_scores AS (
  681. SELECT
  682. id,
  683. collection_ids,
  684. owner_id,
  685. type,
  686. metadata,
  687. title,
  688. version,
  689. size_in_bytes,
  690. ingestion_status,
  691. extraction_status,
  692. created_at,
  693. updated_at,
  694. summary,
  695. summary_embedding,
  696. total_tokens,
  697. (summary_embedding <=> $1::vector({vector_dim})) as semantic_distance
  698. FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  699. WHERE {where_clause}
  700. ORDER BY semantic_distance ASC
  701. LIMIT ${len(params) + 1}
  702. OFFSET ${len(params) + 2}
  703. )
  704. SELECT *,
  705. 1.0 - semantic_distance as semantic_score
  706. FROM document_scores
  707. """
  708. params.extend([search_settings.limit, search_settings.offset])
  709. results = await self.connection_manager.fetch_query(query, params)
  710. return [
  711. DocumentResponse(
  712. id=row["id"],
  713. collection_ids=row["collection_ids"],
  714. owner_id=row["owner_id"],
  715. document_type=DocumentType(row["type"]),
  716. metadata={
  717. **(
  718. json.loads(row["metadata"])
  719. if search_settings.include_metadatas
  720. else {}
  721. ),
  722. "search_score": float(row["semantic_score"]),
  723. "search_type": "semantic",
  724. },
  725. title=row["title"],
  726. version=row["version"],
  727. size_in_bytes=row["size_in_bytes"],
  728. ingestion_status=IngestionStatus(row["ingestion_status"]),
  729. extraction_status=GraphExtractionStatus(
  730. row["extraction_status"]
  731. ),
  732. created_at=row["created_at"],
  733. updated_at=row["updated_at"],
  734. summary=row["summary"],
  735. summary_embedding=[
  736. float(x)
  737. for x in row["summary_embedding"][1:-1].split(",")
  738. if x
  739. ],
  740. total_tokens=row["total_tokens"],
  741. )
  742. for row in results
  743. ]
  744. async def full_text_document_search(
  745. self, query_text: str, search_settings: SearchSettings
  746. ) -> list[DocumentResponse]:
  747. """Enhanced full-text search using generated tsvector."""
  748. where_clauses = ["raw_tsvector @@ websearch_to_tsquery('english', $1)"]
  749. params: list[str | int | bytes] = [query_text]
  750. filters = copy.deepcopy(search_settings.filters)
  751. if filters:
  752. filter_condition, params = apply_filters(
  753. transform_filter_fields(filters), params, mode="condition_only"
  754. )
  755. if filter_condition:
  756. where_clauses.append(filter_condition)
  757. where_clause = " AND ".join(where_clauses)
  758. query = f"""
  759. WITH document_scores AS (
  760. SELECT
  761. id,
  762. collection_ids,
  763. owner_id,
  764. type,
  765. metadata,
  766. title,
  767. version,
  768. size_in_bytes,
  769. ingestion_status,
  770. extraction_status,
  771. created_at,
  772. updated_at,
  773. summary,
  774. summary_embedding,
  775. total_tokens,
  776. ts_rank_cd(raw_tsvector, websearch_to_tsquery('english', $1), 32) as text_score
  777. FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
  778. WHERE {where_clause}
  779. ORDER BY text_score DESC
  780. LIMIT ${len(params) + 1}
  781. OFFSET ${len(params) + 2}
  782. )
  783. SELECT * FROM document_scores
  784. """
  785. params.extend([search_settings.limit, search_settings.offset])
  786. results = await self.connection_manager.fetch_query(query, params)
  787. return [
  788. DocumentResponse(
  789. id=row["id"],
  790. collection_ids=row["collection_ids"],
  791. owner_id=row["owner_id"],
  792. document_type=DocumentType(row["type"]),
  793. metadata={
  794. **(
  795. json.loads(row["metadata"])
  796. if search_settings.include_metadatas
  797. else {}
  798. ),
  799. "search_score": float(row["text_score"]),
  800. "search_type": "full_text",
  801. },
  802. title=row["title"],
  803. version=row["version"],
  804. size_in_bytes=row["size_in_bytes"],
  805. ingestion_status=IngestionStatus(row["ingestion_status"]),
  806. extraction_status=GraphExtractionStatus(
  807. row["extraction_status"]
  808. ),
  809. created_at=row["created_at"],
  810. updated_at=row["updated_at"],
  811. summary=row["summary"],
  812. summary_embedding=(
  813. [
  814. float(x)
  815. for x in row["summary_embedding"][1:-1].split(",")
  816. if x
  817. ]
  818. if row["summary_embedding"]
  819. else None
  820. ),
  821. total_tokens=row["total_tokens"],
  822. )
  823. for row in results
  824. ]
  825. async def hybrid_document_search(
  826. self,
  827. query_text: str,
  828. query_embedding: list[float],
  829. search_settings: SearchSettings,
  830. ) -> list[DocumentResponse]:
  831. """Search documents using both semantic and full-text search with RRF
  832. fusion."""
  833. # Get more results than needed for better fusion
  834. extended_settings = copy.deepcopy(search_settings)
  835. extended_settings.limit = search_settings.limit * 3
  836. # Get results from both search methods
  837. semantic_results = await self.semantic_document_search(
  838. query_embedding, extended_settings
  839. )
  840. full_text_results = await self.full_text_document_search(
  841. query_text, extended_settings
  842. )
  843. # Combine results using RRF
  844. doc_scores: dict[str, dict] = {}
  845. # Process semantic results
  846. for rank, result in enumerate(semantic_results, 1):
  847. doc_id = str(result.id)
  848. doc_scores[doc_id] = {
  849. "semantic_rank": rank,
  850. "full_text_rank": len(full_text_results)
  851. + 1, # Default rank if not found
  852. "data": result,
  853. }
  854. # Process full-text results
  855. for rank, result in enumerate(full_text_results, 1):
  856. doc_id = str(result.id)
  857. if doc_id in doc_scores:
  858. doc_scores[doc_id]["full_text_rank"] = rank
  859. else:
  860. doc_scores[doc_id] = {
  861. "semantic_rank": len(semantic_results)
  862. + 1, # Default rank if not found
  863. "full_text_rank": rank,
  864. "data": result,
  865. }
  866. # Calculate RRF scores using hybrid search settings
  867. rrf_k = search_settings.hybrid_settings.rrf_k
  868. semantic_weight = search_settings.hybrid_settings.semantic_weight
  869. full_text_weight = search_settings.hybrid_settings.full_text_weight
  870. for scores in doc_scores.values():
  871. semantic_score = 1 / (rrf_k + scores["semantic_rank"])
  872. full_text_score = 1 / (rrf_k + scores["full_text_rank"])
  873. # Weighted combination
  874. combined_score = (
  875. semantic_score * semantic_weight
  876. + full_text_score * full_text_weight
  877. ) / (semantic_weight + full_text_weight)
  878. scores["final_score"] = combined_score
  879. # Sort by final score and apply offset/limit
  880. sorted_results = sorted(
  881. doc_scores.values(), key=lambda x: x["final_score"], reverse=True
  882. )[
  883. search_settings.offset : search_settings.offset
  884. + search_settings.limit
  885. ]
  886. return [
  887. DocumentResponse(
  888. **{
  889. **result["data"].__dict__,
  890. "metadata": {
  891. **(
  892. result["data"].metadata
  893. if search_settings.include_metadatas
  894. else {}
  895. ),
  896. "search_score": result["final_score"],
  897. "semantic_rank": result["semantic_rank"],
  898. "full_text_rank": result["full_text_rank"],
  899. "search_type": "hybrid",
  900. },
  901. }
  902. )
  903. for result in sorted_results
  904. ]
  905. async def search_documents(
  906. self,
  907. query_text: str,
  908. query_embedding: Optional[list[float]] = None,
  909. settings: Optional[SearchSettings] = None,
  910. ) -> list[DocumentResponse]:
  911. """Main search method that delegates to the appropriate search method
  912. based on settings."""
  913. if settings is None:
  914. settings = SearchSettings()
  915. if (
  916. settings.use_semantic_search and settings.use_fulltext_search
  917. ) or settings.use_hybrid_search:
  918. if query_embedding is None:
  919. raise ValueError(
  920. "query_embedding is required for hybrid search"
  921. )
  922. return await self.hybrid_document_search(
  923. query_text, query_embedding, settings
  924. )
  925. elif settings.use_semantic_search:
  926. if query_embedding is None:
  927. raise ValueError(
  928. "query_embedding is required for vector search"
  929. )
  930. return await self.semantic_document_search(
  931. query_embedding, settings
  932. )
  933. else:
  934. return await self.full_text_document_search(query_text, settings)
  935. async def export_to_csv(
  936. self,
  937. columns: Optional[list[str]] = None,
  938. filters: Optional[dict] = None,
  939. include_header: bool = True,
  940. ) -> tuple[str, IO]:
  941. """Creates a CSV file from the PostgreSQL data and returns the path to
  942. the temp file."""
  943. valid_columns = {
  944. "id",
  945. "collection_ids",
  946. "owner_id",
  947. "type",
  948. "metadata",
  949. "title",
  950. "summary",
  951. "version",
  952. "size_in_bytes",
  953. "ingestion_status",
  954. "extraction_status",
  955. "created_at",
  956. "updated_at",
  957. "total_tokens",
  958. }
  959. filters = copy.deepcopy(filters)
  960. filters = transform_filter_fields(filters) # type: ignore
  961. if not columns:
  962. columns = list(valid_columns)
  963. elif invalid_cols := set(columns) - valid_columns:
  964. raise ValueError(f"Invalid columns: {invalid_cols}")
  965. select_stmt = f"""
  966. SELECT
  967. id::text,
  968. collection_ids::text,
  969. owner_id::text,
  970. type::text,
  971. metadata::text AS metadata,
  972. title,
  973. summary,
  974. version,
  975. size_in_bytes,
  976. ingestion_status,
  977. extraction_status,
  978. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
  979. to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
  980. total_tokens
  981. FROM {self._get_table_name(self.TABLE_NAME)}
  982. """
  983. conditions = []
  984. params: list[Any] = []
  985. param_index = 1
  986. if filters:
  987. for field, value in filters.items():
  988. if field not in valid_columns:
  989. continue
  990. if isinstance(value, dict):
  991. for op, val in value.items():
  992. if op == "$eq":
  993. conditions.append(f"{field} = ${param_index}")
  994. params.append(val)
  995. param_index += 1
  996. elif op == "$gt":
  997. conditions.append(f"{field} > ${param_index}")
  998. params.append(val)
  999. param_index += 1
  1000. elif op == "$lt":
  1001. conditions.append(f"{field} < ${param_index}")
  1002. params.append(val)
  1003. param_index += 1
  1004. else:
  1005. # Direct equality
  1006. conditions.append(f"{field} = ${param_index}")
  1007. params.append(value)
  1008. param_index += 1
  1009. if conditions:
  1010. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  1011. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  1012. temp_file = None
  1013. try:
  1014. temp_file = tempfile.NamedTemporaryFile(
  1015. mode="w", delete=True, suffix=".csv"
  1016. )
  1017. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  1018. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  1019. async with conn.transaction():
  1020. cursor = await conn.cursor(select_stmt, *params)
  1021. if include_header:
  1022. writer.writerow(columns)
  1023. chunk_size = 1000
  1024. while True:
  1025. rows = await cursor.fetch(chunk_size)
  1026. if not rows:
  1027. break
  1028. for row in rows:
  1029. row_dict = {
  1030. "id": row[0],
  1031. "collection_ids": row[1],
  1032. "owner_id": row[2],
  1033. "type": row[3],
  1034. "metadata": row[4],
  1035. "title": row[5],
  1036. "summary": row[6],
  1037. "version": row[7],
  1038. "size_in_bytes": row[8],
  1039. "ingestion_status": row[9],
  1040. "extraction_status": row[10],
  1041. "created_at": row[11],
  1042. "updated_at": row[12],
  1043. "total_tokens": row[13],
  1044. }
  1045. writer.writerow([row_dict[col] for col in columns])
  1046. temp_file.flush()
  1047. return temp_file.name, temp_file
  1048. except Exception as e:
  1049. if temp_file:
  1050. temp_file.close()
  1051. raise HTTPException(
  1052. status_code=500,
  1053. detail=f"Failed to export data: {str(e)}",
  1054. ) from e