collections.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. import csv
  2. import json
  3. import logging
  4. import tempfile
  5. from typing import IO, Any, Optional
  6. from uuid import UUID, uuid4
  7. from asyncpg.exceptions import UniqueViolationError
  8. from fastapi import HTTPException
  9. from core.base import (
  10. DatabaseConfig,
  11. Handler,
  12. KGExtractionStatus,
  13. R2RException,
  14. generate_default_user_collection_id,
  15. )
  16. from core.base.abstractions import (
  17. DocumentResponse,
  18. DocumentType,
  19. IngestionStatus,
  20. )
  21. from core.base.api.models import CollectionResponse
  22. from core.utils import generate_default_user_collection_id
  23. from .base import PostgresConnectionManager
  24. logger = logging.getLogger()
  25. class PostgresCollectionsHandler(Handler):
  26. TABLE_NAME = "collections"
  27. def __init__(
  28. self,
  29. project_name: str,
  30. connection_manager: PostgresConnectionManager,
  31. config: DatabaseConfig,
  32. ):
  33. self.config = config
  34. super().__init__(project_name, connection_manager)
  35. async def create_tables(self) -> None:
  36. query = f"""
  37. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)} (
  38. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  39. owner_id UUID,
  40. name TEXT NOT NULL,
  41. description TEXT,
  42. graph_sync_status TEXT DEFAULT 'pending',
  43. graph_cluster_status TEXT DEFAULT 'pending',
  44. created_at TIMESTAMPTZ DEFAULT NOW(),
  45. updated_at TIMESTAMPTZ DEFAULT NOW(),
  46. user_count INT DEFAULT 0,
  47. document_count INT DEFAULT 0
  48. );
  49. """
  50. await self.connection_manager.execute_query(query)
  51. async def collection_exists(self, collection_id: UUID) -> bool:
  52. """Check if a collection exists."""
  53. query = f"""
  54. SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
  55. WHERE id = $1
  56. """
  57. result = await self.connection_manager.fetchrow_query(
  58. query, [collection_id]
  59. )
  60. return result is not None
  61. async def create_collection(
  62. self,
  63. owner_id: UUID,
  64. name: Optional[str] = None,
  65. description: str = "",
  66. collection_id: Optional[UUID] = None,
  67. ) -> CollectionResponse:
  68. if not name and not collection_id:
  69. name = self.config.default_collection_name
  70. collection_id = generate_default_user_collection_id(owner_id)
  71. query = f"""
  72. INSERT INTO {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
  73. (id, owner_id, name, description)
  74. VALUES ($1, $2, $3, $4)
  75. RETURNING id, owner_id, name, description, graph_sync_status, graph_cluster_status, created_at, updated_at
  76. """
  77. params = [
  78. collection_id or uuid4(),
  79. owner_id,
  80. name,
  81. description,
  82. ]
  83. try:
  84. result = await self.connection_manager.fetchrow_query(
  85. query=query,
  86. params=params,
  87. )
  88. if not result:
  89. raise R2RException(
  90. status_code=404, message="Collection not found"
  91. )
  92. return CollectionResponse(
  93. id=result["id"],
  94. owner_id=result["owner_id"],
  95. name=result["name"],
  96. description=result["description"],
  97. graph_cluster_status=result["graph_cluster_status"],
  98. graph_sync_status=result["graph_sync_status"],
  99. created_at=result["created_at"],
  100. updated_at=result["updated_at"],
  101. user_count=0,
  102. document_count=0,
  103. )
  104. except UniqueViolationError:
  105. raise R2RException(
  106. message="Collection with this ID already exists",
  107. status_code=409,
  108. )
  109. except Exception as e:
  110. raise HTTPException(
  111. status_code=500,
  112. detail=f"An error occurred while creating the collection: {e}",
  113. ) from e
  114. async def update_collection(
  115. self,
  116. collection_id: UUID,
  117. name: Optional[str] = None,
  118. description: Optional[str] = None,
  119. ) -> CollectionResponse:
  120. """Update an existing collection."""
  121. if not await self.collection_exists(collection_id):
  122. raise R2RException(status_code=404, message="Collection not found")
  123. update_fields = []
  124. params: list = []
  125. param_index = 1
  126. if name is not None:
  127. update_fields.append(f"name = ${param_index}")
  128. params.append(name)
  129. param_index += 1
  130. if description is not None:
  131. update_fields.append(f"description = ${param_index}")
  132. params.append(description)
  133. param_index += 1
  134. if not update_fields:
  135. raise R2RException(status_code=400, message="No fields to update")
  136. update_fields.append("updated_at = NOW()")
  137. params.append(collection_id)
  138. query = f"""
  139. WITH updated_collection AS (
  140. UPDATE {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
  141. SET {', '.join(update_fields)}
  142. WHERE id = ${param_index}
  143. RETURNING id, owner_id, name, description, graph_sync_status, graph_cluster_status, created_at, updated_at
  144. )
  145. SELECT
  146. uc.*,
  147. COUNT(DISTINCT u.id) FILTER (WHERE u.id IS NOT NULL) as user_count,
  148. COUNT(DISTINCT d.id) FILTER (WHERE d.id IS NOT NULL) as document_count
  149. FROM updated_collection uc
  150. LEFT JOIN {self._get_table_name('users')} u ON uc.id = ANY(u.collection_ids)
  151. LEFT JOIN {self._get_table_name('documents')} d ON uc.id = ANY(d.collection_ids)
  152. GROUP BY uc.id, uc.owner_id, uc.name, uc.description, uc.graph_sync_status, uc.graph_cluster_status, uc.created_at, uc.updated_at
  153. """
  154. try:
  155. result = await self.connection_manager.fetchrow_query(
  156. query, params
  157. )
  158. if not result:
  159. raise R2RException(
  160. status_code=404, message="Collection not found"
  161. )
  162. return CollectionResponse(
  163. id=result["id"],
  164. owner_id=result["owner_id"],
  165. name=result["name"],
  166. description=result["description"],
  167. graph_sync_status=result["graph_sync_status"],
  168. graph_cluster_status=result["graph_cluster_status"],
  169. created_at=result["created_at"],
  170. updated_at=result["updated_at"],
  171. user_count=result["user_count"],
  172. document_count=result["document_count"],
  173. )
  174. except Exception as e:
  175. raise HTTPException(
  176. status_code=500,
  177. detail=f"An error occurred while updating the collection: {e}",
  178. ) from e
  179. async def delete_collection_relational(self, collection_id: UUID) -> None:
  180. # Remove collection_id from users
  181. user_update_query = f"""
  182. UPDATE {self._get_table_name('users')}
  183. SET collection_ids = array_remove(collection_ids, $1)
  184. WHERE $1 = ANY(collection_ids)
  185. """
  186. await self.connection_manager.execute_query(
  187. user_update_query, [collection_id]
  188. )
  189. # Remove collection_id from documents
  190. document_update_query = f"""
  191. WITH updated AS (
  192. UPDATE {self._get_table_name('documents')}
  193. SET collection_ids = array_remove(collection_ids, $1)
  194. WHERE $1 = ANY(collection_ids)
  195. RETURNING 1
  196. )
  197. SELECT COUNT(*) AS affected_rows FROM updated
  198. """
  199. await self.connection_manager.fetchrow_query(
  200. document_update_query, [collection_id]
  201. )
  202. # Delete the collection
  203. delete_query = f"""
  204. DELETE FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
  205. WHERE id = $1
  206. RETURNING id
  207. """
  208. deleted = await self.connection_manager.fetchrow_query(
  209. delete_query, [collection_id]
  210. )
  211. if not deleted:
  212. raise R2RException(status_code=404, message="Collection not found")
  213. async def documents_in_collection(
  214. self, collection_id: UUID, offset: int, limit: int
  215. ) -> dict[str, list[DocumentResponse] | int]:
  216. """
  217. Get all documents in a specific collection with pagination.
  218. Args:
  219. collection_id (UUID): The ID of the collection to get documents from.
  220. offset (int): The number of documents to skip.
  221. limit (int): The maximum number of documents to return.
  222. Returns:
  223. List[DocumentResponse]: A list of DocumentResponse objects representing the documents in the collection.
  224. Raises:
  225. R2RException: If the collection doesn't exist.
  226. """
  227. if not await self.collection_exists(collection_id):
  228. raise R2RException(status_code=404, message="Collection not found")
  229. query = f"""
  230. SELECT d.id, d.owner_id, d.type, d.metadata, d.title, d.version,
  231. d.size_in_bytes, d.ingestion_status, d.extraction_status, d.created_at, d.updated_at, d.summary,
  232. COUNT(*) OVER() AS total_entries
  233. FROM {self._get_table_name('documents')} d
  234. WHERE $1 = ANY(d.collection_ids)
  235. ORDER BY d.created_at DESC
  236. OFFSET $2
  237. """
  238. conditions = [collection_id, offset]
  239. if limit != -1:
  240. query += " LIMIT $3"
  241. conditions.append(limit)
  242. results = await self.connection_manager.fetch_query(query, conditions)
  243. documents = [
  244. DocumentResponse(
  245. id=row["id"],
  246. collection_ids=[collection_id],
  247. owner_id=row["owner_id"],
  248. document_type=DocumentType(row["type"]),
  249. metadata=json.loads(row["metadata"]),
  250. title=row["title"],
  251. version=row["version"],
  252. size_in_bytes=row["size_in_bytes"],
  253. ingestion_status=IngestionStatus(row["ingestion_status"]),
  254. extraction_status=KGExtractionStatus(row["extraction_status"]),
  255. created_at=row["created_at"],
  256. updated_at=row["updated_at"],
  257. summary=row["summary"],
  258. )
  259. for row in results
  260. ]
  261. total_entries = results[0]["total_entries"] if results else 0
  262. return {"results": documents, "total_entries": total_entries}
  263. async def get_collections_overview(
  264. self,
  265. offset: int,
  266. limit: int,
  267. filter_user_ids: Optional[list[UUID]] = None,
  268. filter_document_ids: Optional[list[UUID]] = None,
  269. filter_collection_ids: Optional[list[UUID]] = None,
  270. ) -> dict[str, list[CollectionResponse] | int]:
  271. conditions = []
  272. params: list[Any] = []
  273. param_index = 1
  274. if filter_user_ids:
  275. conditions.append(
  276. f"""
  277. c.id IN (
  278. SELECT unnest(collection_ids)
  279. FROM {self.project_name}.users
  280. WHERE id = ANY(${param_index})
  281. )
  282. """
  283. )
  284. params.append(filter_user_ids)
  285. param_index += 1
  286. if filter_document_ids:
  287. conditions.append(
  288. f"""
  289. c.id IN (
  290. SELECT unnest(collection_ids)
  291. FROM {self.project_name}.documents
  292. WHERE id = ANY(${param_index})
  293. )
  294. """
  295. )
  296. params.append(filter_document_ids)
  297. param_index += 1
  298. if filter_collection_ids:
  299. conditions.append(f"c.id = ANY(${param_index})")
  300. params.append(filter_collection_ids)
  301. param_index += 1
  302. where_clause = (
  303. f"WHERE {' AND '.join(conditions)}" if conditions else ""
  304. )
  305. query = f"""
  306. SELECT
  307. c.*,
  308. COUNT(*) OVER() as total_entries
  309. FROM {self.project_name}.collections c
  310. {where_clause}
  311. ORDER BY created_at DESC
  312. OFFSET ${param_index}
  313. """
  314. params.append(offset)
  315. param_index += 1
  316. if limit != -1:
  317. query += f" LIMIT ${param_index}"
  318. params.append(limit)
  319. try:
  320. results = await self.connection_manager.fetch_query(query, params)
  321. if not results:
  322. return {"results": [], "total_entries": 0}
  323. total_entries = results[0]["total_entries"] if results else 0
  324. collections = [CollectionResponse(**row) for row in results]
  325. return {"results": collections, "total_entries": total_entries}
  326. except Exception as e:
  327. raise HTTPException(
  328. status_code=500,
  329. detail=f"An error occurred while fetching collections: {e}",
  330. ) from e
  331. async def assign_document_to_collection_relational(
  332. self,
  333. document_id: UUID,
  334. collection_id: UUID,
  335. ) -> UUID:
  336. """
  337. Assign a document to a collection.
  338. Args:
  339. document_id (UUID): The ID of the document to assign.
  340. collection_id (UUID): The ID of the collection to assign the document to.
  341. Raises:
  342. R2RException: If the collection doesn't exist, if the document is not found,
  343. or if there's a database error.
  344. """
  345. try:
  346. if not await self.collection_exists(collection_id):
  347. raise R2RException(
  348. status_code=404, message="Collection not found"
  349. )
  350. # First, check if the document exists
  351. document_check_query = f"""
  352. SELECT 1 FROM {self._get_table_name('documents')}
  353. WHERE id = $1
  354. """
  355. document_exists = await self.connection_manager.fetchrow_query(
  356. document_check_query, [document_id]
  357. )
  358. if not document_exists:
  359. raise R2RException(
  360. status_code=404, message="Document not found"
  361. )
  362. # If document exists, proceed with the assignment
  363. assign_query = f"""
  364. UPDATE {self._get_table_name('documents')}
  365. SET collection_ids = array_append(collection_ids, $1)
  366. WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
  367. RETURNING id
  368. """
  369. result = await self.connection_manager.fetchrow_query(
  370. assign_query, [collection_id, document_id]
  371. )
  372. if not result:
  373. # Document exists but was already assigned to the collection
  374. raise R2RException(
  375. status_code=409,
  376. message="Document is already assigned to the collection",
  377. )
  378. update_collection_query = f"""
  379. UPDATE {self._get_table_name('collections')}
  380. SET document_count = document_count + 1
  381. WHERE id = $1
  382. """
  383. await self.connection_manager.execute_query(
  384. query=update_collection_query, params=[collection_id]
  385. )
  386. return collection_id
  387. except R2RException:
  388. # Re-raise R2RExceptions as they are already handled
  389. raise
  390. except Exception as e:
  391. raise HTTPException(
  392. status_code=500,
  393. detail=f"An error '{e}' occurred while assigning the document to the collection",
  394. ) from e
  395. async def remove_document_from_collection_relational(
  396. self, document_id: UUID, collection_id: UUID
  397. ) -> None:
  398. """
  399. Remove a document from a collection.
  400. Args:
  401. document_id (UUID): The ID of the document to remove.
  402. collection_id (UUID): The ID of the collection to remove the document from.
  403. Raises:
  404. R2RException: If the collection doesn't exist or if the document is not in the collection.
  405. """
  406. if not await self.collection_exists(collection_id):
  407. raise R2RException(status_code=404, message="Collection not found")
  408. query = f"""
  409. UPDATE {self._get_table_name('documents')}
  410. SET collection_ids = array_remove(collection_ids, $1)
  411. WHERE id = $2 AND $1 = ANY(collection_ids)
  412. RETURNING id
  413. """
  414. result = await self.connection_manager.fetchrow_query(
  415. query, [collection_id, document_id]
  416. )
  417. if not result:
  418. raise R2RException(
  419. status_code=404,
  420. message="Document not found in the specified collection",
  421. )
  422. async def export_to_csv(
  423. self,
  424. columns: Optional[list[str]] = None,
  425. filters: Optional[dict] = None,
  426. include_header: bool = True,
  427. ) -> tuple[str, IO]:
  428. """
  429. Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
  430. """
  431. valid_columns = {
  432. "id",
  433. "owner_id",
  434. "name",
  435. "description",
  436. "graph_sync_status",
  437. "graph_cluster_status",
  438. "created_at",
  439. "updated_at",
  440. "user_count",
  441. "document_count",
  442. }
  443. if not columns:
  444. columns = list(valid_columns)
  445. elif invalid_cols := set(columns) - valid_columns:
  446. raise ValueError(f"Invalid columns: {invalid_cols}")
  447. select_stmt = f"""
  448. SELECT
  449. id::text,
  450. owner_id::text,
  451. name,
  452. description,
  453. graph_sync_status,
  454. graph_cluster_status,
  455. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
  456. to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
  457. user_count,
  458. document_count
  459. FROM {self._get_table_name(self.TABLE_NAME)}
  460. """
  461. params = []
  462. if filters:
  463. conditions = []
  464. param_index = 1
  465. for field, value in filters.items():
  466. if field not in valid_columns:
  467. continue
  468. if isinstance(value, dict):
  469. for op, val in value.items():
  470. if op == "$eq":
  471. conditions.append(f"{field} = ${param_index}")
  472. params.append(val)
  473. param_index += 1
  474. elif op == "$gt":
  475. conditions.append(f"{field} > ${param_index}")
  476. params.append(val)
  477. param_index += 1
  478. elif op == "$lt":
  479. conditions.append(f"{field} < ${param_index}")
  480. params.append(val)
  481. param_index += 1
  482. else:
  483. # Direct equality
  484. conditions.append(f"{field} = ${param_index}")
  485. params.append(value)
  486. param_index += 1
  487. if conditions:
  488. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  489. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  490. temp_file = None
  491. try:
  492. temp_file = tempfile.NamedTemporaryFile(
  493. mode="w", delete=True, suffix=".csv"
  494. )
  495. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  496. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  497. async with conn.transaction():
  498. cursor = await conn.cursor(select_stmt, *params)
  499. if include_header:
  500. writer.writerow(columns)
  501. chunk_size = 1000
  502. while True:
  503. rows = await cursor.fetch(chunk_size)
  504. if not rows:
  505. break
  506. for row in rows:
  507. writer.writerow(row)
  508. temp_file.flush()
  509. return temp_file.name, temp_file
  510. except Exception as e:
  511. if temp_file:
  512. temp_file.close()
  513. raise HTTPException(
  514. status_code=500,
  515. detail=f"Failed to export data: {str(e)}",
  516. ) from e