graphs.py 102 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923
  1. import asyncio
  2. import contextlib
  3. import csv
  4. import datetime
  5. import json
  6. import logging
  7. import os
  8. import tempfile
  9. import time
  10. from typing import IO, Any, AsyncGenerator, Optional, Tuple
  11. from uuid import UUID
  12. import asyncpg
  13. import httpx
  14. from asyncpg.exceptions import UndefinedTableError, UniqueViolationError
  15. from fastapi import HTTPException
  16. from core.base.abstractions import (
  17. Community,
  18. Entity,
  19. Graph,
  20. KGCreationSettings,
  21. KGEnrichmentSettings,
  22. KGExtractionStatus,
  23. R2RException,
  24. Relationship,
  25. StoreType,
  26. VectorQuantizationType,
  27. )
  28. from core.base.api.models import GraphResponse
  29. from core.base.providers.database import Handler
  30. from core.base.utils import (
  31. _decorate_vector_type,
  32. _get_str_estimation_output,
  33. llm_cost_per_million_tokens,
  34. )
  35. from .base import PostgresConnectionManager
  36. from .collections import PostgresCollectionsHandler
  37. logger = logging.getLogger()
  38. class PostgresEntitiesHandler(Handler):
  39. def __init__(self, *args: Any, **kwargs: Any) -> None:
  40. self.project_name: str = kwargs.get("project_name") # type: ignore
  41. self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore
  42. self.dimension: int = kwargs.get("dimension") # type: ignore
  43. self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore
  44. def _get_table_name(self, table: str) -> str:
  45. """Get the fully qualified table name."""
  46. return f'"{self.project_name}"."{table}"'
  47. def _get_entity_table_for_store(self, store_type: StoreType) -> str:
  48. """Get the appropriate table name for the store type."""
  49. return f"{store_type.value}_entities"
  50. def _get_parent_constraint(self, store_type: StoreType) -> str:
  51. """Get the appropriate foreign key constraint for the store type."""
  52. if store_type == StoreType.GRAPHS:
  53. return f"""
  54. CONSTRAINT fk_graph
  55. FOREIGN KEY(parent_id)
  56. REFERENCES {self._get_table_name("graphs")}(id)
  57. ON DELETE CASCADE
  58. """
  59. else:
  60. return f"""
  61. CONSTRAINT fk_document
  62. FOREIGN KEY(parent_id)
  63. REFERENCES {self._get_table_name("documents")}(id)
  64. ON DELETE CASCADE
  65. """
  66. async def create_tables(self) -> None:
  67. """Create separate tables for graph and document entities."""
  68. vector_column_str = _decorate_vector_type(
  69. f"({self.dimension})", self.quantization_type
  70. )
  71. for store_type in StoreType:
  72. table_name = self._get_entity_table_for_store(store_type)
  73. parent_constraint = self._get_parent_constraint(store_type)
  74. QUERY = f"""
  75. CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
  76. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  77. name TEXT NOT NULL,
  78. category TEXT,
  79. description TEXT,
  80. parent_id UUID NOT NULL,
  81. description_embedding {vector_column_str},
  82. chunk_ids UUID[],
  83. metadata JSONB,
  84. created_at TIMESTAMPTZ DEFAULT NOW(),
  85. updated_at TIMESTAMPTZ DEFAULT NOW(),
  86. {parent_constraint}
  87. );
  88. CREATE INDEX IF NOT EXISTS {table_name}_name_idx
  89. ON {self._get_table_name(table_name)} (name);
  90. CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
  91. ON {self._get_table_name(table_name)} (parent_id);
  92. CREATE INDEX IF NOT EXISTS {table_name}_category_idx
  93. ON {self._get_table_name(table_name)} (category);
  94. """
  95. await self.connection_manager.execute_query(QUERY)
  96. async def create(
  97. self,
  98. parent_id: UUID,
  99. store_type: StoreType,
  100. name: str,
  101. category: Optional[str] = None,
  102. description: Optional[str] = None,
  103. description_embedding: Optional[list[float] | str] = None,
  104. chunk_ids: Optional[list[UUID]] = None,
  105. metadata: Optional[dict[str, Any] | str] = None,
  106. ) -> Entity:
  107. """Create a new entity in the specified store."""
  108. table_name = self._get_entity_table_for_store(store_type)
  109. if isinstance(metadata, str):
  110. with contextlib.suppress(json.JSONDecodeError):
  111. metadata = json.loads(metadata)
  112. if isinstance(description_embedding, list):
  113. description_embedding = str(description_embedding)
  114. query = f"""
  115. INSERT INTO {self._get_table_name(table_name)}
  116. (name, category, description, parent_id, description_embedding, chunk_ids, metadata)
  117. VALUES ($1, $2, $3, $4, $5, $6, $7)
  118. RETURNING id, name, category, description, parent_id, chunk_ids, metadata
  119. """
  120. params = [
  121. name,
  122. category,
  123. description,
  124. parent_id,
  125. description_embedding,
  126. chunk_ids,
  127. json.dumps(metadata) if metadata else None,
  128. ]
  129. result = await self.connection_manager.fetchrow_query(
  130. query=query,
  131. params=params,
  132. )
  133. return Entity(
  134. id=result["id"],
  135. name=result["name"],
  136. category=result["category"],
  137. description=result["description"],
  138. parent_id=result["parent_id"],
  139. chunk_ids=result["chunk_ids"],
  140. metadata=result["metadata"],
  141. )
  142. async def get(
  143. self,
  144. parent_id: UUID,
  145. store_type: StoreType,
  146. offset: int,
  147. limit: int,
  148. entity_ids: Optional[list[UUID]] = None,
  149. entity_names: Optional[list[str]] = None,
  150. include_embeddings: bool = False,
  151. ):
  152. """Retrieve entities from the specified store."""
  153. table_name = self._get_entity_table_for_store(store_type)
  154. conditions = ["parent_id = $1"]
  155. params: list[Any] = [parent_id]
  156. param_index = 2
  157. if entity_ids:
  158. conditions.append(f"id = ANY(${param_index})")
  159. params.append(entity_ids)
  160. param_index += 1
  161. if entity_names:
  162. conditions.append(f"name = ANY(${param_index})")
  163. params.append(entity_names)
  164. param_index += 1
  165. select_fields = """
  166. id, name, category, description, parent_id,
  167. chunk_ids, metadata
  168. """
  169. if include_embeddings:
  170. select_fields += ", description_embedding"
  171. COUNT_QUERY = f"""
  172. SELECT COUNT(*)
  173. FROM {self._get_table_name(table_name)}
  174. WHERE {' AND '.join(conditions)}
  175. """
  176. count_params = params[: param_index - 1]
  177. count = (
  178. await self.connection_manager.fetch_query(
  179. COUNT_QUERY, count_params
  180. )
  181. )[0]["count"]
  182. QUERY = f"""
  183. SELECT {select_fields}
  184. FROM {self._get_table_name(table_name)}
  185. WHERE {' AND '.join(conditions)}
  186. ORDER BY created_at
  187. OFFSET ${param_index}
  188. """
  189. params.append(offset)
  190. param_index += 1
  191. if limit != -1:
  192. QUERY += f" LIMIT ${param_index}"
  193. params.append(limit)
  194. rows = await self.connection_manager.fetch_query(QUERY, params)
  195. entities = []
  196. for row in rows:
  197. # Convert the Record to a dictionary
  198. entity_dict = dict(row)
  199. # Process metadata if it exists and is a string
  200. if isinstance(entity_dict["metadata"], str):
  201. with contextlib.suppress(json.JSONDecodeError):
  202. entity_dict["metadata"] = json.loads(
  203. entity_dict["metadata"]
  204. )
  205. entities.append(Entity(**entity_dict))
  206. return entities, count
  207. async def update(
  208. self,
  209. entity_id: UUID,
  210. store_type: StoreType,
  211. name: Optional[str] = None,
  212. description: Optional[str] = None,
  213. description_embedding: Optional[list[float] | str] = None,
  214. category: Optional[str] = None,
  215. metadata: Optional[dict] = None,
  216. ) -> Entity:
  217. """Update an entity in the specified store."""
  218. table_name = self._get_entity_table_for_store(store_type)
  219. update_fields = []
  220. params: list[Any] = []
  221. param_index = 1
  222. if isinstance(metadata, str):
  223. with contextlib.suppress(json.JSONDecodeError):
  224. metadata = json.loads(metadata)
  225. if name is not None:
  226. update_fields.append(f"name = ${param_index}")
  227. params.append(name)
  228. param_index += 1
  229. if description is not None:
  230. update_fields.append(f"description = ${param_index}")
  231. params.append(description)
  232. param_index += 1
  233. if description_embedding is not None:
  234. update_fields.append(f"description_embedding = ${param_index}")
  235. params.append(description_embedding)
  236. param_index += 1
  237. if category is not None:
  238. update_fields.append(f"category = ${param_index}")
  239. params.append(category)
  240. param_index += 1
  241. if metadata is not None:
  242. update_fields.append(f"metadata = ${param_index}")
  243. params.append(json.dumps(metadata))
  244. param_index += 1
  245. if not update_fields:
  246. raise R2RException(status_code=400, message="No fields to update")
  247. update_fields.append("updated_at = NOW()")
  248. params.append(entity_id)
  249. query = f"""
  250. UPDATE {self._get_table_name(table_name)}
  251. SET {', '.join(update_fields)}
  252. WHERE id = ${param_index}\
  253. RETURNING id, name, category, description, parent_id, chunk_ids, metadata
  254. """
  255. try:
  256. result = await self.connection_manager.fetchrow_query(
  257. query=query,
  258. params=params,
  259. )
  260. return Entity(
  261. id=result["id"],
  262. name=result["name"],
  263. category=result["category"],
  264. description=result["description"],
  265. parent_id=result["parent_id"],
  266. chunk_ids=result["chunk_ids"],
  267. metadata=result["metadata"],
  268. )
  269. except Exception as e:
  270. raise HTTPException(
  271. status_code=500,
  272. detail=f"An error occurred while updating the entity: {e}",
  273. ) from e
  274. async def delete(
  275. self,
  276. parent_id: UUID,
  277. entity_ids: Optional[list[UUID]] = None,
  278. store_type: StoreType = StoreType.GRAPHS,
  279. ) -> None:
  280. """
  281. Delete entities from the specified store.
  282. If entity_ids is not provided, deletes all entities for the given parent_id.
  283. Args:
  284. parent_id (UUID): Parent ID (collection_id or document_id)
  285. entity_ids (Optional[list[UUID]]): Specific entity IDs to delete. If None, deletes all entities for parent_id
  286. store_type (StoreType): Type of store (graph or document)
  287. Returns:
  288. list[UUID]: List of deleted entity IDs
  289. Raises:
  290. R2RException: If specific entities were requested but not all found
  291. """
  292. table_name = self._get_entity_table_for_store(store_type)
  293. if entity_ids is None:
  294. # Delete all entities for the parent_id
  295. QUERY = f"""
  296. DELETE FROM {self._get_table_name(table_name)}
  297. WHERE parent_id = $1
  298. RETURNING id
  299. """
  300. results = await self.connection_manager.fetch_query(
  301. QUERY, [parent_id]
  302. )
  303. else:
  304. # Delete specific entities
  305. QUERY = f"""
  306. DELETE FROM {self._get_table_name(table_name)}
  307. WHERE id = ANY($1) AND parent_id = $2
  308. RETURNING id
  309. """
  310. results = await self.connection_manager.fetch_query(
  311. QUERY, [entity_ids, parent_id]
  312. )
  313. # Check if all requested entities were deleted
  314. deleted_ids = [row["id"] for row in results]
  315. if entity_ids and len(deleted_ids) != len(entity_ids):
  316. raise R2RException(
  317. f"Some entities not found in {store_type} store or no permission to delete",
  318. 404,
  319. )
  320. async def export_to_csv(
  321. self,
  322. parent_id: UUID,
  323. store_type: StoreType,
  324. columns: Optional[list[str]] = None,
  325. filters: Optional[dict] = None,
  326. include_header: bool = True,
  327. ) -> tuple[str, IO]:
  328. """
  329. Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
  330. """
  331. valid_columns = {
  332. "id",
  333. "name",
  334. "category",
  335. "description",
  336. "parent_id",
  337. "chunk_ids",
  338. "metadata",
  339. "created_at",
  340. "updated_at",
  341. }
  342. if not columns:
  343. columns = list(valid_columns)
  344. elif invalid_cols := set(columns) - valid_columns:
  345. raise ValueError(f"Invalid columns: {invalid_cols}")
  346. select_stmt = f"""
  347. SELECT
  348. id::text,
  349. name,
  350. category,
  351. description,
  352. parent_id::text,
  353. chunk_ids::text,
  354. metadata::text,
  355. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
  356. to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
  357. FROM {self._get_table_name(self._get_entity_table_for_store(store_type))}
  358. """
  359. conditions = ["parent_id = $1"]
  360. params: list[Any] = [parent_id]
  361. param_index = 2
  362. if filters:
  363. for field, value in filters.items():
  364. if field not in valid_columns:
  365. continue
  366. if isinstance(value, dict):
  367. for op, val in value.items():
  368. if op == "$eq":
  369. conditions.append(f"{field} = ${param_index}")
  370. params.append(val)
  371. param_index += 1
  372. elif op == "$gt":
  373. conditions.append(f"{field} > ${param_index}")
  374. params.append(val)
  375. param_index += 1
  376. elif op == "$lt":
  377. conditions.append(f"{field} < ${param_index}")
  378. params.append(val)
  379. param_index += 1
  380. else:
  381. # Direct equality
  382. conditions.append(f"{field} = ${param_index}")
  383. params.append(value)
  384. param_index += 1
  385. if conditions:
  386. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  387. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  388. temp_file = None
  389. try:
  390. temp_file = tempfile.NamedTemporaryFile(
  391. mode="w", delete=True, suffix=".csv"
  392. )
  393. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  394. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  395. async with conn.transaction():
  396. cursor = await conn.cursor(select_stmt, *params)
  397. if include_header:
  398. writer.writerow(columns)
  399. chunk_size = 1000
  400. while True:
  401. rows = await cursor.fetch(chunk_size)
  402. if not rows:
  403. break
  404. for row in rows:
  405. writer.writerow(row)
  406. temp_file.flush()
  407. return temp_file.name, temp_file
  408. except Exception as e:
  409. if temp_file:
  410. temp_file.close()
  411. raise HTTPException(
  412. status_code=500,
  413. detail=f"Failed to export data: {str(e)}",
  414. ) from e
  415. class PostgresRelationshipsHandler(Handler):
  416. def __init__(self, *args: Any, **kwargs: Any) -> None:
  417. self.project_name: str = kwargs.get("project_name") # type: ignore
  418. self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore
  419. self.dimension: int = kwargs.get("dimension") # type: ignore
  420. self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore
  421. def _get_table_name(self, table: str) -> str:
  422. """Get the fully qualified table name."""
  423. return f'"{self.project_name}"."{table}"'
  424. def _get_relationship_table_for_store(self, store_type: StoreType) -> str:
  425. """Get the appropriate table name for the store type."""
  426. return f"{store_type.value}_relationships"
  427. def _get_parent_constraint(self, store_type: StoreType) -> str:
  428. """Get the appropriate foreign key constraint for the store type."""
  429. if store_type == StoreType.GRAPHS:
  430. return f"""
  431. CONSTRAINT fk_graph
  432. FOREIGN KEY(parent_id)
  433. REFERENCES {self._get_table_name("graphs")}(id)
  434. ON DELETE CASCADE
  435. """
  436. else:
  437. return f"""
  438. CONSTRAINT fk_document
  439. FOREIGN KEY(parent_id)
  440. REFERENCES {self._get_table_name("documents")}(id)
  441. ON DELETE CASCADE
  442. """
  443. async def create_tables(self) -> None:
  444. """Create separate tables for graph and document relationships."""
  445. for store_type in StoreType:
  446. table_name = self._get_relationship_table_for_store(store_type)
  447. parent_constraint = self._get_parent_constraint(store_type)
  448. vector_column_str = _decorate_vector_type(
  449. f"({self.dimension})", self.quantization_type
  450. )
  451. QUERY = f"""
  452. CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
  453. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  454. subject TEXT NOT NULL,
  455. predicate TEXT NOT NULL,
  456. object TEXT NOT NULL,
  457. description TEXT,
  458. description_embedding {vector_column_str},
  459. subject_id UUID,
  460. object_id UUID,
  461. weight FLOAT DEFAULT 1.0,
  462. chunk_ids UUID[],
  463. parent_id UUID NOT NULL,
  464. metadata JSONB,
  465. created_at TIMESTAMPTZ DEFAULT NOW(),
  466. updated_at TIMESTAMPTZ DEFAULT NOW(),
  467. {parent_constraint}
  468. );
  469. CREATE INDEX IF NOT EXISTS {table_name}_subject_idx
  470. ON {self._get_table_name(table_name)} (subject);
  471. CREATE INDEX IF NOT EXISTS {table_name}_object_idx
  472. ON {self._get_table_name(table_name)} (object);
  473. CREATE INDEX IF NOT EXISTS {table_name}_predicate_idx
  474. ON {self._get_table_name(table_name)} (predicate);
  475. CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
  476. ON {self._get_table_name(table_name)} (parent_id);
  477. CREATE INDEX IF NOT EXISTS {table_name}_subject_id_idx
  478. ON {self._get_table_name(table_name)} (subject_id);
  479. CREATE INDEX IF NOT EXISTS {table_name}_object_id_idx
  480. ON {self._get_table_name(table_name)} (object_id);
  481. """
  482. await self.connection_manager.execute_query(QUERY)
  483. async def create(
  484. self,
  485. subject: str,
  486. subject_id: UUID,
  487. predicate: str,
  488. object: str,
  489. object_id: UUID,
  490. parent_id: UUID,
  491. store_type: StoreType,
  492. description: str | None = None,
  493. weight: float | None = 1.0,
  494. chunk_ids: Optional[list[UUID]] = None,
  495. description_embedding: Optional[list[float] | str] = None,
  496. metadata: Optional[dict[str, Any] | str] = None,
  497. ) -> Relationship:
  498. """Create a new relationship in the specified store."""
  499. table_name = self._get_relationship_table_for_store(store_type)
  500. if isinstance(metadata, str):
  501. with contextlib.suppress(json.JSONDecodeError):
  502. metadata = json.loads(metadata)
  503. if isinstance(description_embedding, list):
  504. description_embedding = str(description_embedding)
  505. query = f"""
  506. INSERT INTO {self._get_table_name(table_name)}
  507. (subject, predicate, object, description, subject_id, object_id,
  508. weight, chunk_ids, parent_id, description_embedding, metadata)
  509. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
  510. RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
  511. """
  512. params = [
  513. subject,
  514. predicate,
  515. object,
  516. description,
  517. subject_id,
  518. object_id,
  519. weight,
  520. chunk_ids,
  521. parent_id,
  522. description_embedding,
  523. json.dumps(metadata) if metadata else None,
  524. ]
  525. result = await self.connection_manager.fetchrow_query(
  526. query=query,
  527. params=params,
  528. )
  529. return Relationship(
  530. id=result["id"],
  531. subject=result["subject"],
  532. predicate=result["predicate"],
  533. object=result["object"],
  534. description=result["description"],
  535. subject_id=result["subject_id"],
  536. object_id=result["object_id"],
  537. weight=result["weight"],
  538. chunk_ids=result["chunk_ids"],
  539. parent_id=result["parent_id"],
  540. metadata=result["metadata"],
  541. )
  542. async def get(
  543. self,
  544. parent_id: UUID,
  545. store_type: StoreType,
  546. offset: int,
  547. limit: int,
  548. relationship_ids: Optional[list[UUID]] = None,
  549. entity_names: Optional[list[str]] = None,
  550. relationship_types: Optional[list[str]] = None,
  551. include_metadata: bool = False,
  552. ):
  553. """
  554. Get relationships from the specified store.
  555. Args:
  556. parent_id: UUID of the parent (collection_id or document_id)
  557. store_type: Type of store (graph or document)
  558. offset: Number of records to skip
  559. limit: Maximum number of records to return (-1 for no limit)
  560. relationship_ids: Optional list of specific relationship IDs to retrieve
  561. entity_names: Optional list of entity names to filter by (matches subject or object)
  562. relationship_types: Optional list of relationship types (predicates) to filter by
  563. include_metadata: Whether to include metadata in the response
  564. Returns:
  565. Tuple of (list of relationships, total count)
  566. """
  567. table_name = self._get_relationship_table_for_store(store_type)
  568. conditions = ["parent_id = $1"]
  569. params: list[Any] = [parent_id]
  570. param_index = 2
  571. if relationship_ids:
  572. conditions.append(f"id = ANY(${param_index})")
  573. params.append(relationship_ids)
  574. param_index += 1
  575. if entity_names:
  576. conditions.append(
  577. f"(subject = ANY(${param_index}) OR object = ANY(${param_index}))"
  578. )
  579. params.append(entity_names)
  580. param_index += 1
  581. if relationship_types:
  582. conditions.append(f"predicate = ANY(${param_index})")
  583. params.append(relationship_types)
  584. param_index += 1
  585. select_fields = """
  586. id, subject, predicate, object, description,
  587. subject_id, object_id, weight, chunk_ids,
  588. parent_id
  589. """
  590. if include_metadata:
  591. select_fields += ", metadata"
  592. # Count query
  593. COUNT_QUERY = f"""
  594. SELECT COUNT(*)
  595. FROM {self._get_table_name(table_name)}
  596. WHERE {' AND '.join(conditions)}
  597. """
  598. count_params = params[: param_index - 1]
  599. count = (
  600. await self.connection_manager.fetch_query(
  601. COUNT_QUERY, count_params
  602. )
  603. )[0]["count"]
  604. # Main query
  605. QUERY = f"""
  606. SELECT {select_fields}
  607. FROM {self._get_table_name(table_name)}
  608. WHERE {' AND '.join(conditions)}
  609. ORDER BY created_at
  610. OFFSET ${param_index}
  611. """
  612. params.append(offset)
  613. param_index += 1
  614. if limit != -1:
  615. QUERY += f" LIMIT ${param_index}"
  616. params.append(limit)
  617. rows = await self.connection_manager.fetch_query(QUERY, params)
  618. relationships = []
  619. for row in rows:
  620. relationship_dict = dict(row)
  621. if include_metadata and isinstance(
  622. relationship_dict["metadata"], str
  623. ):
  624. with contextlib.suppress(json.JSONDecodeError):
  625. relationship_dict["metadata"] = json.loads(
  626. relationship_dict["metadata"]
  627. )
  628. elif not include_metadata:
  629. relationship_dict.pop("metadata", None)
  630. relationships.append(Relationship(**relationship_dict))
  631. return relationships, count
  632. async def update(
  633. self,
  634. relationship_id: UUID,
  635. store_type: StoreType,
  636. subject: Optional[str],
  637. subject_id: Optional[UUID],
  638. predicate: Optional[str],
  639. object: Optional[str],
  640. object_id: Optional[UUID],
  641. description: Optional[str],
  642. description_embedding: Optional[list[float] | str],
  643. weight: Optional[float],
  644. metadata: Optional[dict[str, Any] | str],
  645. ) -> Relationship:
  646. """Update multiple relationships in the specified store."""
  647. table_name = self._get_relationship_table_for_store(store_type)
  648. update_fields = []
  649. params: list = []
  650. param_index = 1
  651. if isinstance(metadata, str):
  652. with contextlib.suppress(json.JSONDecodeError):
  653. metadata = json.loads(metadata)
  654. if subject is not None:
  655. update_fields.append(f"subject = ${param_index}")
  656. params.append(subject)
  657. param_index += 1
  658. if subject_id is not None:
  659. update_fields.append(f"subject_id = ${param_index}")
  660. params.append(subject_id)
  661. param_index += 1
  662. if predicate is not None:
  663. update_fields.append(f"predicate = ${param_index}")
  664. params.append(predicate)
  665. param_index += 1
  666. if object is not None:
  667. update_fields.append(f"object = ${param_index}")
  668. params.append(object)
  669. param_index += 1
  670. if object_id is not None:
  671. update_fields.append(f"object_id = ${param_index}")
  672. params.append(object_id)
  673. param_index += 1
  674. if description is not None:
  675. update_fields.append(f"description = ${param_index}")
  676. params.append(description)
  677. param_index += 1
  678. if description_embedding is not None:
  679. update_fields.append(f"description_embedding = ${param_index}")
  680. params.append(description_embedding)
  681. param_index += 1
  682. if weight is not None:
  683. update_fields.append(f"weight = ${param_index}")
  684. params.append(weight)
  685. param_index += 1
  686. if not update_fields:
  687. raise R2RException(status_code=400, message="No fields to update")
  688. update_fields.append("updated_at = NOW()")
  689. params.append(relationship_id)
  690. query = f"""
  691. UPDATE {self._get_table_name(table_name)}
  692. SET {', '.join(update_fields)}
  693. WHERE id = ${param_index}
  694. RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
  695. """
  696. try:
  697. result = await self.connection_manager.fetchrow_query(
  698. query=query,
  699. params=params,
  700. )
  701. return Relationship(
  702. id=result["id"],
  703. subject=result["subject"],
  704. predicate=result["predicate"],
  705. object=result["object"],
  706. description=result["description"],
  707. subject_id=result["subject_id"],
  708. object_id=result["object_id"],
  709. weight=result["weight"],
  710. chunk_ids=result["chunk_ids"],
  711. parent_id=result["parent_id"],
  712. metadata=result["metadata"],
  713. )
  714. except Exception as e:
  715. raise HTTPException(
  716. status_code=500,
  717. detail=f"An error occurred while updating the relationship: {e}",
  718. ) from e
  719. async def delete(
  720. self,
  721. parent_id: UUID,
  722. relationship_ids: Optional[list[UUID]] = None,
  723. store_type: StoreType = StoreType.GRAPHS,
  724. ) -> None:
  725. """
  726. Delete relationships from the specified store.
  727. If relationship_ids is not provided, deletes all relationships for the given parent_id.
  728. Args:
  729. parent_id: UUID of the parent (collection_id or document_id)
  730. relationship_ids: Optional list of specific relationship IDs to delete
  731. store_type: Type of store (graph or document)
  732. Returns:
  733. List of deleted relationship IDs
  734. Raises:
  735. R2RException: If specific relationships were requested but not all found
  736. """
  737. table_name = self._get_relationship_table_for_store(store_type)
  738. if relationship_ids is None:
  739. QUERY = f"""
  740. DELETE FROM {self._get_table_name(table_name)}
  741. WHERE parent_id = $1
  742. RETURNING id
  743. """
  744. results = await self.connection_manager.fetch_query(
  745. QUERY, [parent_id]
  746. )
  747. else:
  748. QUERY = f"""
  749. DELETE FROM {self._get_table_name(table_name)}
  750. WHERE id = ANY($1) AND parent_id = $2
  751. RETURNING id
  752. """
  753. results = await self.connection_manager.fetch_query(
  754. QUERY, [relationship_ids, parent_id]
  755. )
  756. deleted_ids = [row["id"] for row in results]
  757. if relationship_ids and len(deleted_ids) != len(relationship_ids):
  758. raise R2RException(
  759. f"Some relationships not found in {store_type} store or no permission to delete",
  760. 404,
  761. )
  762. async def export_to_csv(
  763. self,
  764. parent_id: UUID,
  765. store_type: StoreType,
  766. columns: Optional[list[str]] = None,
  767. filters: Optional[dict] = None,
  768. include_header: bool = True,
  769. ) -> tuple[str, IO]:
  770. """
  771. Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
  772. """
  773. valid_columns = {
  774. "id",
  775. "subject",
  776. "predicate",
  777. "object",
  778. "description",
  779. "subject_id",
  780. "object_id",
  781. "weight",
  782. "chunk_ids",
  783. "parent_id",
  784. "metadata",
  785. "created_at",
  786. "updated_at",
  787. }
  788. if not columns:
  789. columns = list(valid_columns)
  790. elif invalid_cols := set(columns) - valid_columns:
  791. raise ValueError(f"Invalid columns: {invalid_cols}")
  792. select_stmt = f"""
  793. SELECT
  794. id::text,
  795. subject,
  796. predicate,
  797. object,
  798. description,
  799. subject_id::text,
  800. object_id::text,
  801. weight,
  802. chunk_ids::text,
  803. parent_id::text,
  804. metadata::text,
  805. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
  806. to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
  807. FROM {self._get_table_name(self._get_relationship_table_for_store(store_type))}
  808. """
  809. conditions = ["parent_id = $1"]
  810. params: list[Any] = [parent_id]
  811. param_index = 2
  812. if filters:
  813. for field, value in filters.items():
  814. if field not in valid_columns:
  815. continue
  816. if isinstance(value, dict):
  817. for op, val in value.items():
  818. if op == "$eq":
  819. conditions.append(f"{field} = ${param_index}")
  820. params.append(val)
  821. param_index += 1
  822. elif op == "$gt":
  823. conditions.append(f"{field} > ${param_index}")
  824. params.append(val)
  825. param_index += 1
  826. elif op == "$lt":
  827. conditions.append(f"{field} < ${param_index}")
  828. params.append(val)
  829. param_index += 1
  830. else:
  831. # Direct equality
  832. conditions.append(f"{field} = ${param_index}")
  833. params.append(value)
  834. param_index += 1
  835. if conditions:
  836. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  837. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  838. temp_file = None
  839. try:
  840. temp_file = tempfile.NamedTemporaryFile(
  841. mode="w", delete=True, suffix=".csv"
  842. )
  843. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  844. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  845. async with conn.transaction():
  846. cursor = await conn.cursor(select_stmt, *params)
  847. if include_header:
  848. writer.writerow(columns)
  849. chunk_size = 1000
  850. while True:
  851. rows = await cursor.fetch(chunk_size)
  852. if not rows:
  853. break
  854. for row in rows:
  855. writer.writerow(row)
  856. temp_file.flush()
  857. return temp_file.name, temp_file
  858. except Exception as e:
  859. if temp_file:
  860. temp_file.close()
  861. raise HTTPException(
  862. status_code=500,
  863. detail=f"Failed to export data: {str(e)}",
  864. ) from e
  865. class PostgresCommunitiesHandler(Handler):
  866. def __init__(self, *args: Any, **kwargs: Any) -> None:
  867. self.project_name: str = kwargs.get("project_name") # type: ignore
  868. self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore
  869. self.dimension: int = kwargs.get("dimension") # type: ignore
  870. self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore
  871. async def create_tables(self) -> None:
  872. vector_column_str = _decorate_vector_type(
  873. f"({self.dimension})", self.quantization_type
  874. )
  875. query = f"""
  876. CREATE TABLE IF NOT EXISTS {self._get_table_name("graphs_communities")} (
  877. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  878. collection_id UUID,
  879. community_id UUID,
  880. level INT,
  881. name TEXT NOT NULL,
  882. summary TEXT NOT NULL,
  883. findings TEXT[],
  884. rating FLOAT,
  885. rating_explanation TEXT,
  886. description_embedding {vector_column_str} NOT NULL,
  887. created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
  888. updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
  889. metadata JSONB,
  890. UNIQUE (community_id, level, collection_id)
  891. );"""
  892. await self.connection_manager.execute_query(query)
  893. async def create(
  894. self,
  895. parent_id: UUID,
  896. store_type: StoreType,
  897. name: str,
  898. summary: str,
  899. findings: Optional[list[str]],
  900. rating: Optional[float],
  901. rating_explanation: Optional[str],
  902. description_embedding: Optional[list[float] | str] = None,
  903. ) -> Community:
  904. table_name = "graphs_communities"
  905. if isinstance(description_embedding, list):
  906. description_embedding = str(description_embedding)
  907. query = f"""
  908. INSERT INTO {self._get_table_name(table_name)}
  909. (collection_id, name, summary, findings, rating, rating_explanation, description_embedding)
  910. VALUES ($1, $2, $3, $4, $5, $6, $7)
  911. RETURNING id, collection_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
  912. """
  913. params = [
  914. parent_id,
  915. name,
  916. summary,
  917. findings,
  918. rating,
  919. rating_explanation,
  920. description_embedding,
  921. ]
  922. try:
  923. result = await self.connection_manager.fetchrow_query(
  924. query=query,
  925. params=params,
  926. )
  927. return Community(
  928. id=result["id"],
  929. collection_id=result["collection_id"],
  930. name=result["name"],
  931. summary=result["summary"],
  932. findings=result["findings"],
  933. rating=result["rating"],
  934. rating_explanation=result["rating_explanation"],
  935. created_at=result["created_at"],
  936. updated_at=result["updated_at"],
  937. )
  938. except Exception as e:
  939. raise HTTPException(
  940. status_code=500,
  941. detail=f"An error occurred while creating the community: {e}",
  942. ) from e
  943. async def update(
  944. self,
  945. community_id: UUID,
  946. store_type: StoreType,
  947. name: Optional[str] = None,
  948. summary: Optional[str] = None,
  949. summary_embedding: Optional[list[float] | str] = None,
  950. findings: Optional[list[str]] = None,
  951. rating: Optional[float] = None,
  952. rating_explanation: Optional[str] = None,
  953. ) -> Community:
  954. table_name = "graphs_communities"
  955. update_fields = []
  956. params: list[Any] = []
  957. param_index = 1
  958. if name is not None:
  959. update_fields.append(f"name = ${param_index}")
  960. params.append(name)
  961. param_index += 1
  962. if summary is not None:
  963. update_fields.append(f"summary = ${param_index}")
  964. params.append(summary)
  965. param_index += 1
  966. if summary_embedding is not None:
  967. update_fields.append(f"description_embedding = ${param_index}")
  968. params.append(summary_embedding)
  969. param_index += 1
  970. if findings is not None:
  971. update_fields.append(f"findings = ${param_index}")
  972. params.append(findings)
  973. param_index += 1
  974. if rating is not None:
  975. update_fields.append(f"rating = ${param_index}")
  976. params.append(rating)
  977. param_index += 1
  978. if rating_explanation is not None:
  979. update_fields.append(f"rating_explanation = ${param_index}")
  980. params.append(rating_explanation)
  981. param_index += 1
  982. if not update_fields:
  983. raise R2RException(status_code=400, message="No fields to update")
  984. update_fields.append("updated_at = NOW()")
  985. params.append(community_id)
  986. query = f"""
  987. UPDATE {self._get_table_name(table_name)}
  988. SET {", ".join(update_fields)}
  989. WHERE id = ${param_index}\
  990. RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
  991. """
  992. try:
  993. result = await self.connection_manager.fetchrow_query(
  994. query, params
  995. )
  996. return Community(
  997. id=result["id"],
  998. community_id=result["community_id"],
  999. name=result["name"],
  1000. summary=result["summary"],
  1001. findings=result["findings"],
  1002. rating=result["rating"],
  1003. rating_explanation=result["rating_explanation"],
  1004. created_at=result["created_at"],
  1005. updated_at=result["updated_at"],
  1006. )
  1007. except Exception as e:
  1008. raise HTTPException(
  1009. status_code=500,
  1010. detail=f"An error occurred while updating the community: {e}",
  1011. ) from e
  1012. async def delete(
  1013. self,
  1014. parent_id: UUID,
  1015. community_id: UUID,
  1016. ) -> None:
  1017. table_name = "graphs_communities"
  1018. params = [community_id, parent_id]
  1019. # Delete the community
  1020. query = f"""
  1021. DELETE FROM {self._get_table_name(table_name)}
  1022. WHERE id = $1 AND collection_id = $2
  1023. """
  1024. try:
  1025. await self.connection_manager.execute_query(query, params)
  1026. except Exception as e:
  1027. raise HTTPException(
  1028. status_code=500,
  1029. detail=f"An error occurred while deleting the community: {e}",
  1030. ) from e
  1031. async def delete_all_communities(
  1032. self,
  1033. parent_id: UUID,
  1034. ) -> None:
  1035. table_name = "graphs_communities"
  1036. params = [parent_id]
  1037. # Delete all communities for the parent_id
  1038. query = f"""
  1039. DELETE FROM {self._get_table_name(table_name)}
  1040. WHERE collection_id = $1
  1041. """
  1042. try:
  1043. await self.connection_manager.execute_query(query, params)
  1044. except Exception as e:
  1045. raise HTTPException(
  1046. status_code=500,
  1047. detail=f"An error occurred while deleting communities: {e}",
  1048. ) from e
  1049. async def get(
  1050. self,
  1051. parent_id: UUID,
  1052. store_type: StoreType,
  1053. offset: int,
  1054. limit: int,
  1055. community_ids: Optional[list[UUID]] = None,
  1056. community_names: Optional[list[str]] = None,
  1057. include_embeddings: bool = False,
  1058. ):
  1059. """Retrieve communities from the specified store."""
  1060. # Do we ever want to get communities from document store?
  1061. table_name = "graphs_communities"
  1062. conditions = ["collection_id = $1"]
  1063. params: list[Any] = [parent_id]
  1064. param_index = 2
  1065. if community_ids:
  1066. conditions.append(f"id = ANY(${param_index})")
  1067. params.append(community_ids)
  1068. param_index += 1
  1069. if community_names:
  1070. conditions.append(f"name = ANY(${param_index})")
  1071. params.append(community_names)
  1072. param_index += 1
  1073. select_fields = """
  1074. id, community_id, name, summary, findings, rating,
  1075. rating_explanation, level, created_at, updated_at
  1076. """
  1077. if include_embeddings:
  1078. select_fields += ", description_embedding"
  1079. COUNT_QUERY = f"""
  1080. SELECT COUNT(*)
  1081. FROM {self._get_table_name(table_name)}
  1082. WHERE {' AND '.join(conditions)}
  1083. """
  1084. count = (
  1085. await self.connection_manager.fetch_query(
  1086. COUNT_QUERY, params[: param_index - 1]
  1087. )
  1088. )[0]["count"]
  1089. QUERY = f"""
  1090. SELECT {select_fields}
  1091. FROM {self._get_table_name(table_name)}
  1092. WHERE {' AND '.join(conditions)}
  1093. ORDER BY created_at
  1094. OFFSET ${param_index}
  1095. """
  1096. params.append(offset)
  1097. param_index += 1
  1098. if limit != -1:
  1099. QUERY += f" LIMIT ${param_index}"
  1100. params.append(limit)
  1101. rows = await self.connection_manager.fetch_query(QUERY, params)
  1102. communities = []
  1103. for row in rows:
  1104. community_dict = dict(row)
  1105. communities.append(Community(**community_dict))
  1106. return communities, count
  1107. async def export_to_csv(
  1108. self,
  1109. parent_id: UUID,
  1110. store_type: StoreType,
  1111. columns: Optional[list[str]] = None,
  1112. filters: Optional[dict] = None,
  1113. include_header: bool = True,
  1114. ) -> tuple[str, IO]:
  1115. """
  1116. Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
  1117. """
  1118. valid_columns = {
  1119. "id",
  1120. "collection_id",
  1121. "community_id",
  1122. "level",
  1123. "name",
  1124. "summary",
  1125. "findings",
  1126. "rating",
  1127. "rating_explanation",
  1128. "created_at",
  1129. "updated_at",
  1130. "metadata",
  1131. }
  1132. if not columns:
  1133. columns = list(valid_columns)
  1134. elif invalid_cols := set(columns) - valid_columns:
  1135. raise ValueError(f"Invalid columns: {invalid_cols}")
  1136. table_name = "graphs_communities"
  1137. select_stmt = f"""
  1138. SELECT
  1139. id::text,
  1140. collection_id::text,
  1141. community_id::text,
  1142. level,
  1143. name,
  1144. summary,
  1145. findings::text,
  1146. rating,
  1147. rating_explanation,
  1148. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
  1149. to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
  1150. metadata::text
  1151. FROM {self._get_table_name(table_name)}
  1152. """
  1153. conditions = ["collection_id = $1"]
  1154. params: list[Any] = [parent_id]
  1155. param_index = 2
  1156. if filters:
  1157. for field, value in filters.items():
  1158. if field not in valid_columns:
  1159. continue
  1160. if isinstance(value, dict):
  1161. for op, val in value.items():
  1162. if op == "$eq":
  1163. conditions.append(f"{field} = ${param_index}")
  1164. params.append(val)
  1165. param_index += 1
  1166. elif op == "$gt":
  1167. conditions.append(f"{field} > ${param_index}")
  1168. params.append(val)
  1169. param_index += 1
  1170. elif op == "$lt":
  1171. conditions.append(f"{field} < ${param_index}")
  1172. params.append(val)
  1173. param_index += 1
  1174. else:
  1175. # Direct equality
  1176. conditions.append(f"{field} = ${param_index}")
  1177. params.append(value)
  1178. param_index += 1
  1179. if conditions:
  1180. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  1181. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  1182. temp_file = None
  1183. try:
  1184. temp_file = tempfile.NamedTemporaryFile(
  1185. mode="w", delete=True, suffix=".csv"
  1186. )
  1187. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  1188. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  1189. async with conn.transaction():
  1190. cursor = await conn.cursor(select_stmt, *params)
  1191. if include_header:
  1192. writer.writerow(columns)
  1193. chunk_size = 1000
  1194. while True:
  1195. rows = await cursor.fetch(chunk_size)
  1196. if not rows:
  1197. break
  1198. for row in rows:
  1199. writer.writerow(row)
  1200. temp_file.flush()
  1201. return temp_file.name, temp_file
  1202. except Exception as e:
  1203. if temp_file:
  1204. temp_file.close()
  1205. raise HTTPException(
  1206. status_code=500,
  1207. detail=f"Failed to export data: {str(e)}",
  1208. ) from e
  1209. class PostgresGraphsHandler(Handler):
  1210. """Handler for Knowledge Graph METHODS in PostgreSQL."""
  1211. TABLE_NAME = "graphs"
  1212. def __init__(
  1213. self,
  1214. *args: Any,
  1215. **kwargs: Any,
  1216. ) -> None:
  1217. self.project_name: str = kwargs.get("project_name") # type: ignore
  1218. self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore
  1219. self.dimension: int = kwargs.get("dimension") # type: ignore
  1220. self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore
  1221. self.collections_handler: PostgresCollectionsHandler = kwargs.get("collections_handler") # type: ignore
  1222. self.entities = PostgresEntitiesHandler(*args, **kwargs)
  1223. self.relationships = PostgresRelationshipsHandler(*args, **kwargs)
  1224. self.communities = PostgresCommunitiesHandler(*args, **kwargs)
  1225. self.handlers = [
  1226. self.entities,
  1227. self.relationships,
  1228. self.communities,
  1229. ]
  1230. import networkx as nx
  1231. self.nx = nx
  1232. async def create_tables(self) -> None:
  1233. """Create the graph tables with mandatory collection_id support."""
  1234. QUERY = f"""
  1235. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} (
  1236. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  1237. collection_id UUID NOT NULL,
  1238. name TEXT NOT NULL,
  1239. description TEXT,
  1240. status TEXT NOT NULL,
  1241. document_ids UUID[],
  1242. metadata JSONB,
  1243. created_at TIMESTAMPTZ DEFAULT NOW(),
  1244. updated_at TIMESTAMPTZ DEFAULT NOW()
  1245. );
  1246. CREATE INDEX IF NOT EXISTS graph_collection_id_idx
  1247. ON {self._get_table_name("graphs")} (collection_id);
  1248. """
  1249. await self.connection_manager.execute_query(QUERY)
  1250. for handler in self.handlers:
  1251. await handler.create_tables()
  1252. async def create(
  1253. self,
  1254. collection_id: UUID,
  1255. name: Optional[str] = None,
  1256. description: Optional[str] = None,
  1257. status: str = "pending",
  1258. ) -> GraphResponse:
  1259. """Create a new graph associated with a collection."""
  1260. name = name or f"Graph {collection_id}"
  1261. description = description or ""
  1262. query = f"""
  1263. INSERT INTO {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1264. (id, collection_id, name, description, status)
  1265. VALUES ($1, $2, $3, $4, $5)
  1266. RETURNING id, collection_id, name, description, status, created_at, updated_at, document_ids
  1267. """
  1268. params = [
  1269. collection_id,
  1270. collection_id,
  1271. name,
  1272. description,
  1273. status,
  1274. ]
  1275. try:
  1276. result = await self.connection_manager.fetchrow_query(
  1277. query=query,
  1278. params=params,
  1279. )
  1280. return GraphResponse(
  1281. id=result["id"],
  1282. collection_id=result["collection_id"],
  1283. name=result["name"],
  1284. description=result["description"],
  1285. status=result["status"],
  1286. created_at=result["created_at"],
  1287. updated_at=result["updated_at"],
  1288. document_ids=result["document_ids"] or [],
  1289. )
  1290. except UniqueViolationError:
  1291. raise R2RException(
  1292. message="Graph with this ID already exists",
  1293. status_code=409,
  1294. )
  1295. async def reset(self, parent_id: UUID) -> None:
  1296. """
  1297. Completely reset a graph and all associated data.
  1298. """
  1299. await self.entities.delete(
  1300. parent_id=parent_id, store_type=StoreType.GRAPHS
  1301. )
  1302. await self.relationships.delete(
  1303. parent_id=parent_id, store_type=StoreType.GRAPHS
  1304. )
  1305. await self.communities.delete_all_communities(parent_id=parent_id)
  1306. return
  1307. async def list_graphs(
  1308. self,
  1309. offset: int,
  1310. limit: int,
  1311. # filter_user_ids: Optional[list[UUID]] = None,
  1312. filter_graph_ids: Optional[list[UUID]] = None,
  1313. filter_collection_id: Optional[UUID] = None,
  1314. ) -> dict[str, list[GraphResponse] | int]:
  1315. conditions = []
  1316. params: list[Any] = []
  1317. param_index = 1
  1318. if filter_graph_ids:
  1319. conditions.append(f"id = ANY(${param_index})")
  1320. params.append(filter_graph_ids)
  1321. param_index += 1
  1322. # if filter_user_ids:
  1323. # conditions.append(f"user_id = ANY(${param_index})")
  1324. # params.append(filter_user_ids)
  1325. # param_index += 1
  1326. if filter_collection_id:
  1327. conditions.append(f"collection_id = ${param_index}")
  1328. params.append(filter_collection_id)
  1329. param_index += 1
  1330. where_clause = (
  1331. f"WHERE {' AND '.join(conditions)}" if conditions else ""
  1332. )
  1333. query = f"""
  1334. WITH RankedGraphs AS (
  1335. SELECT
  1336. id, collection_id, name, description, status, created_at, updated_at, document_ids,
  1337. COUNT(*) OVER() as total_entries,
  1338. ROW_NUMBER() OVER (PARTITION BY collection_id ORDER BY created_at DESC) as rn
  1339. FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1340. {where_clause}
  1341. )
  1342. SELECT * FROM RankedGraphs
  1343. WHERE rn = 1
  1344. ORDER BY created_at DESC
  1345. OFFSET ${param_index} LIMIT ${param_index + 1}
  1346. """
  1347. params.extend([offset, limit])
  1348. try:
  1349. results = await self.connection_manager.fetch_query(query, params)
  1350. if not results:
  1351. return {"results": [], "total_entries": 0}
  1352. total_entries = results[0]["total_entries"] if results else 0
  1353. graphs = [
  1354. GraphResponse(
  1355. id=row["id"],
  1356. document_ids=row["document_ids"] or [],
  1357. name=row["name"],
  1358. collection_id=row["collection_id"],
  1359. description=row["description"],
  1360. status=row["status"],
  1361. created_at=row["created_at"],
  1362. updated_at=row["updated_at"],
  1363. )
  1364. for row in results
  1365. ]
  1366. return {"results": graphs, "total_entries": total_entries}
  1367. except Exception as e:
  1368. raise HTTPException(
  1369. status_code=500,
  1370. detail=f"An error occurred while fetching graphs: {e}",
  1371. ) from e
  1372. async def get(
  1373. self, offset: int, limit: int, graph_id: Optional[UUID] = None
  1374. ):
  1375. if graph_id is None:
  1376. params = [offset, limit]
  1377. QUERY = f"""
  1378. SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1379. OFFSET $1 LIMIT $2
  1380. """
  1381. ret = await self.connection_manager.fetch_query(QUERY, params)
  1382. COUNT_QUERY = f"""
  1383. SELECT COUNT(*) FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1384. """
  1385. count = (await self.connection_manager.fetch_query(COUNT_QUERY))[
  1386. 0
  1387. ]["count"]
  1388. return {
  1389. "results": [Graph(**row) for row in ret],
  1390. "total_entries": count,
  1391. }
  1392. else:
  1393. QUERY = f"""
  1394. SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} WHERE id = $1
  1395. """
  1396. params = [graph_id] # type: ignore
  1397. return {
  1398. "results": [
  1399. Graph(
  1400. **await self.connection_manager.fetchrow_query(
  1401. QUERY, params
  1402. )
  1403. )
  1404. ]
  1405. }
  1406. async def add_documents(self, id: UUID, document_ids: list[UUID]) -> bool:
  1407. """
  1408. Add documents to the graph by copying their entities and relationships.
  1409. """
  1410. # Copy entities from document_entity to graphs_entities
  1411. ENTITY_COPY_QUERY = f"""
  1412. INSERT INTO {self._get_table_name("graphs_entities")} (
  1413. name, category, description, parent_id, description_embedding,
  1414. chunk_ids, metadata
  1415. )
  1416. SELECT
  1417. name, category, description, $1, description_embedding,
  1418. chunk_ids, metadata
  1419. FROM {self._get_table_name("documents_entities")}
  1420. WHERE parent_id = ANY($2)
  1421. """
  1422. await self.connection_manager.execute_query(
  1423. ENTITY_COPY_QUERY, [id, document_ids]
  1424. )
  1425. # Copy relationships from documents_relationships to graphs_relationships
  1426. RELATIONSHIP_COPY_QUERY = f"""
  1427. INSERT INTO {self._get_table_name("graphs_relationships")} (
  1428. subject, predicate, object, description, subject_id, object_id,
  1429. weight, chunk_ids, parent_id, metadata, description_embedding
  1430. )
  1431. SELECT
  1432. subject, predicate, object, description, subject_id, object_id,
  1433. weight, chunk_ids, $1, metadata, description_embedding
  1434. FROM {self._get_table_name("documents_relationships")}
  1435. WHERE parent_id = ANY($2)
  1436. """
  1437. await self.connection_manager.execute_query(
  1438. RELATIONSHIP_COPY_QUERY, [id, document_ids]
  1439. )
  1440. # Add document_ids to the graph
  1441. UPDATE_GRAPH_QUERY = f"""
  1442. UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1443. SET document_ids = array_cat(
  1444. CASE
  1445. WHEN document_ids IS NULL THEN ARRAY[]::uuid[]
  1446. ELSE document_ids
  1447. END,
  1448. $2::uuid[]
  1449. )
  1450. WHERE id = $1
  1451. """
  1452. await self.connection_manager.execute_query(
  1453. UPDATE_GRAPH_QUERY, [id, document_ids]
  1454. )
  1455. return True
  1456. async def update(
  1457. self,
  1458. collection_id: UUID,
  1459. name: Optional[str] = None,
  1460. description: Optional[str] = None,
  1461. ) -> GraphResponse:
  1462. """Update an existing graph."""
  1463. update_fields = []
  1464. params: list = []
  1465. param_index = 1
  1466. if name is not None:
  1467. update_fields.append(f"name = ${param_index}")
  1468. params.append(name)
  1469. param_index += 1
  1470. if description is not None:
  1471. update_fields.append(f"description = ${param_index}")
  1472. params.append(description)
  1473. param_index += 1
  1474. if not update_fields:
  1475. raise R2RException(status_code=400, message="No fields to update")
  1476. update_fields.append("updated_at = NOW()")
  1477. params.append(collection_id)
  1478. query = f"""
  1479. UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1480. SET {', '.join(update_fields)}
  1481. WHERE id = ${param_index}
  1482. RETURNING id, name, description, status, created_at, updated_at, collection_id, document_ids
  1483. """
  1484. try:
  1485. result = await self.connection_manager.fetchrow_query(
  1486. query, params
  1487. )
  1488. if not result:
  1489. raise R2RException(status_code=404, message="Graph not found")
  1490. return GraphResponse(
  1491. id=result["id"],
  1492. collection_id=result["collection_id"],
  1493. name=result["name"],
  1494. description=result["description"],
  1495. status=result["status"],
  1496. created_at=result["created_at"],
  1497. document_ids=result["document_ids"] or [],
  1498. updated_at=result["updated_at"],
  1499. )
  1500. except Exception as e:
  1501. raise HTTPException(
  1502. status_code=500,
  1503. detail=f"An error occurred while updating the graph: {e}",
  1504. ) from e
  1505. async def get_creation_estimate(
  1506. self,
  1507. graph_creation_settings: KGCreationSettings,
  1508. document_id: Optional[UUID] = None,
  1509. collection_id: Optional[UUID] = None,
  1510. ):
  1511. """Get the estimated cost and time for creating a KG."""
  1512. if bool(document_id) ^ bool(collection_id) is False:
  1513. raise ValueError(
  1514. "Exactly one of document_id or collection_id must be provided."
  1515. )
  1516. # todo: harmonize the document_id and id fields: postgres table contains document_id, but other places use id.
  1517. document_ids = (
  1518. [document_id]
  1519. if document_id
  1520. else [
  1521. doc.id for doc in (await self.collections_handler.documents_in_collection(collection_id, offset=0, limit=-1))["results"] # type: ignore
  1522. ]
  1523. )
  1524. chunk_counts = await self.connection_manager.fetch_query(
  1525. f"SELECT document_id, COUNT(*) as chunk_count FROM {self._get_table_name('vectors')} "
  1526. f"WHERE document_id = ANY($1) GROUP BY document_id",
  1527. [document_ids],
  1528. )
  1529. total_chunks = (
  1530. sum(doc["chunk_count"] for doc in chunk_counts)
  1531. // graph_creation_settings.chunk_merge_count
  1532. )
  1533. estimated_entities = (total_chunks * 10, total_chunks * 20)
  1534. estimated_relationships = (
  1535. int(estimated_entities[0] * 1.25),
  1536. int(estimated_entities[1] * 1.5),
  1537. )
  1538. estimated_llm_calls = (
  1539. total_chunks * 2 + estimated_entities[0],
  1540. total_chunks * 2 + estimated_entities[1],
  1541. )
  1542. total_in_out_tokens = tuple(
  1543. 2000 * calls // 1000000 for calls in estimated_llm_calls
  1544. )
  1545. cost_per_million = llm_cost_per_million_tokens(
  1546. graph_creation_settings.generation_config.model
  1547. )
  1548. estimated_cost = tuple(
  1549. tokens * cost_per_million for tokens in total_in_out_tokens
  1550. )
  1551. total_time_in_minutes = tuple(
  1552. tokens * 10 / 60 for tokens in total_in_out_tokens
  1553. )
  1554. return {
  1555. "message": 'Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `extract-triples` with `--run` in the cli, or `run_type="run"` in the client.',
  1556. "document_count": len(document_ids),
  1557. "number_of_jobs_created": len(document_ids) + 1,
  1558. "total_chunks": total_chunks,
  1559. "estimated_entities": _get_str_estimation_output(
  1560. estimated_entities
  1561. ),
  1562. "estimated_relationships": _get_str_estimation_output(
  1563. estimated_relationships
  1564. ),
  1565. "estimated_llm_calls": _get_str_estimation_output(
  1566. estimated_llm_calls
  1567. ),
  1568. "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(
  1569. total_in_out_tokens
  1570. ),
  1571. "estimated_cost_in_usd": _get_str_estimation_output(
  1572. estimated_cost
  1573. ),
  1574. "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: "
  1575. + _get_str_estimation_output(total_time_in_minutes),
  1576. }
  1577. async def get_enrichment_estimate(
  1578. self,
  1579. collection_id: UUID | None = None,
  1580. graph_id: UUID | None = None,
  1581. graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
  1582. ):
  1583. """Get the estimated cost and time for enriching a KG."""
  1584. if collection_id is not None:
  1585. document_ids = [
  1586. doc.id
  1587. for doc in (
  1588. await self.collections_handler.documents_in_collection(collection_id, offset=0, limit=-1) # type: ignore
  1589. )["results"]
  1590. ]
  1591. # Get entity and relationship counts
  1592. entity_count = (
  1593. await self.connection_manager.fetch_query(
  1594. f"SELECT COUNT(*) FROM {self._get_table_name('entity')} WHERE document_id = ANY($1);",
  1595. [document_ids],
  1596. )
  1597. )[0]["count"]
  1598. if not entity_count:
  1599. raise ValueError(
  1600. "No entities found in the graph. Please run `extract-triples` first."
  1601. )
  1602. relationship_count = (
  1603. await self.connection_manager.fetch_query(
  1604. f"""SELECT COUNT(*) FROM {self._get_table_name("documents_relationships")} WHERE document_id = ANY($1);""",
  1605. [document_ids],
  1606. )
  1607. )[0]["count"]
  1608. else:
  1609. entity_count = (
  1610. await self.connection_manager.fetch_query(
  1611. f"SELECT COUNT(*) FROM {self._get_table_name('entity')} WHERE $1 = ANY(graph_ids);",
  1612. [graph_id],
  1613. )
  1614. )[0]["count"]
  1615. if not entity_count:
  1616. raise ValueError(
  1617. "No entities found in the graph. Please run `extract-triples` first."
  1618. )
  1619. relationship_count = (
  1620. await self.connection_manager.fetch_query(
  1621. f"SELECT COUNT(*) FROM {self._get_table_name('relationship')} WHERE $1 = ANY(graph_ids);",
  1622. [graph_id],
  1623. )
  1624. )[0]["count"]
  1625. # Calculate estimates
  1626. estimated_llm_calls = (entity_count // 10, entity_count // 5)
  1627. tokens_in_millions = tuple(
  1628. 2000 * calls / 1000000 for calls in estimated_llm_calls
  1629. )
  1630. cost_per_million = llm_cost_per_million_tokens(
  1631. graph_enrichment_settings.generation_config.model # type: ignore
  1632. )
  1633. estimated_cost = tuple(
  1634. tokens * cost_per_million for tokens in tokens_in_millions
  1635. )
  1636. estimated_time = tuple(
  1637. tokens * 10 / 60 for tokens in tokens_in_millions
  1638. )
  1639. return {
  1640. "message": 'Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `build-communities` with `--run` in the cli, or `run_type="run"` in the client.',
  1641. "total_entities": entity_count,
  1642. "total_relationships": relationship_count,
  1643. "estimated_llm_calls": _get_str_estimation_output(
  1644. estimated_llm_calls
  1645. ),
  1646. "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(
  1647. tokens_in_millions
  1648. ),
  1649. "estimated_cost_in_usd": _get_str_estimation_output(
  1650. estimated_cost
  1651. ),
  1652. "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: "
  1653. + _get_str_estimation_output(estimated_time),
  1654. }
  1655. async def get_entities(
  1656. self,
  1657. parent_id: UUID,
  1658. offset: int,
  1659. limit: int,
  1660. entity_ids: Optional[list[UUID]] = None,
  1661. entity_names: Optional[list[str]] = None,
  1662. include_embeddings: bool = False,
  1663. ) -> tuple[list[Entity], int]:
  1664. """
  1665. Get entities for a graph.
  1666. Args:
  1667. offset: Number of records to skip
  1668. limit: Maximum number of records to return (-1 for no limit)
  1669. parent_id: UUID of the collection
  1670. entity_ids: Optional list of entity IDs to filter by
  1671. entity_names: Optional list of entity names to filter by
  1672. include_embeddings: Whether to include embeddings in the response
  1673. Returns:
  1674. Tuple of (list of entities, total count)
  1675. """
  1676. conditions = ["parent_id = $1"]
  1677. params: list[Any] = [parent_id]
  1678. param_index = 2
  1679. if entity_ids:
  1680. conditions.append(f"id = ANY(${param_index})")
  1681. params.append(entity_ids)
  1682. param_index += 1
  1683. if entity_names:
  1684. conditions.append(f"name = ANY(${param_index})")
  1685. params.append(entity_names)
  1686. param_index += 1
  1687. # Count query - uses the same conditions but without offset/limit
  1688. COUNT_QUERY = f"""
  1689. SELECT COUNT(*)
  1690. FROM {self._get_table_name("graphs_entities")}
  1691. WHERE {' AND '.join(conditions)}
  1692. """
  1693. count = (
  1694. await self.connection_manager.fetch_query(COUNT_QUERY, params)
  1695. )[0]["count"]
  1696. # Define base columns to select
  1697. select_fields = """
  1698. id, name, category, description, parent_id,
  1699. chunk_ids, metadata
  1700. """
  1701. if include_embeddings:
  1702. select_fields += ", description_embedding"
  1703. # Main query for fetching entities with pagination
  1704. QUERY = f"""
  1705. SELECT {select_fields}
  1706. FROM {self._get_table_name("graphs_entities")}
  1707. WHERE {' AND '.join(conditions)}
  1708. ORDER BY created_at
  1709. OFFSET ${param_index}
  1710. """
  1711. params.append(offset)
  1712. param_index += 1
  1713. if limit != -1:
  1714. QUERY += f" LIMIT ${param_index}"
  1715. params.append(limit)
  1716. rows = await self.connection_manager.fetch_query(QUERY, params)
  1717. entities = []
  1718. for row in rows:
  1719. entity_dict = dict(row)
  1720. if isinstance(entity_dict["metadata"], str):
  1721. with contextlib.suppress(json.JSONDecodeError):
  1722. entity_dict["metadata"] = json.loads(
  1723. entity_dict["metadata"]
  1724. )
  1725. entities.append(Entity(**entity_dict))
  1726. return entities, count
  1727. async def get_relationships(
  1728. self,
  1729. parent_id: UUID,
  1730. offset: int,
  1731. limit: int,
  1732. relationship_ids: Optional[list[UUID]] = None,
  1733. relationship_types: Optional[list[str]] = None,
  1734. include_embeddings: bool = False,
  1735. ) -> tuple[list[Relationship], int]:
  1736. """
  1737. Get relationships for a graph.
  1738. Args:
  1739. parent_id: UUID of the graph
  1740. offset: Number of records to skip
  1741. limit: Maximum number of records to return (-1 for no limit)
  1742. relationship_ids: Optional list of relationship IDs to filter by
  1743. relationship_types: Optional list of relationship types to filter by
  1744. include_metadata: Whether to include metadata in the response
  1745. Returns:
  1746. Tuple of (list of relationships, total count)
  1747. """
  1748. conditions = ["parent_id = $1"]
  1749. params: list[Any] = [parent_id]
  1750. param_index = 2
  1751. if relationship_ids:
  1752. conditions.append(f"id = ANY(${param_index})")
  1753. params.append(relationship_ids)
  1754. param_index += 1
  1755. if relationship_types:
  1756. conditions.append(f"predicate = ANY(${param_index})")
  1757. params.append(relationship_types)
  1758. param_index += 1
  1759. # Count query - uses the same conditions but without offset/limit
  1760. COUNT_QUERY = f"""
  1761. SELECT COUNT(*)
  1762. FROM {self._get_table_name("graphs_relationships")}
  1763. WHERE {' AND '.join(conditions)}
  1764. """
  1765. count = (
  1766. await self.connection_manager.fetch_query(COUNT_QUERY, params)
  1767. )[0]["count"]
  1768. # Define base columns to select
  1769. select_fields = """
  1770. id, subject, predicate, object, weight, chunk_ids, parent_id, metadata
  1771. """
  1772. if include_embeddings:
  1773. select_fields += ", description_embedding"
  1774. # Main query for fetching relationships with pagination
  1775. QUERY = f"""
  1776. SELECT {select_fields}
  1777. FROM {self._get_table_name("graphs_relationships")}
  1778. WHERE {' AND '.join(conditions)}
  1779. ORDER BY created_at
  1780. OFFSET ${param_index}
  1781. """
  1782. params.append(offset)
  1783. param_index += 1
  1784. if limit != -1:
  1785. QUERY += f" LIMIT ${param_index}"
  1786. params.append(limit)
  1787. rows = await self.connection_manager.fetch_query(QUERY, params)
  1788. relationships = []
  1789. for row in rows:
  1790. relationship_dict = dict(row)
  1791. if isinstance(relationship_dict["metadata"], str):
  1792. with contextlib.suppress(json.JSONDecodeError):
  1793. relationship_dict["metadata"] = json.loads(
  1794. relationship_dict["metadata"]
  1795. )
  1796. relationships.append(Relationship(**relationship_dict))
  1797. return relationships, count
  1798. async def add_entities(
  1799. self,
  1800. entities: list[Entity],
  1801. table_name: str,
  1802. conflict_columns: list[str] = [],
  1803. ) -> asyncpg.Record:
  1804. """
  1805. Upsert entities into the entities_raw table. These are raw entities extracted from the document.
  1806. Args:
  1807. entities: list[Entity]: list of entities to upsert
  1808. collection_name: str: name of the collection
  1809. Returns:
  1810. result: asyncpg.Record: result of the upsert operation
  1811. """
  1812. cleaned_entities = []
  1813. for entity in entities:
  1814. entity_dict = entity.to_dict()
  1815. entity_dict["chunk_ids"] = (
  1816. entity_dict["chunk_ids"]
  1817. if entity_dict.get("chunk_ids")
  1818. else []
  1819. )
  1820. entity_dict["description_embedding"] = (
  1821. str(entity_dict["description_embedding"])
  1822. if entity_dict.get("description_embedding") # type: ignore
  1823. else None
  1824. )
  1825. cleaned_entities.append(entity_dict)
  1826. return await _add_objects(
  1827. objects=cleaned_entities,
  1828. full_table_name=self._get_table_name(table_name),
  1829. connection_manager=self.connection_manager,
  1830. conflict_columns=conflict_columns,
  1831. )
  1832. async def get_all_relationships(
  1833. self,
  1834. collection_id: UUID | None,
  1835. graph_id: UUID | None,
  1836. document_ids: Optional[list[UUID]] = None,
  1837. ) -> list[Relationship]:
  1838. QUERY = f"""
  1839. SELECT id, subject, predicate, weight, object, parent_id FROM {self._get_table_name("graphs_relationships")} WHERE parent_id = ANY($1)
  1840. """
  1841. relationships = await self.connection_manager.fetch_query(
  1842. QUERY, [collection_id]
  1843. )
  1844. return [Relationship(**relationship) for relationship in relationships]
  1845. async def has_document(self, graph_id: UUID, document_id: UUID) -> bool:
  1846. """
  1847. Check if a document exists in the graph's document_ids array.
  1848. Args:
  1849. graph_id (UUID): ID of the graph to check
  1850. document_id (UUID): ID of the document to look for
  1851. Returns:
  1852. bool: True if document exists in graph, False otherwise
  1853. Raises:
  1854. R2RException: If graph not found
  1855. """
  1856. QUERY = f"""
  1857. SELECT EXISTS (
  1858. SELECT 1
  1859. FROM {self._get_table_name("graphs")}
  1860. WHERE id = $1
  1861. AND document_ids IS NOT NULL
  1862. AND $2 = ANY(document_ids)
  1863. ) as exists;
  1864. """
  1865. result = await self.connection_manager.fetchrow_query(
  1866. QUERY, [graph_id, document_id]
  1867. )
  1868. if result is None:
  1869. raise R2RException(f"Graph {graph_id} not found", 404)
  1870. return result["exists"]
  1871. async def get_communities(
  1872. self,
  1873. parent_id: UUID,
  1874. offset: int,
  1875. limit: int,
  1876. community_ids: Optional[list[UUID]] = None,
  1877. include_embeddings: bool = False,
  1878. ) -> tuple[list[Community], int]:
  1879. """
  1880. Get communities for a graph.
  1881. Args:
  1882. collection_id: UUID of the collection
  1883. offset: Number of records to skip
  1884. limit: Maximum number of records to return (-1 for no limit)
  1885. community_ids: Optional list of community IDs to filter by
  1886. include_embeddings: Whether to include embeddings in the response
  1887. Returns:
  1888. Tuple of (list of communities, total count)
  1889. """
  1890. conditions = ["collection_id = $1"]
  1891. params: list[Any] = [parent_id]
  1892. param_index = 2
  1893. if community_ids:
  1894. conditions.append(f"id = ANY(${param_index})")
  1895. params.append(community_ids)
  1896. param_index += 1
  1897. select_fields = """
  1898. id, collection_id, name, summary, findings, rating, rating_explanation
  1899. """
  1900. if include_embeddings:
  1901. select_fields += ", description_embedding"
  1902. COUNT_QUERY = f"""
  1903. SELECT COUNT(*)
  1904. FROM {self._get_table_name("graphs_communities")}
  1905. WHERE {' AND '.join(conditions)}
  1906. """
  1907. count = (
  1908. await self.connection_manager.fetch_query(COUNT_QUERY, params)
  1909. )[0]["count"]
  1910. QUERY = f"""
  1911. SELECT {select_fields}
  1912. FROM {self._get_table_name("graphs_communities")}
  1913. WHERE {' AND '.join(conditions)}
  1914. ORDER BY created_at
  1915. OFFSET ${param_index}
  1916. """
  1917. params.append(offset)
  1918. param_index += 1
  1919. if limit != -1:
  1920. QUERY += f" LIMIT ${param_index}"
  1921. params.append(limit)
  1922. rows = await self.connection_manager.fetch_query(QUERY, params)
  1923. communities = []
  1924. for row in rows:
  1925. community_dict = dict(row)
  1926. communities.append(Community(**community_dict))
  1927. return communities, count
  1928. async def add_community(self, community: Community) -> None:
  1929. # TODO: Fix in the short term.
  1930. # we need to do this because postgres insert needs to be a string
  1931. community.description_embedding = str(community.description_embedding) # type: ignore[assignment]
  1932. non_null_attrs = {
  1933. k: v for k, v in community.__dict__.items() if v is not None
  1934. }
  1935. columns = ", ".join(non_null_attrs.keys())
  1936. placeholders = ", ".join(f"${i+1}" for i in range(len(non_null_attrs)))
  1937. conflict_columns = ", ".join(
  1938. [f"{k} = EXCLUDED.{k}" for k in non_null_attrs]
  1939. )
  1940. QUERY = f"""
  1941. INSERT INTO {self._get_table_name("graphs_communities")} ({columns})
  1942. VALUES ({placeholders})
  1943. ON CONFLICT (community_id, level, collection_id) DO UPDATE SET
  1944. {conflict_columns}
  1945. """
  1946. await self.connection_manager.execute_many(
  1947. QUERY, [tuple(non_null_attrs.values())]
  1948. )
  1949. async def delete(self, collection_id: UUID) -> None:
  1950. graphs = await self.get(graph_id=collection_id, offset=0, limit=-1)
  1951. if len(graphs["results"]) == 0:
  1952. raise R2RException(
  1953. message=f"Graph not found for collection {collection_id}",
  1954. status_code=404,
  1955. )
  1956. await self.reset(collection_id)
  1957. # set status to PENDING for this collection.
  1958. QUERY = f"""
  1959. UPDATE {self._get_table_name("collections")} SET graph_cluster_status = $1 WHERE id = $2
  1960. """
  1961. await self.connection_manager.execute_query(
  1962. QUERY, [KGExtractionStatus.PENDING, collection_id]
  1963. )
  1964. # Delete the graph
  1965. QUERY = f"""
  1966. DELETE FROM {self._get_table_name("graphs")} WHERE collection_id = $1
  1967. """
  1968. async def perform_graph_clustering(
  1969. self,
  1970. collection_id: UUID,
  1971. leiden_params: dict[str, Any],
  1972. clustering_mode: str,
  1973. ) -> Tuple[int, Any]:
  1974. """
  1975. Calls the external clustering service to cluster the KG.
  1976. """
  1977. offset = 0
  1978. page_size = 1000
  1979. all_relationships = []
  1980. while True:
  1981. relationships, count = await self.relationships.get(
  1982. parent_id=collection_id,
  1983. store_type=StoreType.GRAPHS,
  1984. offset=offset,
  1985. limit=page_size,
  1986. )
  1987. if not relationships:
  1988. break
  1989. all_relationships.extend(relationships)
  1990. offset += len(relationships)
  1991. if offset >= count:
  1992. break
  1993. relationship_ids_cache = await self._get_relationship_ids_cache(
  1994. all_relationships
  1995. )
  1996. logger.info(
  1997. f"Clustering over {len(all_relationships)} relationships for {collection_id} with settings: {leiden_params}"
  1998. )
  1999. return await self._cluster_and_add_community_info(
  2000. relationships=all_relationships,
  2001. relationship_ids_cache=relationship_ids_cache,
  2002. leiden_params=leiden_params,
  2003. collection_id=collection_id,
  2004. clustering_mode=clustering_mode,
  2005. )
  2006. async def _call_clustering_service(
  2007. self, relationships: list[Relationship], leiden_params: dict[str, Any]
  2008. ) -> list[dict]:
  2009. """
  2010. Calls the external Graspologic clustering service, sending relationships and parameters.
  2011. Expects a response with 'communities' field.
  2012. """
  2013. # Convert relationships to a JSON-friendly format
  2014. rel_data = []
  2015. for r in relationships:
  2016. rel_data.append(
  2017. {
  2018. "id": str(r.id),
  2019. "subject": r.subject,
  2020. "object": r.object,
  2021. "weight": r.weight if r.weight is not None else 1.0,
  2022. }
  2023. )
  2024. endpoint = os.environ.get("CLUSTERING_SERVICE_URL")
  2025. if not endpoint:
  2026. raise ValueError("CLUSTERING_SERVICE_URL not set.")
  2027. url = f"{endpoint}/cluster"
  2028. payload = {"relationships": rel_data, "leiden_params": leiden_params}
  2029. async with httpx.AsyncClient() as client:
  2030. response = await client.post(url, json=payload, timeout=3600)
  2031. response.raise_for_status()
  2032. data = response.json()
  2033. communities = data.get("communities", [])
  2034. return communities
  2035. async def _create_graph_and_cluster(
  2036. self,
  2037. relationships: list[Relationship],
  2038. leiden_params: dict[str, Any],
  2039. clustering_mode: str = "remote",
  2040. ) -> Any:
  2041. """
  2042. Create a graph and cluster it. If clustering_mode='local', use hierarchical_leiden locally.
  2043. If clustering_mode='remote', call the external service.
  2044. """
  2045. if clustering_mode == "remote":
  2046. logger.info("Sending request to external clustering service...")
  2047. communities = await self._call_clustering_service(
  2048. relationships, leiden_params
  2049. )
  2050. logger.info("Received communities from clustering service.")
  2051. return communities
  2052. else:
  2053. # Local mode: run hierarchical_leiden directly
  2054. G = self.nx.Graph()
  2055. for relationship in relationships:
  2056. G.add_edge(
  2057. relationship.subject,
  2058. relationship.object,
  2059. weight=relationship.weight,
  2060. id=relationship.id,
  2061. )
  2062. logger.info(
  2063. f"Graph has {len(G.nodes)} nodes and {len(G.edges)} edges"
  2064. )
  2065. return await self._compute_leiden_communities(G, leiden_params)
  2066. async def _cluster_and_add_community_info(
  2067. self,
  2068. relationships: list[Relationship],
  2069. relationship_ids_cache: dict[str, list[int]],
  2070. leiden_params: dict[str, Any],
  2071. collection_id: Optional[UUID] = None,
  2072. clustering_mode: str = "local",
  2073. ) -> Tuple[int, Any]:
  2074. # clear if there is any old information
  2075. conditions = []
  2076. if collection_id is not None:
  2077. conditions.append("collection_id = $1")
  2078. await asyncio.sleep(0.1)
  2079. start_time = time.time()
  2080. logger.info(f"Creating graph and clustering for {collection_id}")
  2081. hierarchical_communities = await self._create_graph_and_cluster(
  2082. relationships=relationships,
  2083. leiden_params=leiden_params,
  2084. clustering_mode=clustering_mode,
  2085. )
  2086. logger.info(
  2087. f"Computing Leiden communities completed, time {time.time() - start_time:.2f} seconds."
  2088. )
  2089. def relationship_ids(node: str) -> list[int]:
  2090. return relationship_ids_cache.get(node, [])
  2091. logger.info(
  2092. f"Cached {len(relationship_ids_cache)} relationship ids, time {time.time() - start_time:.2f} seconds."
  2093. )
  2094. # If remote: hierarchical_communities is a list of dicts like:
  2095. # [{"node": str, "cluster": int, "level": int}, ...]
  2096. # If local: hierarchical_communities is the returned structure from hierarchical_leiden (list of named tuples)
  2097. if clustering_mode == "remote":
  2098. if not hierarchical_communities:
  2099. num_communities = 0
  2100. else:
  2101. num_communities = (
  2102. max(item["cluster"] for item in hierarchical_communities)
  2103. + 1
  2104. )
  2105. else:
  2106. # Local mode: hierarchical_communities returned by hierarchical_leiden
  2107. # According to the original code, it's likely a list of items with .cluster attribute
  2108. if not hierarchical_communities:
  2109. num_communities = 0
  2110. else:
  2111. num_communities = (
  2112. max(item.cluster for item in hierarchical_communities) + 1
  2113. )
  2114. logger.info(
  2115. f"Generated {num_communities} communities, time {time.time() - start_time:.2f} seconds."
  2116. )
  2117. return num_communities, hierarchical_communities
  2118. async def _get_relationship_ids_cache(
  2119. self, relationships: list[Relationship]
  2120. ) -> dict[str, list[int]]:
  2121. relationship_ids_cache: dict[str, list[int]] = {}
  2122. for relationship in relationships:
  2123. if relationship.subject is not None:
  2124. relationship_ids_cache.setdefault(relationship.subject, [])
  2125. if relationship.id is not None:
  2126. relationship_ids_cache[relationship.subject].append(
  2127. int(relationship.id)
  2128. )
  2129. if relationship.object is not None:
  2130. relationship_ids_cache.setdefault(relationship.object, [])
  2131. if relationship.id is not None:
  2132. relationship_ids_cache[relationship.object].append(
  2133. int(relationship.id)
  2134. )
  2135. return relationship_ids_cache
  2136. async def get_entity_map(
  2137. self, offset: int, limit: int, document_id: UUID
  2138. ) -> dict[str, dict[str, list[dict[str, Any]]]]:
  2139. QUERY1 = f"""
  2140. WITH entities_list AS (
  2141. SELECT DISTINCT name
  2142. FROM {self._get_table_name("documents_entities")}
  2143. WHERE parent_id = $1
  2144. ORDER BY name ASC
  2145. LIMIT {limit} OFFSET {offset}
  2146. )
  2147. SELECT e.name, e.description, e.category,
  2148. (SELECT array_agg(DISTINCT x) FROM unnest(e.chunk_ids) x) AS chunk_ids,
  2149. e.parent_id
  2150. FROM {self._get_table_name("documents_entities")} e
  2151. JOIN entities_list el ON e.name = el.name
  2152. GROUP BY e.name, e.description, e.category, e.chunk_ids, e.parent_id
  2153. ORDER BY e.name;"""
  2154. entities_list = await self.connection_manager.fetch_query(
  2155. QUERY1, [document_id]
  2156. )
  2157. entities_list = [Entity(**entity) for entity in entities_list]
  2158. QUERY2 = f"""
  2159. WITH entities_list AS (
  2160. SELECT DISTINCT name
  2161. FROM {self._get_table_name("documents_entities")}
  2162. WHERE parent_id = $1
  2163. ORDER BY name ASC
  2164. LIMIT {limit} OFFSET {offset}
  2165. )
  2166. SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description,
  2167. (SELECT array_agg(DISTINCT x) FROM unnest(t.chunk_ids) x) AS chunk_ids, t.parent_id
  2168. FROM {self._get_table_name("documents_relationships")} t
  2169. JOIN entities_list el ON t.subject = el.name
  2170. ORDER BY t.subject, t.predicate, t.object;
  2171. """
  2172. relationships_list = await self.connection_manager.fetch_query(
  2173. QUERY2, [document_id]
  2174. )
  2175. relationships_list = [
  2176. Relationship(**relationship) for relationship in relationships_list
  2177. ]
  2178. entity_map: dict[str, dict[str, list[Any]]] = {}
  2179. for entity in entities_list:
  2180. if entity.name not in entity_map:
  2181. entity_map[entity.name] = {"entities": [], "relationships": []}
  2182. entity_map[entity.name]["entities"].append(entity)
  2183. for relationship in relationships_list:
  2184. if relationship.subject in entity_map:
  2185. entity_map[relationship.subject]["relationships"].append(
  2186. relationship
  2187. )
  2188. if relationship.object in entity_map:
  2189. entity_map[relationship.object]["relationships"].append(
  2190. relationship
  2191. )
  2192. return entity_map
  2193. async def graph_search(
  2194. self, query: str, **kwargs: Any
  2195. ) -> AsyncGenerator[Any, None]:
  2196. """
  2197. Perform semantic search with similarity scores while maintaining exact same structure.
  2198. """
  2199. query_embedding = kwargs.get("query_embedding", None)
  2200. if query_embedding is None:
  2201. raise ValueError(
  2202. "query_embedding must be provided for semantic search"
  2203. )
  2204. search_type = kwargs.get(
  2205. "search_type", "entities"
  2206. ) # entities | relationships | communities
  2207. embedding_type = kwargs.get("embedding_type", "description_embedding")
  2208. property_names = kwargs.get("property_names", ["name", "description"])
  2209. # Add metadata if not present
  2210. if "metadata" not in property_names:
  2211. property_names.append("metadata")
  2212. filters = kwargs.get("filters", {})
  2213. limit = kwargs.get("limit", 10)
  2214. use_fulltext_search = kwargs.get("use_fulltext_search", True)
  2215. use_hybrid_search = kwargs.get("use_hybrid_search", True)
  2216. if use_hybrid_search or use_fulltext_search:
  2217. logger.warning(
  2218. "Hybrid and fulltext search not supported for graph search, ignoring."
  2219. )
  2220. table_name = f"graphs_{search_type}"
  2221. property_names_str = ", ".join(property_names)
  2222. # Build the WHERE clause from filters
  2223. params: list[str | int | bytes] = [
  2224. json.dumps(query_embedding),
  2225. limit,
  2226. ]
  2227. conditions_clause = self._build_filters(filters, params, search_type)
  2228. where_clause = (
  2229. f"WHERE {conditions_clause}" if conditions_clause else ""
  2230. )
  2231. # Construct the query
  2232. # Note: For vector similarity, we use <=> for distance. The smaller the number, the more similar.
  2233. # We'll convert that to similarity_score by doing (1 - distance).
  2234. QUERY = f"""
  2235. SELECT
  2236. {property_names_str},
  2237. ({embedding_type} <=> $1) as similarity_score
  2238. FROM {self._get_table_name(table_name)}
  2239. {where_clause}
  2240. ORDER BY {embedding_type} <=> $1
  2241. LIMIT $2;
  2242. """
  2243. results = await self.connection_manager.fetch_query(
  2244. QUERY, tuple(params)
  2245. )
  2246. for result in results:
  2247. output = {
  2248. prop: result[prop] for prop in property_names if prop in result
  2249. }
  2250. output["similarity_score"] = 1 - float(result["similarity_score"])
  2251. yield output
  2252. def _build_filters(
  2253. self, filter_dict: dict, parameters: list[Any], search_type: str
  2254. ) -> str:
  2255. """
  2256. Build a WHERE clause from a nested filter dictionary for the graph search.
  2257. For communities we use collection_id as primary key filter; for entities/relationships we use parent_id.
  2258. """
  2259. # Determine primary identifier column depending on search_type
  2260. # communities: use collection_id
  2261. # entities/relationships: use parent_id
  2262. base_id_column = (
  2263. "collection_id" if search_type == "communities" else "parent_id"
  2264. )
  2265. def parse_condition(key: str, value: Any) -> str:
  2266. # This function returns a single condition (string) or empty if no valid condition.
  2267. # Supported keys:
  2268. # - base_id_column (collection_id or parent_id)
  2269. # - metadata fields: metadata.some_field
  2270. # Supported ops: $eq, $ne, $lt, $lte, $gt, $gte, $in, $contains
  2271. if key == base_id_column:
  2272. # e.g. {"collection_id": {"$eq": "<some-uuid>"}}
  2273. if isinstance(value, dict):
  2274. op, clause = next(iter(value.items()))
  2275. if op == "$eq":
  2276. parameters.append(str(clause))
  2277. return f"{base_id_column} = ${len(parameters)}::uuid"
  2278. elif op == "$in":
  2279. # $in expects a list of UUIDs
  2280. parameters.append([str(x) for x in clause])
  2281. return f"{base_id_column} = ANY(${len(parameters)}::uuid[])"
  2282. else:
  2283. # direct equality?
  2284. parameters.append(str(value))
  2285. return f"{base_id_column} = ${len(parameters)}::uuid"
  2286. elif key.startswith("metadata."):
  2287. # Handle metadata filters
  2288. # Example: {"metadata.some_key": {"$eq": "value"}}
  2289. field = key.split("metadata.")[1]
  2290. if isinstance(value, dict):
  2291. op, clause = next(iter(value.items()))
  2292. if op == "$eq":
  2293. parameters.append(clause)
  2294. return f"(metadata->>'{field}') = ${len(parameters)}"
  2295. elif op == "$ne":
  2296. parameters.append(clause)
  2297. return f"(metadata->>'{field}') != ${len(parameters)}"
  2298. elif op == "$lt":
  2299. parameters.append(clause)
  2300. return f"(metadata->>'{field}')::float < ${len(parameters)}::float"
  2301. elif op == "$lte":
  2302. parameters.append(clause)
  2303. return f"(metadata->>'{field}')::float <= ${len(parameters)}::float"
  2304. elif op == "$gt":
  2305. parameters.append(clause)
  2306. return f"(metadata->>'{field}')::float > ${len(parameters)}::float"
  2307. elif op == "$gte":
  2308. parameters.append(clause)
  2309. return f"(metadata->>'{field}')::float >= ${len(parameters)}::float"
  2310. elif op == "$in":
  2311. # Ensure clause is a list
  2312. if not isinstance(clause, list):
  2313. raise Exception(
  2314. "argument to $in filter must be a list"
  2315. )
  2316. # Append the Python list as a parameter; many drivers can convert Python lists to arrays
  2317. parameters.append(clause)
  2318. # Cast the parameter to a text array type
  2319. return f"(metadata->>'{key}')::text = ANY(${len(parameters)}::text[])"
  2320. # elif op == "$in":
  2321. # # For $in, we assume an array of values and check if the field is in that set.
  2322. # # Note: This is simplistic, adjust as needed.
  2323. # parameters.append(clause)
  2324. # # convert field to text and check membership
  2325. # return f"(metadata->>'{field}') = ANY(SELECT jsonb_array_elements_text(${len(parameters)}::jsonb))"
  2326. elif op == "$contains":
  2327. # $contains for metadata likely means metadata @> clause in JSON.
  2328. # If clause is dict or list, we use json containment.
  2329. parameters.append(json.dumps(clause))
  2330. return f"metadata @> ${len(parameters)}::jsonb"
  2331. else:
  2332. # direct equality
  2333. parameters.append(value)
  2334. return f"(metadata->>'{field}') = ${len(parameters)}"
  2335. # Add additional conditions for other columns if needed
  2336. # If key not recognized, return empty so it doesn't break query
  2337. return ""
  2338. def parse_filter(fd: dict) -> str:
  2339. filter_conditions = []
  2340. for k, v in fd.items():
  2341. if k == "$and":
  2342. and_parts = [parse_filter(sub) for sub in v if sub]
  2343. # Remove empty strings
  2344. and_parts = [x for x in and_parts if x.strip()]
  2345. if and_parts:
  2346. filter_conditions.append(
  2347. f"({' AND '.join(and_parts)})"
  2348. )
  2349. elif k == "$or":
  2350. or_parts = [parse_filter(sub) for sub in v if sub]
  2351. # Remove empty strings
  2352. or_parts = [x for x in or_parts if x.strip()]
  2353. if or_parts:
  2354. filter_conditions.append(f"({' OR '.join(or_parts)})")
  2355. else:
  2356. # Regular condition
  2357. c = parse_condition(k, v)
  2358. if c and c.strip():
  2359. filter_conditions.append(c)
  2360. if not filter_conditions:
  2361. return ""
  2362. if len(filter_conditions) == 1:
  2363. return filter_conditions[0]
  2364. return " AND ".join(filter_conditions)
  2365. return parse_filter(filter_dict)
  2366. async def _compute_leiden_communities(
  2367. self,
  2368. graph: Any,
  2369. leiden_params: dict[str, Any],
  2370. ) -> Any:
  2371. """Compute Leiden communities."""
  2372. try:
  2373. from graspologic.partition import hierarchical_leiden
  2374. if "random_seed" not in leiden_params:
  2375. leiden_params["random_seed"] = (
  2376. 7272 # add seed to control randomness
  2377. )
  2378. start_time = time.time()
  2379. logger.info(
  2380. f"Running Leiden clustering with params: {leiden_params}"
  2381. )
  2382. community_mapping = hierarchical_leiden(graph, **leiden_params)
  2383. logger.info(
  2384. f"Leiden clustering completed in {time.time() - start_time:.2f} seconds."
  2385. )
  2386. return community_mapping
  2387. except ImportError as e:
  2388. raise ImportError("Please install the graspologic package.") from e
  2389. async def get_existing_document_entity_chunk_ids(
  2390. self, document_id: UUID
  2391. ) -> list[str]:
  2392. QUERY = f"""
  2393. SELECT DISTINCT unnest(chunk_ids) AS chunk_id FROM {self._get_table_name("documents_entities")} WHERE parent_id = $1
  2394. """
  2395. return [
  2396. item["chunk_id"]
  2397. for item in await self.connection_manager.fetch_query(
  2398. QUERY, [document_id]
  2399. )
  2400. ]
  2401. async def get_entity_count(
  2402. self,
  2403. collection_id: Optional[UUID] = None,
  2404. document_id: Optional[UUID] = None,
  2405. distinct: bool = False,
  2406. entity_table_name: str = "entity",
  2407. ) -> int:
  2408. if collection_id is None and document_id is None:
  2409. raise ValueError(
  2410. "Either collection_id or document_id must be provided."
  2411. )
  2412. conditions = ["parent_id = $1"]
  2413. params = [str(document_id)]
  2414. count_value = "DISTINCT name" if distinct else "*"
  2415. QUERY = f"""
  2416. SELECT COUNT({count_value}) FROM {self._get_table_name(entity_table_name)}
  2417. WHERE {" AND ".join(conditions)}
  2418. """
  2419. return (await self.connection_manager.fetch_query(QUERY, params))[0][
  2420. "count"
  2421. ]
  2422. async def update_entity_descriptions(self, entities: list[Entity]):
  2423. query = f"""
  2424. UPDATE {self._get_table_name("graphs_entities")}
  2425. SET description = $3, description_embedding = $4
  2426. WHERE name = $1 AND graph_id = $2
  2427. """
  2428. inputs = [
  2429. (
  2430. entity.name,
  2431. entity.parent_id,
  2432. entity.description,
  2433. entity.description_embedding,
  2434. )
  2435. for entity in entities
  2436. ]
  2437. await self.connection_manager.execute_many(query, inputs) # type: ignore
  2438. def _json_serialize(obj):
  2439. if isinstance(obj, UUID):
  2440. return str(obj)
  2441. elif isinstance(obj, (datetime.datetime, datetime.date)):
  2442. return obj.isoformat()
  2443. raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
  2444. async def _add_objects(
  2445. objects: list[dict],
  2446. full_table_name: str,
  2447. connection_manager: PostgresConnectionManager,
  2448. conflict_columns: list[str] = [],
  2449. exclude_metadata: list[str] = [],
  2450. ) -> list[UUID]:
  2451. """
  2452. Bulk insert objects into the specified table using jsonb_to_recordset.
  2453. """
  2454. # Exclude specified metadata and prepare data
  2455. cleaned_objects = []
  2456. for obj in objects:
  2457. cleaned_obj = {
  2458. k: v
  2459. for k, v in obj.items()
  2460. if k not in exclude_metadata and v is not None
  2461. }
  2462. cleaned_objects.append(cleaned_obj)
  2463. # Serialize the list of objects to JSON
  2464. json_data = json.dumps(cleaned_objects, default=_json_serialize)
  2465. # Prepare the column definitions for jsonb_to_recordset
  2466. columns = cleaned_objects[0].keys()
  2467. column_defs = []
  2468. for col in columns:
  2469. # Map Python types to PostgreSQL types
  2470. sample_value = cleaned_objects[0][col]
  2471. if "embedding" in col:
  2472. pg_type = "vector"
  2473. elif "chunk_ids" in col or "document_ids" in col or "graph_ids" in col:
  2474. pg_type = "uuid[]"
  2475. elif col == "id" or "_id" in col:
  2476. pg_type = "uuid"
  2477. elif isinstance(sample_value, str):
  2478. pg_type = "text"
  2479. elif isinstance(sample_value, UUID):
  2480. pg_type = "uuid"
  2481. elif isinstance(sample_value, (int, float)):
  2482. pg_type = "numeric"
  2483. elif isinstance(sample_value, list) and all(
  2484. isinstance(x, UUID) for x in sample_value
  2485. ):
  2486. pg_type = "uuid[]"
  2487. elif isinstance(sample_value, list):
  2488. pg_type = "jsonb"
  2489. elif isinstance(sample_value, dict):
  2490. pg_type = "jsonb"
  2491. elif isinstance(sample_value, bool):
  2492. pg_type = "boolean"
  2493. elif isinstance(sample_value, (datetime.datetime, datetime.date)):
  2494. pg_type = "timestamp"
  2495. else:
  2496. raise TypeError(
  2497. f"Unsupported data type for column '{col}': {type(sample_value)}"
  2498. )
  2499. column_defs.append(f"{col} {pg_type}")
  2500. columns_str = ", ".join(columns)
  2501. column_defs_str = ", ".join(column_defs)
  2502. if conflict_columns:
  2503. conflict_columns_str = ", ".join(conflict_columns)
  2504. update_columns_str = ", ".join(
  2505. f"{col}=EXCLUDED.{col}"
  2506. for col in columns
  2507. if col not in conflict_columns
  2508. )
  2509. on_conflict_clause = f"ON CONFLICT ({conflict_columns_str}) DO UPDATE SET {update_columns_str}"
  2510. else:
  2511. on_conflict_clause = ""
  2512. QUERY = f"""
  2513. INSERT INTO {full_table_name} ({columns_str})
  2514. SELECT {columns_str}
  2515. FROM jsonb_to_recordset($1::jsonb)
  2516. AS x({column_defs_str})
  2517. {on_conflict_clause}
  2518. RETURNING id;
  2519. """
  2520. # Execute the query
  2521. result = await connection_manager.fetch_query(QUERY, [json_data])
  2522. # Extract and return the IDs
  2523. return [record["id"] for record in result]