collections.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. import json
  2. import logging
  3. from typing import Any, Optional
  4. from uuid import UUID, uuid4
  5. from asyncpg.exceptions import UniqueViolationError
  6. from fastapi import HTTPException
  7. from core.base import (
  8. Handler,
  9. DatabaseConfig,
  10. KGExtractionStatus,
  11. R2RException,
  12. generate_default_user_collection_id,
  13. )
  14. from core.base.abstractions import (
  15. DocumentResponse,
  16. DocumentType,
  17. IngestionStatus,
  18. )
  19. from core.base.api.models import CollectionResponse
  20. from core.utils import generate_default_user_collection_id
  21. from .base import PostgresConnectionManager
  22. logger = logging.getLogger()
  23. class PostgresCollectionsHandler(Handler):
  24. TABLE_NAME = "collections"
  25. def __init__(
  26. self,
  27. project_name: str,
  28. connection_manager: PostgresConnectionManager,
  29. config: DatabaseConfig,
  30. ):
  31. self.config = config
  32. super().__init__(project_name, connection_manager)
  33. async def create_tables(self) -> None:
  34. query = f"""
  35. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)} (
  36. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  37. owner_id UUID,
  38. name TEXT NOT NULL,
  39. description TEXT,
  40. graph_sync_status TEXT DEFAULT 'pending',
  41. graph_cluster_status TEXT DEFAULT 'pending',
  42. created_at TIMESTAMPTZ DEFAULT NOW(),
  43. updated_at TIMESTAMPTZ DEFAULT NOW(),
  44. user_count INT DEFAULT 0,
  45. document_count INT DEFAULT 0
  46. );
  47. """
  48. await self.connection_manager.execute_query(query)
  49. async def collection_exists(self, collection_id: UUID) -> bool:
  50. """Check if a collection exists."""
  51. query = f"""
  52. SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
  53. WHERE id = $1
  54. """
  55. result = await self.connection_manager.fetchrow_query(
  56. query, [collection_id]
  57. )
  58. return result is not None
  59. async def create_collection(
  60. self,
  61. owner_id: UUID,
  62. name: Optional[str] = None,
  63. description: str = "",
  64. collection_id: Optional[UUID] = None,
  65. ) -> CollectionResponse:
  66. if not name and not collection_id:
  67. name = self.config.default_collection_name
  68. collection_id = generate_default_user_collection_id(owner_id)
  69. query = f"""
  70. INSERT INTO {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
  71. (id, owner_id, name, description)
  72. VALUES ($1, $2, $3, $4)
  73. RETURNING id, owner_id, name, description, graph_sync_status, graph_cluster_status, created_at, updated_at
  74. """
  75. params = [
  76. collection_id or uuid4(),
  77. owner_id,
  78. name,
  79. description,
  80. ]
  81. try:
  82. result = await self.connection_manager.fetchrow_query(
  83. query=query,
  84. params=params,
  85. )
  86. if not result:
  87. raise R2RException(
  88. status_code=404, message="Collection not found"
  89. )
  90. return CollectionResponse(
  91. id=result["id"],
  92. owner_id=result["owner_id"],
  93. name=result["name"],
  94. description=result["description"],
  95. graph_cluster_status=result["graph_cluster_status"],
  96. graph_sync_status=result["graph_sync_status"],
  97. created_at=result["created_at"],
  98. updated_at=result["updated_at"],
  99. user_count=0,
  100. document_count=0,
  101. )
  102. except UniqueViolationError:
  103. raise R2RException(
  104. message="Collection with this ID already exists",
  105. status_code=409,
  106. )
  107. async def update_collection(
  108. self,
  109. collection_id: UUID,
  110. name: Optional[str] = None,
  111. description: Optional[str] = None,
  112. ) -> CollectionResponse:
  113. """Update an existing collection."""
  114. if not await self.collection_exists(collection_id):
  115. raise R2RException(status_code=404, message="Collection not found")
  116. update_fields = []
  117. params: list = []
  118. param_index = 1
  119. if name is not None:
  120. update_fields.append(f"name = ${param_index}")
  121. params.append(name)
  122. param_index += 1
  123. if description is not None:
  124. update_fields.append(f"description = ${param_index}")
  125. params.append(description)
  126. param_index += 1
  127. if not update_fields:
  128. raise R2RException(status_code=400, message="No fields to update")
  129. update_fields.append("updated_at = NOW()")
  130. params.append(collection_id)
  131. query = f"""
  132. WITH updated_collection AS (
  133. UPDATE {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
  134. SET {', '.join(update_fields)}
  135. WHERE id = ${param_index}
  136. RETURNING id, owner_id, name, description, graph_sync_status, graph_cluster_status, created_at, updated_at
  137. )
  138. SELECT
  139. uc.*,
  140. COUNT(DISTINCT u.id) FILTER (WHERE u.id IS NOT NULL) as user_count,
  141. COUNT(DISTINCT d.id) FILTER (WHERE d.id IS NOT NULL) as document_count
  142. FROM updated_collection uc
  143. LEFT JOIN {self._get_table_name('users')} u ON uc.id = ANY(u.collection_ids)
  144. LEFT JOIN {self._get_table_name('documents')} d ON uc.id = ANY(d.collection_ids)
  145. GROUP BY uc.id, uc.owner_id, uc.name, uc.description, uc.graph_sync_status, uc.graph_cluster_status, uc.created_at, uc.updated_at
  146. """
  147. try:
  148. result = await self.connection_manager.fetchrow_query(
  149. query, params
  150. )
  151. if not result:
  152. raise R2RException(
  153. status_code=404, message="Collection not found"
  154. )
  155. return CollectionResponse(
  156. id=result["id"],
  157. owner_id=result["owner_id"],
  158. name=result["name"],
  159. description=result["description"],
  160. graph_sync_status=result["graph_sync_status"],
  161. graph_cluster_status=result["graph_cluster_status"],
  162. created_at=result["created_at"],
  163. updated_at=result["updated_at"],
  164. user_count=result["user_count"],
  165. document_count=result["document_count"],
  166. )
  167. except Exception as e:
  168. raise HTTPException(
  169. status_code=500,
  170. detail=f"An error occurred while updating the collection: {e}",
  171. )
  172. async def delete_collection_relational(self, collection_id: UUID) -> None:
  173. # Remove collection_id from users
  174. user_update_query = f"""
  175. UPDATE {self._get_table_name('users')}
  176. SET collection_ids = array_remove(collection_ids, $1)
  177. WHERE $1 = ANY(collection_ids)
  178. """
  179. await self.connection_manager.execute_query(
  180. user_update_query, [collection_id]
  181. )
  182. # Remove collection_id from documents
  183. document_update_query = f"""
  184. WITH updated AS (
  185. UPDATE {self._get_table_name('documents')}
  186. SET collection_ids = array_remove(collection_ids, $1)
  187. WHERE $1 = ANY(collection_ids)
  188. RETURNING 1
  189. )
  190. SELECT COUNT(*) AS affected_rows FROM updated
  191. """
  192. await self.connection_manager.fetchrow_query(
  193. document_update_query, [collection_id]
  194. )
  195. # Delete the collection
  196. delete_query = f"""
  197. DELETE FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
  198. WHERE id = $1
  199. RETURNING id
  200. """
  201. deleted = await self.connection_manager.fetchrow_query(
  202. delete_query, [collection_id]
  203. )
  204. if not deleted:
  205. raise R2RException(status_code=404, message="Collection not found")
  206. async def documents_in_collection(
  207. self, collection_id: UUID, offset: int, limit: int
  208. ) -> dict[str, list[DocumentResponse] | int]:
  209. """
  210. Get all documents in a specific collection with pagination.
  211. Args:
  212. collection_id (UUID): The ID of the collection to get documents from.
  213. offset (int): The number of documents to skip.
  214. limit (int): The maximum number of documents to return.
  215. Returns:
  216. List[DocumentResponse]: A list of DocumentResponse objects representing the documents in the collection.
  217. Raises:
  218. R2RException: If the collection doesn't exist.
  219. """
  220. if not await self.collection_exists(collection_id):
  221. raise R2RException(status_code=404, message="Collection not found")
  222. query = f"""
  223. SELECT d.id, d.owner_id, d.type, d.metadata, d.title, d.version,
  224. d.size_in_bytes, d.ingestion_status, d.extraction_status, d.created_at, d.updated_at, d.summary,
  225. COUNT(*) OVER() AS total_entries
  226. FROM {self._get_table_name('documents')} d
  227. WHERE $1 = ANY(d.collection_ids)
  228. ORDER BY d.created_at DESC
  229. OFFSET $2
  230. """
  231. conditions = [collection_id, offset]
  232. if limit != -1:
  233. query += " LIMIT $3"
  234. conditions.append(limit)
  235. results = await self.connection_manager.fetch_query(query, conditions)
  236. documents = [
  237. DocumentResponse(
  238. id=row["id"],
  239. collection_ids=[collection_id],
  240. owner_id=row["owner_id"],
  241. document_type=DocumentType(row["type"]),
  242. metadata=json.loads(row["metadata"]),
  243. title=row["title"],
  244. version=row["version"],
  245. size_in_bytes=row["size_in_bytes"],
  246. ingestion_status=IngestionStatus(row["ingestion_status"]),
  247. extraction_status=KGExtractionStatus(row["extraction_status"]),
  248. created_at=row["created_at"],
  249. updated_at=row["updated_at"],
  250. summary=row["summary"],
  251. )
  252. for row in results
  253. ]
  254. total_entries = results[0]["total_entries"] if results else 0
  255. return {"results": documents, "total_entries": total_entries}
  256. async def get_collections_overview(
  257. self,
  258. offset: int,
  259. limit: int,
  260. filter_user_ids: Optional[list[UUID]] = None,
  261. filter_document_ids: Optional[list[UUID]] = None,
  262. filter_collection_ids: Optional[list[UUID]] = None,
  263. ) -> dict[str, list[CollectionResponse] | int]:
  264. conditions = []
  265. params: list[Any] = []
  266. param_index = 1
  267. if filter_user_ids:
  268. conditions.append(
  269. f"""
  270. c.id IN (
  271. SELECT unnest(collection_ids)
  272. FROM {self.project_name}.users
  273. WHERE id = ANY(${param_index})
  274. )
  275. """
  276. )
  277. params.append(filter_user_ids)
  278. param_index += 1
  279. if filter_document_ids:
  280. conditions.append(
  281. f"""
  282. c.id IN (
  283. SELECT unnest(collection_ids)
  284. FROM {self.project_name}.documents
  285. WHERE id = ANY(${param_index})
  286. )
  287. """
  288. )
  289. params.append(filter_document_ids)
  290. param_index += 1
  291. if filter_collection_ids:
  292. conditions.append(f"c.id = ANY(${param_index})")
  293. params.append(filter_collection_ids)
  294. param_index += 1
  295. where_clause = (
  296. f"WHERE {' AND '.join(conditions)}" if conditions else ""
  297. )
  298. query = f"""
  299. SELECT
  300. c.*,
  301. COUNT(*) OVER() as total_entries
  302. FROM {self.project_name}.collections c
  303. {where_clause}
  304. ORDER BY created_at DESC
  305. OFFSET ${param_index}
  306. """
  307. params.append(offset)
  308. param_index += 1
  309. if limit != -1:
  310. query += f" LIMIT ${param_index}"
  311. params.append(limit)
  312. try:
  313. results = await self.connection_manager.fetch_query(query, params)
  314. if not results:
  315. return {"results": [], "total_entries": 0}
  316. total_entries = results[0]["total_entries"] if results else 0
  317. collections = [CollectionResponse(**row) for row in results]
  318. return {"results": collections, "total_entries": total_entries}
  319. except Exception as e:
  320. raise HTTPException(
  321. status_code=500,
  322. detail=f"An error occurred while fetching collections: {e}",
  323. )
  324. async def assign_document_to_collection_relational(
  325. self,
  326. document_id: UUID,
  327. collection_id: UUID,
  328. ) -> UUID:
  329. """
  330. Assign a document to a collection.
  331. Args:
  332. document_id (UUID): The ID of the document to assign.
  333. collection_id (UUID): The ID of the collection to assign the document to.
  334. Raises:
  335. R2RException: If the collection doesn't exist, if the document is not found,
  336. or if there's a database error.
  337. """
  338. try:
  339. if not await self.collection_exists(collection_id):
  340. raise R2RException(
  341. status_code=404, message="Collection not found"
  342. )
  343. # First, check if the document exists
  344. document_check_query = f"""
  345. SELECT 1 FROM {self._get_table_name('documents')}
  346. WHERE id = $1
  347. """
  348. document_exists = await self.connection_manager.fetchrow_query(
  349. document_check_query, [document_id]
  350. )
  351. if not document_exists:
  352. raise R2RException(
  353. status_code=404, message="Document not found"
  354. )
  355. # If document exists, proceed with the assignment
  356. assign_query = f"""
  357. UPDATE {self._get_table_name('documents')}
  358. SET collection_ids = array_append(collection_ids, $1)
  359. WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
  360. RETURNING id
  361. """
  362. result = await self.connection_manager.fetchrow_query(
  363. assign_query, [collection_id, document_id]
  364. )
  365. if not result:
  366. # Document exists but was already assigned to the collection
  367. raise R2RException(
  368. status_code=409,
  369. message="Document is already assigned to the collection",
  370. )
  371. update_collection_query = f"""
  372. UPDATE {self._get_table_name('collections')}
  373. SET document_count = document_count + 1
  374. WHERE id = $1
  375. """
  376. await self.connection_manager.execute_query(
  377. query=update_collection_query, params=[collection_id]
  378. )
  379. return collection_id
  380. except R2RException:
  381. # Re-raise R2RExceptions as they are already handled
  382. raise
  383. except Exception as e:
  384. raise HTTPException(
  385. status_code=500,
  386. detail=f"An error '{e}' occurred while assigning the document to the collection",
  387. )
  388. async def remove_document_from_collection_relational(
  389. self, document_id: UUID, collection_id: UUID
  390. ) -> None:
  391. """
  392. Remove a document from a collection.
  393. Args:
  394. document_id (UUID): The ID of the document to remove.
  395. collection_id (UUID): The ID of the collection to remove the document from.
  396. Raises:
  397. R2RException: If the collection doesn't exist or if the document is not in the collection.
  398. """
  399. if not await self.collection_exists(collection_id):
  400. raise R2RException(status_code=404, message="Collection not found")
  401. query = f"""
  402. UPDATE {self._get_table_name('documents')}
  403. SET collection_ids = array_remove(collection_ids, $1)
  404. WHERE id = $2 AND $1 = ANY(collection_ids)
  405. RETURNING id
  406. """
  407. result = await self.connection_manager.fetchrow_query(
  408. query, [collection_id, document_id]
  409. )
  410. if not result:
  411. raise R2RException(
  412. status_code=404,
  413. message="Document not found in the specified collection",
  414. )