graphs.py 102 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913
  1. import asyncio
  2. import contextlib
  3. import csv
  4. import datetime
  5. import json
  6. import logging
  7. import os
  8. import tempfile
  9. import time
  10. from typing import IO, Any, AsyncGenerator, Optional, Tuple
  11. from uuid import UUID
  12. import asyncpg
  13. import httpx
  14. from asyncpg.exceptions import UniqueViolationError
  15. from fastapi import HTTPException
  16. from core.base.abstractions import (
  17. Community,
  18. Entity,
  19. Graph,
  20. GraphExtractionStatus,
  21. R2RException,
  22. Relationship,
  23. StoreType,
  24. VectorQuantizationType,
  25. )
  26. from core.base.api.models import GraphResponse
  27. from core.base.providers.database import Handler
  28. from core.base.utils import (
  29. _get_vector_column_str,
  30. generate_entity_document_id,
  31. )
  32. from .base import PostgresConnectionManager
  33. from .collections import PostgresCollectionsHandler
  34. logger = logging.getLogger()
  35. class PostgresEntitiesHandler(Handler):
  36. def __init__(self, *args: Any, **kwargs: Any) -> None:
  37. self.project_name: str = kwargs.get("project_name") # type: ignore
  38. self.connection_manager: PostgresConnectionManager = kwargs.get(
  39. "connection_manager"
  40. ) # type: ignore
  41. self.dimension: int = kwargs.get("dimension") # type: ignore
  42. self.quantization_type: VectorQuantizationType = kwargs.get(
  43. "quantization_type"
  44. ) # type: ignore
  45. self.relationships_handler: PostgresRelationshipsHandler = (
  46. PostgresRelationshipsHandler(*args, **kwargs)
  47. )
  48. def _get_table_name(self, table: str) -> str:
  49. """Get the fully qualified table name."""
  50. return f'"{self.project_name}"."{table}"'
  51. def _get_entity_table_for_store(self, store_type: StoreType) -> str:
  52. """Get the appropriate table name for the store type."""
  53. return f"{store_type.value}_entities"
  54. def _get_parent_constraint(self, store_type: StoreType) -> str:
  55. """Get the appropriate foreign key constraint for the store type."""
  56. if store_type == StoreType.GRAPHS:
  57. return f"""
  58. CONSTRAINT fk_graph
  59. FOREIGN KEY(parent_id)
  60. REFERENCES {self._get_table_name("graphs")}(id)
  61. ON DELETE CASCADE
  62. """
  63. else:
  64. return f"""
  65. CONSTRAINT fk_document
  66. FOREIGN KEY(parent_id)
  67. REFERENCES {self._get_table_name("documents")}(id)
  68. ON DELETE CASCADE
  69. """
  70. async def create_tables(self) -> None:
  71. """Create separate tables for graph and document entities."""
  72. vector_column_str = _get_vector_column_str(
  73. self.dimension, self.quantization_type
  74. )
  75. for store_type in StoreType:
  76. table_name = self._get_entity_table_for_store(store_type)
  77. parent_constraint = self._get_parent_constraint(store_type)
  78. QUERY = f"""
  79. CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
  80. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  81. name TEXT NOT NULL,
  82. category TEXT,
  83. description TEXT,
  84. parent_id UUID NOT NULL,
  85. description_embedding {vector_column_str},
  86. chunk_ids UUID[],
  87. metadata JSONB,
  88. created_at TIMESTAMPTZ DEFAULT NOW(),
  89. updated_at TIMESTAMPTZ DEFAULT NOW(),
  90. {parent_constraint}
  91. );
  92. CREATE INDEX IF NOT EXISTS {table_name}_name_idx
  93. ON {self._get_table_name(table_name)} (name);
  94. CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
  95. ON {self._get_table_name(table_name)} (parent_id);
  96. CREATE INDEX IF NOT EXISTS {table_name}_category_idx
  97. ON {self._get_table_name(table_name)} (category);
  98. """
  99. await self.connection_manager.execute_query(QUERY)
  100. async def create(
  101. self,
  102. parent_id: UUID,
  103. store_type: StoreType,
  104. name: str,
  105. category: Optional[str] = None,
  106. description: Optional[str] = None,
  107. description_embedding: Optional[list[float] | str] = None,
  108. chunk_ids: Optional[list[UUID]] = None,
  109. metadata: Optional[dict[str, Any] | str] = None,
  110. ) -> Entity:
  111. """Create a new entity in the specified store."""
  112. table_name = self._get_entity_table_for_store(store_type)
  113. if isinstance(metadata, str):
  114. with contextlib.suppress(json.JSONDecodeError):
  115. metadata = json.loads(metadata)
  116. if isinstance(description_embedding, list):
  117. description_embedding = str(description_embedding)
  118. query = f"""
  119. INSERT INTO {self._get_table_name(table_name)}
  120. (name, category, description, parent_id, description_embedding, chunk_ids, metadata)
  121. VALUES ($1, $2, $3, $4, $5, $6, $7)
  122. RETURNING id, name, category, description, parent_id, chunk_ids, metadata
  123. """
  124. params = [
  125. name,
  126. category,
  127. description,
  128. parent_id,
  129. description_embedding,
  130. chunk_ids,
  131. json.dumps(metadata) if metadata else None,
  132. ]
  133. result = await self.connection_manager.fetchrow_query(
  134. query=query,
  135. params=params,
  136. )
  137. return Entity(
  138. id=result["id"],
  139. name=result["name"],
  140. category=result["category"],
  141. description=result["description"],
  142. parent_id=result["parent_id"],
  143. chunk_ids=result["chunk_ids"],
  144. metadata=result["metadata"],
  145. )
  146. async def get(
  147. self,
  148. parent_id: UUID,
  149. store_type: StoreType,
  150. offset: int,
  151. limit: int,
  152. entity_ids: Optional[list[UUID]] = None,
  153. entity_names: Optional[list[str]] = None,
  154. include_embeddings: bool = False,
  155. ):
  156. """Retrieve entities from the specified store."""
  157. table_name = self._get_entity_table_for_store(store_type)
  158. conditions = ["parent_id = $1"]
  159. params: list[Any] = [parent_id]
  160. param_index = 2
  161. if entity_ids:
  162. conditions.append(f"id = ANY(${param_index})")
  163. params.append(entity_ids)
  164. param_index += 1
  165. if entity_names:
  166. conditions.append(f"name = ANY(${param_index})")
  167. params.append(entity_names)
  168. param_index += 1
  169. select_fields = """
  170. id, name, category, description, parent_id,
  171. chunk_ids, metadata
  172. """
  173. if include_embeddings:
  174. select_fields += ", description_embedding"
  175. COUNT_QUERY = f"""
  176. SELECT COUNT(*)
  177. FROM {self._get_table_name(table_name)}
  178. WHERE {" AND ".join(conditions)}
  179. """
  180. count_params = params[: param_index - 1]
  181. count = (
  182. await self.connection_manager.fetch_query(
  183. COUNT_QUERY, count_params
  184. )
  185. )[0]["count"]
  186. QUERY = f"""
  187. SELECT {select_fields}
  188. FROM {self._get_table_name(table_name)}
  189. WHERE {" AND ".join(conditions)}
  190. ORDER BY created_at
  191. OFFSET ${param_index}
  192. """
  193. params.append(offset)
  194. param_index += 1
  195. if limit != -1:
  196. QUERY += f" LIMIT ${param_index}"
  197. params.append(limit)
  198. rows = await self.connection_manager.fetch_query(QUERY, params)
  199. entities = []
  200. for row in rows:
  201. # Convert the Record to a dictionary
  202. entity_dict = dict(row)
  203. # Process metadata if it exists and is a string
  204. if isinstance(entity_dict["metadata"], str):
  205. with contextlib.suppress(json.JSONDecodeError):
  206. entity_dict["metadata"] = json.loads(
  207. entity_dict["metadata"]
  208. )
  209. entities.append(Entity(**entity_dict))
  210. return entities, count
  211. async def update(
  212. self,
  213. entity_id: UUID,
  214. store_type: StoreType,
  215. name: Optional[str] = None,
  216. description: Optional[str] = None,
  217. description_embedding: Optional[list[float] | str] = None,
  218. category: Optional[str] = None,
  219. metadata: Optional[dict] = None,
  220. ) -> Entity:
  221. """Update an entity in the specified store."""
  222. table_name = self._get_entity_table_for_store(store_type)
  223. update_fields = []
  224. params: list[Any] = []
  225. param_index = 1
  226. if isinstance(metadata, str):
  227. with contextlib.suppress(json.JSONDecodeError):
  228. metadata = json.loads(metadata)
  229. if name is not None:
  230. update_fields.append(f"name = ${param_index}")
  231. params.append(name)
  232. param_index += 1
  233. if description is not None:
  234. update_fields.append(f"description = ${param_index}")
  235. params.append(description)
  236. param_index += 1
  237. if description_embedding is not None:
  238. update_fields.append(f"description_embedding = ${param_index}")
  239. params.append(description_embedding)
  240. param_index += 1
  241. if category is not None:
  242. update_fields.append(f"category = ${param_index}")
  243. params.append(category)
  244. param_index += 1
  245. if metadata is not None:
  246. update_fields.append(f"metadata = ${param_index}")
  247. params.append(json.dumps(metadata))
  248. param_index += 1
  249. if not update_fields:
  250. raise R2RException(status_code=400, message="No fields to update")
  251. update_fields.append("updated_at = NOW()")
  252. params.append(entity_id)
  253. query = f"""
  254. UPDATE {self._get_table_name(table_name)}
  255. SET {", ".join(update_fields)}
  256. WHERE id = ${param_index}\
  257. RETURNING id, name, category, description, parent_id, chunk_ids, metadata
  258. """
  259. try:
  260. result = await self.connection_manager.fetchrow_query(
  261. query=query,
  262. params=params,
  263. )
  264. return Entity(
  265. id=result["id"],
  266. name=result["name"],
  267. category=result["category"],
  268. description=result["description"],
  269. parent_id=result["parent_id"],
  270. chunk_ids=result["chunk_ids"],
  271. metadata=result["metadata"],
  272. )
  273. except Exception as e:
  274. raise HTTPException(
  275. status_code=500,
  276. detail=f"An error occurred while updating the entity: {e}",
  277. ) from e
  278. async def delete(
  279. self,
  280. parent_id: UUID,
  281. entity_ids: Optional[list[UUID]] = None,
  282. store_type: StoreType = StoreType.GRAPHS,
  283. ) -> None:
  284. """Delete entities from the specified store. If entity_ids is not
  285. provided, deletes all entities for the given parent_id.
  286. Args:
  287. parent_id (UUID): Parent ID (collection_id or document_id)
  288. entity_ids (Optional[list[UUID]]): Specific entity IDs to delete. If None, deletes all entities for parent_id
  289. store_type (StoreType): Type of store (graph or document)
  290. Returns:
  291. list[UUID]: List of deleted entity IDs
  292. Raises:
  293. R2RException: If specific entities were requested but not all found
  294. """
  295. table_name = self._get_entity_table_for_store(store_type)
  296. if entity_ids is None:
  297. # Delete all entities for the parent_id
  298. QUERY = f"""
  299. DELETE FROM {self._get_table_name(table_name)}
  300. WHERE parent_id = $1
  301. RETURNING id
  302. """
  303. results = await self.connection_manager.fetch_query(
  304. QUERY, [parent_id]
  305. )
  306. else:
  307. # Delete specific entities
  308. QUERY = f"""
  309. DELETE FROM {self._get_table_name(table_name)}
  310. WHERE id = ANY($1) AND parent_id = $2
  311. RETURNING id
  312. """
  313. results = await self.connection_manager.fetch_query(
  314. QUERY, [entity_ids, parent_id]
  315. )
  316. # Check if all requested entities were deleted
  317. deleted_ids = [row["id"] for row in results]
  318. if entity_ids and len(deleted_ids) != len(entity_ids):
  319. raise R2RException(
  320. f"Some entities not found in {store_type} store or no permission to delete",
  321. 404,
  322. )
  323. async def get_duplicate_name_blocks(
  324. self,
  325. parent_id: UUID,
  326. store_type: StoreType,
  327. ) -> list[list[Entity]]:
  328. """Find all groups of entities that share identical names within the
  329. same parent.
  330. Returns a list of entity groups, where each group contains entities
  331. with the same name. For each group, includes the n most dissimilar
  332. descriptions based on cosine similarity.
  333. """
  334. table_name = self._get_entity_table_for_store(store_type)
  335. # First get the duplicate names and their descriptions with embeddings
  336. query = f"""
  337. WITH duplicates AS (
  338. SELECT name
  339. FROM {self._get_table_name(table_name)}
  340. WHERE parent_id = $1
  341. GROUP BY name
  342. HAVING COUNT(*) > 1
  343. )
  344. SELECT
  345. e.id, e.name, e.category, e.description,
  346. e.parent_id, e.chunk_ids, e.metadata
  347. FROM {self._get_table_name(table_name)} e
  348. WHERE e.parent_id = $1
  349. AND e.name IN (SELECT name FROM duplicates)
  350. ORDER BY e.name;
  351. """
  352. rows = await self.connection_manager.fetch_query(query, [parent_id])
  353. # Group entities by name
  354. name_groups: dict[str, list[Entity]] = {}
  355. for row in rows:
  356. entity_dict = dict(row)
  357. if isinstance(entity_dict["metadata"], str):
  358. with contextlib.suppress(json.JSONDecodeError):
  359. entity_dict["metadata"] = json.loads(
  360. entity_dict["metadata"]
  361. )
  362. entity = Entity(**entity_dict)
  363. name_groups.setdefault(entity.name, []).append(entity)
  364. return list(name_groups.values())
  365. async def merge_duplicate_name_blocks(
  366. self,
  367. parent_id: UUID,
  368. store_type: StoreType,
  369. ) -> list[tuple[list[Entity], Entity]]:
  370. """Merge entities that share identical names.
  371. Returns list of tuples: (original_entities, merged_entity)
  372. """
  373. duplicate_blocks = await self.get_duplicate_name_blocks(
  374. parent_id, store_type
  375. )
  376. merged_results: list[tuple[list[Entity], Entity]] = []
  377. for block in duplicate_blocks:
  378. # Create a new merged entity from the block
  379. merged_entity = await self._create_merged_entity(block)
  380. merged_results.append((block, merged_entity))
  381. table_name = self._get_entity_table_for_store(store_type)
  382. async with self.connection_manager.transaction():
  383. # Insert the merged entity
  384. new_id = await self._insert_merged_entity(
  385. merged_entity, table_name
  386. )
  387. merged_entity.id = new_id
  388. # Get the old entity IDs
  389. old_ids = [str(entity.id) for entity in block]
  390. relationship_table = self.relationships_handler._get_relationship_table_for_store(
  391. store_type
  392. )
  393. # Update relationships where old entities appear as subjects
  394. subject_update_query = f"""
  395. UPDATE {self._get_table_name(relationship_table)}
  396. SET subject_id = $1
  397. WHERE subject_id = ANY($2::uuid[])
  398. AND parent_id = $3
  399. """
  400. await self.connection_manager.execute_query(
  401. subject_update_query, [new_id, old_ids, parent_id]
  402. )
  403. # Update relationships where old entities appear as objects
  404. object_update_query = f"""
  405. UPDATE {self._get_table_name(relationship_table)}
  406. SET object_id = $1
  407. WHERE object_id = ANY($2::uuid[])
  408. AND parent_id = $3
  409. """
  410. await self.connection_manager.execute_query(
  411. object_update_query, [new_id, old_ids, parent_id]
  412. )
  413. # Delete the original entities
  414. delete_query = f"""
  415. DELETE FROM {self._get_table_name(table_name)}
  416. WHERE id = ANY($1::uuid[])
  417. """
  418. await self.connection_manager.execute_query(
  419. delete_query, [old_ids]
  420. )
  421. return merged_results
  422. async def _insert_merged_entity(
  423. self, entity: Entity, table_name: str
  424. ) -> UUID:
  425. """Insert merged entity and return its new ID."""
  426. new_id = generate_entity_document_id()
  427. query = f"""
  428. INSERT INTO {self._get_table_name(table_name)}
  429. (id, name, category, description, parent_id, chunk_ids, metadata)
  430. VALUES ($1, $2, $3, $4, $5, $6, $7)
  431. RETURNING id
  432. """
  433. values = [
  434. new_id,
  435. entity.name,
  436. entity.category,
  437. entity.description,
  438. entity.parent_id,
  439. entity.chunk_ids,
  440. json.dumps(entity.metadata) if entity.metadata else None,
  441. ]
  442. result = await self.connection_manager.fetch_query(query, values)
  443. return result[0]["id"]
  444. async def _create_merged_entity(self, entities: list[Entity]) -> Entity:
  445. """Create a merged entity from a list of duplicate entities.
  446. Uses various strategies to combine fields.
  447. """
  448. if not entities:
  449. raise ValueError("Cannot merge empty list of entities")
  450. # Take the first non-None category, or None if all are None
  451. category = next(
  452. (e.category for e in entities if e.category is not None), None
  453. )
  454. # Combine descriptions with newlines if they differ
  455. descriptions = {e.description for e in entities if e.description}
  456. description = "\n\n".join(descriptions) if descriptions else None
  457. # Combine chunk_ids, removing duplicates
  458. chunk_ids = list(
  459. {
  460. chunk_id
  461. for entity in entities
  462. for chunk_id in (entity.chunk_ids or [])
  463. }
  464. )
  465. # Merge metadata dictionaries
  466. merged_metadata: dict[str, Any] = {}
  467. for entity in entities:
  468. if entity.metadata:
  469. merged_metadata |= entity.metadata
  470. # Create new merged entity (without actually inserting to DB)
  471. return Entity(
  472. id=UUID(
  473. "00000000-0000-0000-0000-000000000000"
  474. ), # Placeholder UUID
  475. name=entities[0].name, # All entities in block have same name
  476. category=category,
  477. description=description,
  478. parent_id=entities[0].parent_id,
  479. chunk_ids=chunk_ids or None,
  480. metadata=merged_metadata or None,
  481. )
  482. async def export_to_csv(
  483. self,
  484. parent_id: UUID,
  485. store_type: StoreType,
  486. columns: Optional[list[str]] = None,
  487. filters: Optional[dict] = None,
  488. include_header: bool = True,
  489. ) -> tuple[str, IO]:
  490. """Creates a CSV file from the PostgreSQL data and returns the path to
  491. the temp file."""
  492. valid_columns = {
  493. "id",
  494. "name",
  495. "category",
  496. "description",
  497. "parent_id",
  498. "chunk_ids",
  499. "metadata",
  500. "created_at",
  501. "updated_at",
  502. }
  503. if not columns:
  504. columns = list(valid_columns)
  505. elif invalid_cols := set(columns) - valid_columns:
  506. raise ValueError(f"Invalid columns: {invalid_cols}")
  507. select_stmt = f"""
  508. SELECT
  509. id::text,
  510. name,
  511. category,
  512. description,
  513. parent_id::text,
  514. chunk_ids::text,
  515. metadata::text,
  516. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
  517. to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
  518. FROM {self._get_table_name(self._get_entity_table_for_store(store_type))}
  519. """
  520. conditions = ["parent_id = $1"]
  521. params: list[Any] = [parent_id]
  522. param_index = 2
  523. if filters:
  524. for field, value in filters.items():
  525. if field not in valid_columns:
  526. continue
  527. if isinstance(value, dict):
  528. for op, val in value.items():
  529. if op == "$eq":
  530. conditions.append(f"{field} = ${param_index}")
  531. params.append(val)
  532. param_index += 1
  533. elif op == "$gt":
  534. conditions.append(f"{field} > ${param_index}")
  535. params.append(val)
  536. param_index += 1
  537. elif op == "$lt":
  538. conditions.append(f"{field} < ${param_index}")
  539. params.append(val)
  540. param_index += 1
  541. else:
  542. # Direct equality
  543. conditions.append(f"{field} = ${param_index}")
  544. params.append(value)
  545. param_index += 1
  546. if conditions:
  547. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  548. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  549. temp_file = None
  550. try:
  551. temp_file = tempfile.NamedTemporaryFile(
  552. mode="w", delete=True, suffix=".csv"
  553. )
  554. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  555. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  556. async with conn.transaction():
  557. cursor = await conn.cursor(select_stmt, *params)
  558. if include_header:
  559. writer.writerow(columns)
  560. chunk_size = 1000
  561. while True:
  562. rows = await cursor.fetch(chunk_size)
  563. if not rows:
  564. break
  565. for row in rows:
  566. row_dict = {
  567. "id": row[0],
  568. "name": row[1],
  569. "category": row[2],
  570. "description": row[3],
  571. "parent_id": row[4],
  572. "chunk_ids": row[5],
  573. "metadata": row[6],
  574. "created_at": row[7],
  575. "updated_at": row[8],
  576. }
  577. writer.writerow([row_dict[col] for col in columns])
  578. temp_file.flush()
  579. return temp_file.name, temp_file
  580. except Exception as e:
  581. if temp_file:
  582. temp_file.close()
  583. raise HTTPException(
  584. status_code=500,
  585. detail=f"Failed to export data: {str(e)}",
  586. ) from e
  587. class PostgresRelationshipsHandler(Handler):
  588. def __init__(self, *args: Any, **kwargs: Any) -> None:
  589. self.project_name: str = kwargs.get("project_name") # type: ignore
  590. self.connection_manager: PostgresConnectionManager = kwargs.get(
  591. "connection_manager"
  592. ) # type: ignore
  593. self.dimension: int = kwargs.get("dimension") # type: ignore
  594. self.quantization_type: VectorQuantizationType = kwargs.get(
  595. "quantization_type"
  596. ) # type: ignore
  597. def _get_table_name(self, table: str) -> str:
  598. """Get the fully qualified table name."""
  599. return f'"{self.project_name}"."{table}"'
  600. def _get_relationship_table_for_store(self, store_type: StoreType) -> str:
  601. """Get the appropriate table name for the store type."""
  602. return f"{store_type.value}_relationships"
  603. def _get_parent_constraint(self, store_type: StoreType) -> str:
  604. """Get the appropriate foreign key constraint for the store type."""
  605. if store_type == StoreType.GRAPHS:
  606. return f"""
  607. CONSTRAINT fk_graph
  608. FOREIGN KEY(parent_id)
  609. REFERENCES {self._get_table_name("graphs")}(id)
  610. ON DELETE CASCADE
  611. """
  612. else:
  613. return f"""
  614. CONSTRAINT fk_document
  615. FOREIGN KEY(parent_id)
  616. REFERENCES {self._get_table_name("documents")}(id)
  617. ON DELETE CASCADE
  618. """
  619. async def create_tables(self) -> None:
  620. """Create separate tables for graph and document relationships."""
  621. for store_type in StoreType:
  622. table_name = self._get_relationship_table_for_store(store_type)
  623. parent_constraint = self._get_parent_constraint(store_type)
  624. vector_column_str = _get_vector_column_str(
  625. self.dimension, self.quantization_type
  626. )
  627. QUERY = f"""
  628. CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
  629. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  630. subject TEXT NOT NULL,
  631. predicate TEXT NOT NULL,
  632. object TEXT NOT NULL,
  633. description TEXT,
  634. description_embedding {vector_column_str},
  635. subject_id UUID,
  636. object_id UUID,
  637. weight FLOAT DEFAULT 1.0,
  638. chunk_ids UUID[],
  639. parent_id UUID NOT NULL,
  640. metadata JSONB,
  641. created_at TIMESTAMPTZ DEFAULT NOW(),
  642. updated_at TIMESTAMPTZ DEFAULT NOW(),
  643. {parent_constraint}
  644. );
  645. CREATE INDEX IF NOT EXISTS {table_name}_subject_idx
  646. ON {self._get_table_name(table_name)} (subject);
  647. CREATE INDEX IF NOT EXISTS {table_name}_object_idx
  648. ON {self._get_table_name(table_name)} (object);
  649. CREATE INDEX IF NOT EXISTS {table_name}_predicate_idx
  650. ON {self._get_table_name(table_name)} (predicate);
  651. CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
  652. ON {self._get_table_name(table_name)} (parent_id);
  653. CREATE INDEX IF NOT EXISTS {table_name}_subject_id_idx
  654. ON {self._get_table_name(table_name)} (subject_id);
  655. CREATE INDEX IF NOT EXISTS {table_name}_object_id_idx
  656. ON {self._get_table_name(table_name)} (object_id);
  657. """
  658. await self.connection_manager.execute_query(QUERY)
  659. async def create(
  660. self,
  661. subject: str,
  662. subject_id: UUID,
  663. predicate: str,
  664. object: str,
  665. object_id: UUID,
  666. parent_id: UUID,
  667. store_type: StoreType,
  668. description: str | None = None,
  669. weight: float | None = 1.0,
  670. chunk_ids: Optional[list[UUID]] = None,
  671. description_embedding: Optional[list[float] | str] = None,
  672. metadata: Optional[dict[str, Any] | str] = None,
  673. ) -> Relationship:
  674. """Create a new relationship in the specified store."""
  675. table_name = self._get_relationship_table_for_store(store_type)
  676. if isinstance(metadata, str):
  677. with contextlib.suppress(json.JSONDecodeError):
  678. metadata = json.loads(metadata)
  679. if isinstance(description_embedding, list):
  680. description_embedding = str(description_embedding)
  681. query = f"""
  682. INSERT INTO {self._get_table_name(table_name)}
  683. (subject, predicate, object, description, subject_id, object_id,
  684. weight, chunk_ids, parent_id, description_embedding, metadata)
  685. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
  686. RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
  687. """
  688. params = [
  689. subject,
  690. predicate,
  691. object,
  692. description,
  693. subject_id,
  694. object_id,
  695. weight,
  696. chunk_ids,
  697. parent_id,
  698. description_embedding,
  699. json.dumps(metadata) if metadata else None,
  700. ]
  701. result = await self.connection_manager.fetchrow_query(
  702. query=query,
  703. params=params,
  704. )
  705. return Relationship(
  706. id=result["id"],
  707. subject=result["subject"],
  708. predicate=result["predicate"],
  709. object=result["object"],
  710. description=result["description"],
  711. subject_id=result["subject_id"],
  712. object_id=result["object_id"],
  713. weight=result["weight"],
  714. chunk_ids=result["chunk_ids"],
  715. parent_id=result["parent_id"],
  716. metadata=result["metadata"],
  717. )
  718. async def get(
  719. self,
  720. parent_id: UUID,
  721. store_type: StoreType,
  722. offset: int,
  723. limit: int,
  724. relationship_ids: Optional[list[UUID]] = None,
  725. entity_names: Optional[list[str]] = None,
  726. relationship_types: Optional[list[str]] = None,
  727. include_metadata: bool = False,
  728. ):
  729. """Get relationships from the specified store.
  730. Args:
  731. parent_id: UUID of the parent (collection_id or document_id)
  732. store_type: Type of store (graph or document)
  733. offset: Number of records to skip
  734. limit: Maximum number of records to return (-1 for no limit)
  735. relationship_ids: Optional list of specific relationship IDs to retrieve
  736. entity_names: Optional list of entity names to filter by (matches subject or object)
  737. relationship_types: Optional list of relationship types (predicates) to filter by
  738. include_metadata: Whether to include metadata in the response
  739. Returns:
  740. Tuple of (list of relationships, total count)
  741. """
  742. table_name = self._get_relationship_table_for_store(store_type)
  743. conditions = ["parent_id = $1"]
  744. params: list[Any] = [parent_id]
  745. param_index = 2
  746. if relationship_ids:
  747. conditions.append(f"id = ANY(${param_index})")
  748. params.append(relationship_ids)
  749. param_index += 1
  750. if entity_names:
  751. conditions.append(
  752. f"(subject = ANY(${param_index}) OR object = ANY(${param_index}))"
  753. )
  754. params.append(entity_names)
  755. param_index += 1
  756. if relationship_types:
  757. conditions.append(f"predicate = ANY(${param_index})")
  758. params.append(relationship_types)
  759. param_index += 1
  760. select_fields = """
  761. id, subject, predicate, object, description,
  762. subject_id, object_id, weight, chunk_ids,
  763. parent_id
  764. """
  765. if include_metadata:
  766. select_fields += ", metadata"
  767. # Count query
  768. COUNT_QUERY = f"""
  769. SELECT COUNT(*)
  770. FROM {self._get_table_name(table_name)}
  771. WHERE {" AND ".join(conditions)}
  772. """
  773. count_params = params[: param_index - 1]
  774. count = (
  775. await self.connection_manager.fetch_query(
  776. COUNT_QUERY, count_params
  777. )
  778. )[0]["count"]
  779. # Main query
  780. QUERY = f"""
  781. SELECT {select_fields}
  782. FROM {self._get_table_name(table_name)}
  783. WHERE {" AND ".join(conditions)}
  784. ORDER BY created_at
  785. OFFSET ${param_index}
  786. """
  787. params.append(offset)
  788. param_index += 1
  789. if limit != -1:
  790. QUERY += f" LIMIT ${param_index}"
  791. params.append(limit)
  792. rows = await self.connection_manager.fetch_query(QUERY, params)
  793. relationships = []
  794. for row in rows:
  795. relationship_dict = dict(row)
  796. if include_metadata and isinstance(
  797. relationship_dict["metadata"], str
  798. ):
  799. with contextlib.suppress(json.JSONDecodeError):
  800. relationship_dict["metadata"] = json.loads(
  801. relationship_dict["metadata"]
  802. )
  803. elif not include_metadata:
  804. relationship_dict.pop("metadata", None)
  805. relationships.append(Relationship(**relationship_dict))
  806. return relationships, count
  807. async def update(
  808. self,
  809. relationship_id: UUID,
  810. store_type: StoreType,
  811. subject: Optional[str],
  812. subject_id: Optional[UUID],
  813. predicate: Optional[str],
  814. object: Optional[str],
  815. object_id: Optional[UUID],
  816. description: Optional[str],
  817. description_embedding: Optional[list[float] | str],
  818. weight: Optional[float],
  819. metadata: Optional[dict[str, Any] | str],
  820. ) -> Relationship:
  821. """Update multiple relationships in the specified store."""
  822. table_name = self._get_relationship_table_for_store(store_type)
  823. update_fields = []
  824. params: list = []
  825. param_index = 1
  826. if isinstance(metadata, str):
  827. with contextlib.suppress(json.JSONDecodeError):
  828. metadata = json.loads(metadata)
  829. if subject is not None:
  830. update_fields.append(f"subject = ${param_index}")
  831. params.append(subject)
  832. param_index += 1
  833. if subject_id is not None:
  834. update_fields.append(f"subject_id = ${param_index}")
  835. params.append(subject_id)
  836. param_index += 1
  837. if predicate is not None:
  838. update_fields.append(f"predicate = ${param_index}")
  839. params.append(predicate)
  840. param_index += 1
  841. if object is not None:
  842. update_fields.append(f"object = ${param_index}")
  843. params.append(object)
  844. param_index += 1
  845. if object_id is not None:
  846. update_fields.append(f"object_id = ${param_index}")
  847. params.append(object_id)
  848. param_index += 1
  849. if description is not None:
  850. update_fields.append(f"description = ${param_index}")
  851. params.append(description)
  852. param_index += 1
  853. if description_embedding is not None:
  854. update_fields.append(f"description_embedding = ${param_index}")
  855. params.append(description_embedding)
  856. param_index += 1
  857. if weight is not None:
  858. update_fields.append(f"weight = ${param_index}")
  859. params.append(weight)
  860. param_index += 1
  861. if not update_fields:
  862. raise R2RException(status_code=400, message="No fields to update")
  863. update_fields.append("updated_at = NOW()")
  864. params.append(relationship_id)
  865. query = f"""
  866. UPDATE {self._get_table_name(table_name)}
  867. SET {", ".join(update_fields)}
  868. WHERE id = ${param_index}
  869. RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
  870. """
  871. try:
  872. result = await self.connection_manager.fetchrow_query(
  873. query=query,
  874. params=params,
  875. )
  876. return Relationship(
  877. id=result["id"],
  878. subject=result["subject"],
  879. predicate=result["predicate"],
  880. object=result["object"],
  881. description=result["description"],
  882. subject_id=result["subject_id"],
  883. object_id=result["object_id"],
  884. weight=result["weight"],
  885. chunk_ids=result["chunk_ids"],
  886. parent_id=result["parent_id"],
  887. metadata=result["metadata"],
  888. )
  889. except Exception as e:
  890. raise HTTPException(
  891. status_code=500,
  892. detail=f"An error occurred while updating the relationship: {e}",
  893. ) from e
  894. async def delete(
  895. self,
  896. parent_id: UUID,
  897. relationship_ids: Optional[list[UUID]] = None,
  898. store_type: StoreType = StoreType.GRAPHS,
  899. ) -> None:
  900. """Delete relationships from the specified store. If relationship_ids
  901. is not provided, deletes all relationships for the given parent_id.
  902. Args:
  903. parent_id: UUID of the parent (collection_id or document_id)
  904. relationship_ids: Optional list of specific relationship IDs to delete
  905. store_type: Type of store (graph or document)
  906. Returns:
  907. List of deleted relationship IDs
  908. Raises:
  909. R2RException: If specific relationships were requested but not all found
  910. """
  911. table_name = self._get_relationship_table_for_store(store_type)
  912. if relationship_ids is None:
  913. QUERY = f"""
  914. DELETE FROM {self._get_table_name(table_name)}
  915. WHERE parent_id = $1
  916. RETURNING id
  917. """
  918. results = await self.connection_manager.fetch_query(
  919. QUERY, [parent_id]
  920. )
  921. else:
  922. QUERY = f"""
  923. DELETE FROM {self._get_table_name(table_name)}
  924. WHERE id = ANY($1) AND parent_id = $2
  925. RETURNING id
  926. """
  927. results = await self.connection_manager.fetch_query(
  928. QUERY, [relationship_ids, parent_id]
  929. )
  930. deleted_ids = [row["id"] for row in results]
  931. if relationship_ids and len(deleted_ids) != len(relationship_ids):
  932. raise R2RException(
  933. f"Some relationships not found in {store_type} store or no permission to delete",
  934. 404,
  935. )
  936. async def export_to_csv(
  937. self,
  938. parent_id: UUID,
  939. store_type: StoreType,
  940. columns: Optional[list[str]] = None,
  941. filters: Optional[dict] = None,
  942. include_header: bool = True,
  943. ) -> tuple[str, IO]:
  944. """Creates a CSV file from the PostgreSQL data and returns the path to
  945. the temp file."""
  946. valid_columns = {
  947. "id",
  948. "subject",
  949. "predicate",
  950. "object",
  951. "description",
  952. "subject_id",
  953. "object_id",
  954. "weight",
  955. "chunk_ids",
  956. "parent_id",
  957. "metadata",
  958. "created_at",
  959. "updated_at",
  960. }
  961. if not columns:
  962. columns = list(valid_columns)
  963. elif invalid_cols := set(columns) - valid_columns:
  964. raise ValueError(f"Invalid columns: {invalid_cols}")
  965. select_stmt = f"""
  966. SELECT
  967. id::text,
  968. subject,
  969. predicate,
  970. object,
  971. description,
  972. subject_id::text,
  973. object_id::text,
  974. weight,
  975. chunk_ids::text,
  976. parent_id::text,
  977. metadata::text,
  978. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
  979. to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
  980. FROM {self._get_table_name(self._get_relationship_table_for_store(store_type))}
  981. """
  982. conditions = ["parent_id = $1"]
  983. params: list[Any] = [parent_id]
  984. param_index = 2
  985. if filters:
  986. for field, value in filters.items():
  987. if field not in valid_columns:
  988. continue
  989. if isinstance(value, dict):
  990. for op, val in value.items():
  991. if op == "$eq":
  992. conditions.append(f"{field} = ${param_index}")
  993. params.append(val)
  994. param_index += 1
  995. elif op == "$gt":
  996. conditions.append(f"{field} > ${param_index}")
  997. params.append(val)
  998. param_index += 1
  999. elif op == "$lt":
  1000. conditions.append(f"{field} < ${param_index}")
  1001. params.append(val)
  1002. param_index += 1
  1003. else:
  1004. # Direct equality
  1005. conditions.append(f"{field} = ${param_index}")
  1006. params.append(value)
  1007. param_index += 1
  1008. if conditions:
  1009. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  1010. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  1011. temp_file = None
  1012. try:
  1013. temp_file = tempfile.NamedTemporaryFile(
  1014. mode="w", delete=True, suffix=".csv"
  1015. )
  1016. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  1017. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  1018. async with conn.transaction():
  1019. cursor = await conn.cursor(select_stmt, *params)
  1020. if include_header:
  1021. writer.writerow(columns)
  1022. chunk_size = 1000
  1023. while True:
  1024. rows = await cursor.fetch(chunk_size)
  1025. if not rows:
  1026. break
  1027. for row in rows:
  1028. row_dict = {
  1029. "id": row["id"],
  1030. "subject": row["subject"],
  1031. "predicate": row["predicate"],
  1032. "object": row["object"],
  1033. "description": row["description"],
  1034. "subject_id": row["subject_id"],
  1035. "object_id": row["object_id"],
  1036. "weight": row["weight"],
  1037. "chunk_ids": row["chunk_ids"],
  1038. "parent_id": row["parent_id"],
  1039. "metadata": row["metadata"],
  1040. "created_at": row["created_at"],
  1041. "updated_at": row["updated_at"],
  1042. }
  1043. writer.writerow([row_dict[col] for col in columns])
  1044. temp_file.flush()
  1045. return temp_file.name, temp_file
  1046. except Exception as e:
  1047. if temp_file:
  1048. temp_file.close()
  1049. raise HTTPException(
  1050. status_code=500,
  1051. detail=f"Failed to export data: {str(e)}",
  1052. ) from e
  1053. class PostgresCommunitiesHandler(Handler):
  1054. def __init__(self, *args: Any, **kwargs: Any) -> None:
  1055. self.project_name: str = kwargs.get("project_name") # type: ignore
  1056. self.connection_manager: PostgresConnectionManager = kwargs.get(
  1057. "connection_manager"
  1058. ) # type: ignore
  1059. self.dimension: int = kwargs.get("dimension") # type: ignore
  1060. self.quantization_type: VectorQuantizationType = kwargs.get(
  1061. "quantization_type"
  1062. ) # type: ignore
  1063. async def create_tables(self) -> None:
  1064. vector_column_str = _get_vector_column_str(
  1065. self.dimension, self.quantization_type
  1066. )
  1067. query = f"""
  1068. CREATE TABLE IF NOT EXISTS {self._get_table_name("graphs_communities")} (
  1069. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  1070. collection_id UUID,
  1071. community_id UUID,
  1072. level INT,
  1073. name TEXT NOT NULL,
  1074. summary TEXT NOT NULL,
  1075. findings TEXT[],
  1076. rating FLOAT,
  1077. rating_explanation TEXT,
  1078. description_embedding {vector_column_str} NOT NULL,
  1079. created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
  1080. updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
  1081. metadata JSONB,
  1082. UNIQUE (community_id, level, collection_id)
  1083. );"""
  1084. await self.connection_manager.execute_query(query)
  1085. async def create(
  1086. self,
  1087. parent_id: UUID,
  1088. store_type: StoreType,
  1089. name: str,
  1090. summary: str,
  1091. findings: Optional[list[str]],
  1092. rating: Optional[float],
  1093. rating_explanation: Optional[str],
  1094. description_embedding: Optional[list[float] | str] = None,
  1095. ) -> Community:
  1096. table_name = "graphs_communities"
  1097. if isinstance(description_embedding, list):
  1098. description_embedding = str(description_embedding)
  1099. query = f"""
  1100. INSERT INTO {self._get_table_name(table_name)}
  1101. (collection_id, name, summary, findings, rating, rating_explanation, description_embedding)
  1102. VALUES ($1, $2, $3, $4, $5, $6, $7)
  1103. RETURNING id, collection_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
  1104. """
  1105. params = [
  1106. parent_id,
  1107. name,
  1108. summary,
  1109. findings,
  1110. rating,
  1111. rating_explanation,
  1112. description_embedding,
  1113. ]
  1114. try:
  1115. result = await self.connection_manager.fetchrow_query(
  1116. query=query,
  1117. params=params,
  1118. )
  1119. return Community(
  1120. id=result["id"],
  1121. collection_id=result["collection_id"],
  1122. name=result["name"],
  1123. summary=result["summary"],
  1124. findings=result["findings"],
  1125. rating=result["rating"],
  1126. rating_explanation=result["rating_explanation"],
  1127. created_at=result["created_at"],
  1128. updated_at=result["updated_at"],
  1129. )
  1130. except Exception as e:
  1131. raise HTTPException(
  1132. status_code=500,
  1133. detail=f"An error occurred while creating the community: {e}",
  1134. ) from e
  1135. async def update(
  1136. self,
  1137. community_id: UUID,
  1138. store_type: StoreType,
  1139. name: Optional[str] = None,
  1140. summary: Optional[str] = None,
  1141. summary_embedding: Optional[list[float] | str] = None,
  1142. findings: Optional[list[str]] = None,
  1143. rating: Optional[float] = None,
  1144. rating_explanation: Optional[str] = None,
  1145. ) -> Community:
  1146. table_name = "graphs_communities"
  1147. update_fields = []
  1148. params: list[Any] = []
  1149. param_index = 1
  1150. if name is not None:
  1151. update_fields.append(f"name = ${param_index}")
  1152. params.append(name)
  1153. param_index += 1
  1154. if summary is not None:
  1155. update_fields.append(f"summary = ${param_index}")
  1156. params.append(summary)
  1157. param_index += 1
  1158. if summary_embedding is not None:
  1159. update_fields.append(f"description_embedding = ${param_index}")
  1160. params.append(summary_embedding)
  1161. param_index += 1
  1162. if findings is not None:
  1163. update_fields.append(f"findings = ${param_index}")
  1164. params.append(findings)
  1165. param_index += 1
  1166. if rating is not None:
  1167. update_fields.append(f"rating = ${param_index}")
  1168. params.append(rating)
  1169. param_index += 1
  1170. if rating_explanation is not None:
  1171. update_fields.append(f"rating_explanation = ${param_index}")
  1172. params.append(rating_explanation)
  1173. param_index += 1
  1174. if not update_fields:
  1175. raise R2RException(status_code=400, message="No fields to update")
  1176. update_fields.append("updated_at = NOW()")
  1177. params.append(community_id)
  1178. query = f"""
  1179. UPDATE {self._get_table_name(table_name)}
  1180. SET {", ".join(update_fields)}
  1181. WHERE id = ${param_index}\
  1182. RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
  1183. """
  1184. try:
  1185. result = await self.connection_manager.fetchrow_query(
  1186. query, params
  1187. )
  1188. return Community(
  1189. id=result["id"],
  1190. community_id=result["community_id"],
  1191. name=result["name"],
  1192. summary=result["summary"],
  1193. findings=result["findings"],
  1194. rating=result["rating"],
  1195. rating_explanation=result["rating_explanation"],
  1196. created_at=result["created_at"],
  1197. updated_at=result["updated_at"],
  1198. )
  1199. except Exception as e:
  1200. raise HTTPException(
  1201. status_code=500,
  1202. detail=f"An error occurred while updating the community: {e}",
  1203. ) from e
  1204. async def delete(
  1205. self,
  1206. parent_id: UUID,
  1207. community_id: UUID,
  1208. ) -> None:
  1209. table_name = "graphs_communities"
  1210. params = [community_id, parent_id]
  1211. # Delete the community
  1212. query = f"""
  1213. DELETE FROM {self._get_table_name(table_name)}
  1214. WHERE id = $1 AND collection_id = $2
  1215. """
  1216. try:
  1217. await self.connection_manager.execute_query(query, params)
  1218. except Exception as e:
  1219. raise HTTPException(
  1220. status_code=500,
  1221. detail=f"An error occurred while deleting the community: {e}",
  1222. ) from e
  1223. async def delete_all_communities(
  1224. self,
  1225. parent_id: UUID,
  1226. ) -> None:
  1227. table_name = "graphs_communities"
  1228. params = [parent_id]
  1229. # Delete all communities for the parent_id
  1230. query = f"""
  1231. DELETE FROM {self._get_table_name(table_name)}
  1232. WHERE collection_id = $1
  1233. """
  1234. try:
  1235. await self.connection_manager.execute_query(query, params)
  1236. except Exception as e:
  1237. raise HTTPException(
  1238. status_code=500,
  1239. detail=f"An error occurred while deleting communities: {e}",
  1240. ) from e
  1241. async def get(
  1242. self,
  1243. parent_id: UUID,
  1244. store_type: StoreType,
  1245. offset: int,
  1246. limit: int,
  1247. community_ids: Optional[list[UUID]] = None,
  1248. community_names: Optional[list[str]] = None,
  1249. include_embeddings: bool = False,
  1250. ):
  1251. """Retrieve communities from the specified store."""
  1252. # Do we ever want to get communities from document store?
  1253. table_name = "graphs_communities"
  1254. conditions = ["collection_id = $1"]
  1255. params: list[Any] = [parent_id]
  1256. param_index = 2
  1257. if community_ids:
  1258. conditions.append(f"id = ANY(${param_index})")
  1259. params.append(community_ids)
  1260. param_index += 1
  1261. if community_names:
  1262. conditions.append(f"name = ANY(${param_index})")
  1263. params.append(community_names)
  1264. param_index += 1
  1265. select_fields = """
  1266. id, community_id, name, summary, findings, rating,
  1267. rating_explanation, level, created_at, updated_at
  1268. """
  1269. if include_embeddings:
  1270. select_fields += ", description_embedding"
  1271. COUNT_QUERY = f"""
  1272. SELECT COUNT(*)
  1273. FROM {self._get_table_name(table_name)}
  1274. WHERE {" AND ".join(conditions)}
  1275. """
  1276. count = (
  1277. await self.connection_manager.fetch_query(
  1278. COUNT_QUERY, params[: param_index - 1]
  1279. )
  1280. )[0]["count"]
  1281. QUERY = f"""
  1282. SELECT {select_fields}
  1283. FROM {self._get_table_name(table_name)}
  1284. WHERE {" AND ".join(conditions)}
  1285. ORDER BY created_at
  1286. OFFSET ${param_index}
  1287. """
  1288. params.append(offset)
  1289. param_index += 1
  1290. if limit != -1:
  1291. QUERY += f" LIMIT ${param_index}"
  1292. params.append(limit)
  1293. rows = await self.connection_manager.fetch_query(QUERY, params)
  1294. communities = []
  1295. for row in rows:
  1296. community_dict = dict(row)
  1297. communities.append(Community(**community_dict))
  1298. return communities, count
  1299. async def export_to_csv(
  1300. self,
  1301. parent_id: UUID,
  1302. store_type: StoreType,
  1303. columns: Optional[list[str]] = None,
  1304. filters: Optional[dict] = None,
  1305. include_header: bool = True,
  1306. ) -> tuple[str, IO]:
  1307. """Creates a CSV file from the PostgreSQL data and returns the path to
  1308. the temp file."""
  1309. valid_columns = {
  1310. "id",
  1311. "collection_id",
  1312. "community_id",
  1313. "level",
  1314. "name",
  1315. "summary",
  1316. "findings",
  1317. "rating",
  1318. "rating_explanation",
  1319. "created_at",
  1320. "updated_at",
  1321. "metadata",
  1322. }
  1323. if not columns:
  1324. columns = list(valid_columns)
  1325. elif invalid_cols := set(columns) - valid_columns:
  1326. raise ValueError(f"Invalid columns: {invalid_cols}")
  1327. table_name = "graphs_communities"
  1328. select_stmt = f"""
  1329. SELECT
  1330. id::text,
  1331. collection_id::text,
  1332. community_id::text,
  1333. level,
  1334. name,
  1335. summary,
  1336. findings::text,
  1337. rating,
  1338. rating_explanation,
  1339. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
  1340. to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
  1341. metadata::text
  1342. FROM {self._get_table_name(table_name)}
  1343. """
  1344. conditions = ["collection_id = $1"]
  1345. params: list[Any] = [parent_id]
  1346. param_index = 2
  1347. if filters:
  1348. for field, value in filters.items():
  1349. if field not in valid_columns:
  1350. continue
  1351. if isinstance(value, dict):
  1352. for op, val in value.items():
  1353. if op == "$eq":
  1354. conditions.append(f"{field} = ${param_index}")
  1355. params.append(val)
  1356. param_index += 1
  1357. elif op == "$gt":
  1358. conditions.append(f"{field} > ${param_index}")
  1359. params.append(val)
  1360. param_index += 1
  1361. elif op == "$lt":
  1362. conditions.append(f"{field} < ${param_index}")
  1363. params.append(val)
  1364. param_index += 1
  1365. else:
  1366. # Direct equality
  1367. conditions.append(f"{field} = ${param_index}")
  1368. params.append(value)
  1369. param_index += 1
  1370. if conditions:
  1371. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  1372. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  1373. temp_file = None
  1374. try:
  1375. temp_file = tempfile.NamedTemporaryFile(
  1376. mode="w", delete=True, suffix=".csv"
  1377. )
  1378. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  1379. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  1380. async with conn.transaction():
  1381. cursor = await conn.cursor(select_stmt, *params)
  1382. if include_header:
  1383. writer.writerow(columns)
  1384. chunk_size = 1000
  1385. while True:
  1386. rows = await cursor.fetch(chunk_size)
  1387. if not rows:
  1388. break
  1389. for row in rows:
  1390. row_dict = {
  1391. "id": row[0],
  1392. "collection_id": row[1],
  1393. "community_id": row[2],
  1394. "level": row[3],
  1395. "name": row[4],
  1396. "summary": row[5],
  1397. "findings": row[6],
  1398. "rating": row[7],
  1399. "rating_explanation": row[8],
  1400. "created_at": row[9],
  1401. "updated_at": row[10],
  1402. "metadata": row[11],
  1403. }
  1404. writer.writerow([row_dict[col] for col in columns])
  1405. temp_file.flush()
  1406. return temp_file.name, temp_file
  1407. except Exception as e:
  1408. if temp_file:
  1409. temp_file.close()
  1410. raise HTTPException(
  1411. status_code=500,
  1412. detail=f"Failed to export data: {str(e)}",
  1413. ) from e
  1414. class PostgresGraphsHandler(Handler):
  1415. """Handler for Knowledge Graph METHODS in PostgreSQL."""
  1416. TABLE_NAME = "graphs"
  1417. def __init__(
  1418. self,
  1419. *args: Any,
  1420. **kwargs: Any,
  1421. ) -> None:
  1422. self.project_name: str = kwargs.get("project_name") # type: ignore
  1423. self.connection_manager: PostgresConnectionManager = kwargs.get(
  1424. "connection_manager"
  1425. ) # type: ignore
  1426. self.dimension: int = kwargs.get("dimension") # type: ignore
  1427. self.quantization_type: VectorQuantizationType = kwargs.get(
  1428. "quantization_type"
  1429. ) # type: ignore
  1430. self.collections_handler: PostgresCollectionsHandler = kwargs.get(
  1431. "collections_handler"
  1432. ) # type: ignore
  1433. self.entities = PostgresEntitiesHandler(*args, **kwargs)
  1434. self.relationships = PostgresRelationshipsHandler(*args, **kwargs)
  1435. self.communities = PostgresCommunitiesHandler(*args, **kwargs)
  1436. self.handlers = [
  1437. self.entities,
  1438. self.relationships,
  1439. self.communities,
  1440. ]
  1441. async def create_tables(self) -> None:
  1442. """Create the graph tables with mandatory collection_id support."""
  1443. QUERY = f"""
  1444. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} (
  1445. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  1446. collection_id UUID NOT NULL,
  1447. name TEXT NOT NULL,
  1448. description TEXT,
  1449. status TEXT NOT NULL,
  1450. document_ids UUID[],
  1451. metadata JSONB,
  1452. created_at TIMESTAMPTZ DEFAULT NOW(),
  1453. updated_at TIMESTAMPTZ DEFAULT NOW()
  1454. );
  1455. CREATE INDEX IF NOT EXISTS graph_collection_id_idx
  1456. ON {self._get_table_name("graphs")} (collection_id);
  1457. """
  1458. await self.connection_manager.execute_query(QUERY)
  1459. for handler in self.handlers:
  1460. await handler.create_tables()
  1461. async def create(
  1462. self,
  1463. collection_id: UUID,
  1464. name: Optional[str] = None,
  1465. description: Optional[str] = None,
  1466. status: str = "pending",
  1467. ) -> GraphResponse:
  1468. """Create a new graph associated with a collection."""
  1469. name = name or f"Graph {collection_id}"
  1470. description = description or ""
  1471. query = f"""
  1472. INSERT INTO {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1473. (id, collection_id, name, description, status)
  1474. VALUES ($1, $2, $3, $4, $5)
  1475. RETURNING id, collection_id, name, description, status, created_at, updated_at, document_ids
  1476. """
  1477. params = [
  1478. collection_id,
  1479. collection_id,
  1480. name,
  1481. description,
  1482. status,
  1483. ]
  1484. try:
  1485. result = await self.connection_manager.fetchrow_query(
  1486. query=query,
  1487. params=params,
  1488. )
  1489. return GraphResponse(
  1490. id=result["id"],
  1491. collection_id=result["collection_id"],
  1492. name=result["name"],
  1493. description=result["description"],
  1494. status=result["status"],
  1495. created_at=result["created_at"],
  1496. updated_at=result["updated_at"],
  1497. document_ids=result["document_ids"] or [],
  1498. )
  1499. except UniqueViolationError:
  1500. raise R2RException(
  1501. message="Graph with this ID already exists",
  1502. status_code=409,
  1503. ) from None
  1504. async def reset(self, parent_id: UUID) -> None:
  1505. """Completely reset a graph and all associated data."""
  1506. await self.entities.delete(
  1507. parent_id=parent_id, store_type=StoreType.GRAPHS
  1508. )
  1509. await self.relationships.delete(
  1510. parent_id=parent_id, store_type=StoreType.GRAPHS
  1511. )
  1512. await self.communities.delete_all_communities(parent_id=parent_id)
  1513. # Now, update the graph record to remove any attached document IDs.
  1514. # This sets document_ids to an empty UUID array.
  1515. query = f"""
  1516. UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1517. SET document_ids = ARRAY[]::uuid[]
  1518. WHERE id = $1;
  1519. """
  1520. await self.connection_manager.execute_query(query, [parent_id])
  1521. async def list_graphs(
  1522. self,
  1523. offset: int,
  1524. limit: int,
  1525. # filter_user_ids: Optional[list[UUID]] = None,
  1526. filter_graph_ids: Optional[list[UUID]] = None,
  1527. filter_collection_id: Optional[UUID] = None,
  1528. ) -> dict[str, list[GraphResponse] | int]:
  1529. conditions = []
  1530. params: list[Any] = []
  1531. param_index = 1
  1532. if filter_graph_ids:
  1533. conditions.append(f"id = ANY(${param_index})")
  1534. params.append(filter_graph_ids)
  1535. param_index += 1
  1536. # if filter_user_ids:
  1537. # conditions.append(f"user_id = ANY(${param_index})")
  1538. # params.append(filter_user_ids)
  1539. # param_index += 1
  1540. if filter_collection_id:
  1541. conditions.append(f"collection_id = ${param_index}")
  1542. params.append(filter_collection_id)
  1543. param_index += 1
  1544. where_clause = (
  1545. f"WHERE {' AND '.join(conditions)}" if conditions else ""
  1546. )
  1547. query = f"""
  1548. WITH RankedGraphs AS (
  1549. SELECT
  1550. id, collection_id, name, description, status, created_at, updated_at, document_ids,
  1551. COUNT(*) OVER() as total_entries,
  1552. ROW_NUMBER() OVER (PARTITION BY collection_id ORDER BY created_at DESC) as rn
  1553. FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1554. {where_clause}
  1555. )
  1556. SELECT * FROM RankedGraphs
  1557. WHERE rn = 1
  1558. ORDER BY created_at DESC
  1559. OFFSET ${param_index} LIMIT ${param_index + 1}
  1560. """
  1561. params.extend([offset, limit])
  1562. try:
  1563. results = await self.connection_manager.fetch_query(query, params)
  1564. if not results:
  1565. return {"results": [], "total_entries": 0}
  1566. total_entries = results[0]["total_entries"] if results else 0
  1567. graphs = [
  1568. GraphResponse(
  1569. id=row["id"],
  1570. document_ids=row["document_ids"] or [],
  1571. name=row["name"],
  1572. collection_id=row["collection_id"],
  1573. description=row["description"],
  1574. status=row["status"],
  1575. created_at=row["created_at"],
  1576. updated_at=row["updated_at"],
  1577. )
  1578. for row in results
  1579. ]
  1580. return {"results": graphs, "total_entries": total_entries}
  1581. except Exception as e:
  1582. raise HTTPException(
  1583. status_code=500,
  1584. detail=f"An error occurred while fetching graphs: {e}",
  1585. ) from e
  1586. async def get(
  1587. self, offset: int, limit: int, graph_id: Optional[UUID] = None
  1588. ):
  1589. if graph_id is None:
  1590. params = [offset, limit]
  1591. QUERY = f"""
  1592. SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1593. OFFSET $1 LIMIT $2
  1594. """
  1595. ret = await self.connection_manager.fetch_query(QUERY, params)
  1596. COUNT_QUERY = f"""
  1597. SELECT COUNT(*) FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1598. """
  1599. count = (await self.connection_manager.fetch_query(COUNT_QUERY))[
  1600. 0
  1601. ]["count"]
  1602. return {
  1603. "results": [Graph(**row) for row in ret],
  1604. "total_entries": count,
  1605. }
  1606. else:
  1607. QUERY = f"""
  1608. SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} WHERE id = $1
  1609. """
  1610. params = [graph_id] # type: ignore
  1611. return {
  1612. "results": [
  1613. Graph(
  1614. **await self.connection_manager.fetchrow_query(
  1615. QUERY, params
  1616. )
  1617. )
  1618. ]
  1619. }
  1620. async def add_documents(self, id: UUID, document_ids: list[UUID]) -> bool:
  1621. """Add documents to the graph by copying their entities and
  1622. relationships."""
  1623. # Copy entities from document_entity to graphs_entities
  1624. ENTITY_COPY_QUERY = f"""
  1625. INSERT INTO {self._get_table_name("graphs_entities")} (
  1626. name, category, description, parent_id, description_embedding,
  1627. chunk_ids, metadata
  1628. )
  1629. SELECT
  1630. name, category, description, $1, description_embedding,
  1631. chunk_ids, metadata
  1632. FROM {self._get_table_name("documents_entities")}
  1633. WHERE parent_id = ANY($2)
  1634. """
  1635. await self.connection_manager.execute_query(
  1636. ENTITY_COPY_QUERY, [id, document_ids]
  1637. )
  1638. # Copy relationships from documents_relationships to graphs_relationships
  1639. RELATIONSHIP_COPY_QUERY = f"""
  1640. INSERT INTO {self._get_table_name("graphs_relationships")} (
  1641. subject, predicate, object, description, subject_id, object_id,
  1642. weight, chunk_ids, parent_id, metadata, description_embedding
  1643. )
  1644. SELECT
  1645. subject, predicate, object, description, subject_id, object_id,
  1646. weight, chunk_ids, $1, metadata, description_embedding
  1647. FROM {self._get_table_name("documents_relationships")}
  1648. WHERE parent_id = ANY($2)
  1649. """
  1650. await self.connection_manager.execute_query(
  1651. RELATIONSHIP_COPY_QUERY, [id, document_ids]
  1652. )
  1653. # Add document_ids to the graph
  1654. UPDATE_GRAPH_QUERY = f"""
  1655. UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1656. SET document_ids = array_cat(
  1657. CASE
  1658. WHEN document_ids IS NULL THEN ARRAY[]::uuid[]
  1659. ELSE document_ids
  1660. END,
  1661. $2::uuid[]
  1662. )
  1663. WHERE id = $1
  1664. """
  1665. await self.connection_manager.execute_query(
  1666. UPDATE_GRAPH_QUERY, [id, document_ids]
  1667. )
  1668. return True
  1669. async def update(
  1670. self,
  1671. collection_id: UUID,
  1672. name: Optional[str] = None,
  1673. description: Optional[str] = None,
  1674. ) -> GraphResponse:
  1675. """Update an existing graph."""
  1676. update_fields = []
  1677. params: list = []
  1678. param_index = 1
  1679. if name is not None:
  1680. update_fields.append(f"name = ${param_index}")
  1681. params.append(name)
  1682. param_index += 1
  1683. if description is not None:
  1684. update_fields.append(f"description = ${param_index}")
  1685. params.append(description)
  1686. param_index += 1
  1687. if not update_fields:
  1688. raise R2RException(status_code=400, message="No fields to update")
  1689. update_fields.append("updated_at = NOW()")
  1690. params.append(collection_id)
  1691. query = f"""
  1692. UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1693. SET {", ".join(update_fields)}
  1694. WHERE id = ${param_index}
  1695. RETURNING id, name, description, status, created_at, updated_at, collection_id, document_ids
  1696. """
  1697. try:
  1698. result = await self.connection_manager.fetchrow_query(
  1699. query, params
  1700. )
  1701. if not result:
  1702. raise R2RException(status_code=404, message="Graph not found")
  1703. return GraphResponse(
  1704. id=result["id"],
  1705. collection_id=result["collection_id"],
  1706. name=result["name"],
  1707. description=result["description"],
  1708. status=result["status"],
  1709. created_at=result["created_at"],
  1710. document_ids=result["document_ids"] or [],
  1711. updated_at=result["updated_at"],
  1712. )
  1713. except Exception as e:
  1714. raise HTTPException(
  1715. status_code=500,
  1716. detail=f"An error occurred while updating the graph: {e}",
  1717. ) from e
  1718. async def get_entities(
  1719. self,
  1720. parent_id: UUID,
  1721. offset: int,
  1722. limit: int,
  1723. entity_ids: Optional[list[UUID]] = None,
  1724. entity_names: Optional[list[str]] = None,
  1725. include_embeddings: bool = False,
  1726. ) -> tuple[list[Entity], int]:
  1727. """Get entities for a graph.
  1728. Args:
  1729. offset: Number of records to skip
  1730. limit: Maximum number of records to return (-1 for no limit)
  1731. parent_id: UUID of the collection
  1732. entity_ids: Optional list of entity IDs to filter by
  1733. entity_names: Optional list of entity names to filter by
  1734. include_embeddings: Whether to include embeddings in the response
  1735. Returns:
  1736. Tuple of (list of entities, total count)
  1737. """
  1738. conditions = ["parent_id = $1"]
  1739. params: list[Any] = [parent_id]
  1740. param_index = 2
  1741. if entity_ids:
  1742. conditions.append(f"id = ANY(${param_index})")
  1743. params.append(entity_ids)
  1744. param_index += 1
  1745. if entity_names:
  1746. conditions.append(f"name = ANY(${param_index})")
  1747. params.append(entity_names)
  1748. param_index += 1
  1749. # Count query - uses the same conditions but without offset/limit
  1750. COUNT_QUERY = f"""
  1751. SELECT COUNT(*)
  1752. FROM {self._get_table_name("graphs_entities")}
  1753. WHERE {" AND ".join(conditions)}
  1754. """
  1755. count = (
  1756. await self.connection_manager.fetch_query(COUNT_QUERY, params)
  1757. )[0]["count"]
  1758. # Define base columns to select
  1759. select_fields = """
  1760. id, name, category, description, parent_id,
  1761. chunk_ids, metadata
  1762. """
  1763. if include_embeddings:
  1764. select_fields += ", description_embedding"
  1765. # Main query for fetching entities with pagination
  1766. QUERY = f"""
  1767. SELECT {select_fields}
  1768. FROM {self._get_table_name("graphs_entities")}
  1769. WHERE {" AND ".join(conditions)}
  1770. ORDER BY created_at
  1771. OFFSET ${param_index}
  1772. """
  1773. params.append(offset)
  1774. param_index += 1
  1775. if limit != -1:
  1776. QUERY += f" LIMIT ${param_index}"
  1777. params.append(limit)
  1778. rows = await self.connection_manager.fetch_query(QUERY, params)
  1779. entities = []
  1780. for row in rows:
  1781. entity_dict = dict(row)
  1782. if isinstance(entity_dict["metadata"], str):
  1783. with contextlib.suppress(json.JSONDecodeError):
  1784. entity_dict["metadata"] = json.loads(
  1785. entity_dict["metadata"]
  1786. )
  1787. entities.append(Entity(**entity_dict))
  1788. return entities, count
  1789. async def get_relationships(
  1790. self,
  1791. parent_id: UUID,
  1792. offset: int,
  1793. limit: int,
  1794. relationship_ids: Optional[list[UUID]] = None,
  1795. relationship_types: Optional[list[str]] = None,
  1796. include_embeddings: bool = False,
  1797. ) -> tuple[list[Relationship], int]:
  1798. """Get relationships for a graph.
  1799. Args:
  1800. parent_id: UUID of the graph
  1801. offset: Number of records to skip
  1802. limit: Maximum number of records to return (-1 for no limit)
  1803. relationship_ids: Optional list of relationship IDs to filter by
  1804. relationship_types: Optional list of relationship types to filter by
  1805. include_metadata: Whether to include metadata in the response
  1806. Returns:
  1807. Tuple of (list of relationships, total count)
  1808. """
  1809. conditions = ["parent_id = $1"]
  1810. params: list[Any] = [parent_id]
  1811. param_index = 2
  1812. if relationship_ids:
  1813. conditions.append(f"id = ANY(${param_index})")
  1814. params.append(relationship_ids)
  1815. param_index += 1
  1816. if relationship_types:
  1817. conditions.append(f"predicate = ANY(${param_index})")
  1818. params.append(relationship_types)
  1819. param_index += 1
  1820. # Count query - uses the same conditions but without offset/limit
  1821. COUNT_QUERY = f"""
  1822. SELECT COUNT(*)
  1823. FROM {self._get_table_name("graphs_relationships")}
  1824. WHERE {" AND ".join(conditions)}
  1825. """
  1826. count = (
  1827. await self.connection_manager.fetch_query(COUNT_QUERY, params)
  1828. )[0]["count"]
  1829. # Define base columns to select
  1830. select_fields = """
  1831. id, subject, predicate, object, weight, chunk_ids, parent_id, metadata
  1832. """
  1833. if include_embeddings:
  1834. select_fields += ", description_embedding"
  1835. # Main query for fetching relationships with pagination
  1836. QUERY = f"""
  1837. SELECT {select_fields}
  1838. FROM {self._get_table_name("graphs_relationships")}
  1839. WHERE {" AND ".join(conditions)}
  1840. ORDER BY created_at
  1841. OFFSET ${param_index}
  1842. """
  1843. params.append(offset)
  1844. param_index += 1
  1845. if limit != -1:
  1846. QUERY += f" LIMIT ${param_index}"
  1847. params.append(limit)
  1848. rows = await self.connection_manager.fetch_query(QUERY, params)
  1849. relationships = []
  1850. for row in rows:
  1851. relationship_dict = dict(row)
  1852. if isinstance(relationship_dict["metadata"], str):
  1853. with contextlib.suppress(json.JSONDecodeError):
  1854. relationship_dict["metadata"] = json.loads(
  1855. relationship_dict["metadata"]
  1856. )
  1857. relationships.append(Relationship(**relationship_dict))
  1858. return relationships, count
  1859. async def add_entities(
  1860. self,
  1861. entities: list[Entity],
  1862. table_name: str,
  1863. conflict_columns: list[str] | None = None,
  1864. ) -> asyncpg.Record:
  1865. """Upsert entities into the entities_raw table. These are raw entities
  1866. extracted from the document.
  1867. Args:
  1868. entities: list[Entity]: list of entities to upsert
  1869. collection_name: str: name of the collection
  1870. Returns:
  1871. result: asyncpg.Record: result of the upsert operation
  1872. """
  1873. if not conflict_columns:
  1874. conflict_columns = []
  1875. cleaned_entities = []
  1876. for entity in entities:
  1877. entity_dict = entity.to_dict()
  1878. entity_dict["chunk_ids"] = (
  1879. entity_dict["chunk_ids"]
  1880. if entity_dict.get("chunk_ids")
  1881. else []
  1882. )
  1883. entity_dict["description_embedding"] = (
  1884. str(entity_dict["description_embedding"])
  1885. if entity_dict.get("description_embedding") # type: ignore
  1886. else None
  1887. )
  1888. cleaned_entities.append(entity_dict)
  1889. return await _add_objects(
  1890. objects=cleaned_entities,
  1891. full_table_name=self._get_table_name(table_name),
  1892. connection_manager=self.connection_manager,
  1893. conflict_columns=conflict_columns,
  1894. )
  1895. async def get_all_relationships(
  1896. self,
  1897. collection_id: UUID | None,
  1898. graph_id: UUID | None,
  1899. document_ids: Optional[list[UUID]] = None,
  1900. ) -> list[Relationship]:
  1901. QUERY = f"""
  1902. SELECT id, subject, predicate, weight, object, parent_id FROM {self._get_table_name("graphs_relationships")} WHERE parent_id = ANY($1)
  1903. """
  1904. relationships = await self.connection_manager.fetch_query(
  1905. QUERY, [collection_id]
  1906. )
  1907. return [Relationship(**relationship) for relationship in relationships]
  1908. async def has_document(self, graph_id: UUID, document_id: UUID) -> bool:
  1909. """Check if a document exists in the graph's document_ids array.
  1910. Args:
  1911. graph_id (UUID): ID of the graph to check
  1912. document_id (UUID): ID of the document to look for
  1913. Returns:
  1914. bool: True if document exists in graph, False otherwise
  1915. Raises:
  1916. R2RException: If graph not found
  1917. """
  1918. QUERY = f"""
  1919. SELECT EXISTS (
  1920. SELECT 1
  1921. FROM {self._get_table_name("graphs")}
  1922. WHERE id = $1
  1923. AND document_ids IS NOT NULL
  1924. AND $2 = ANY(document_ids)
  1925. ) as exists;
  1926. """
  1927. result = await self.connection_manager.fetchrow_query(
  1928. QUERY, [graph_id, document_id]
  1929. )
  1930. if result is None:
  1931. raise R2RException(f"Graph {graph_id} not found", 404)
  1932. return result["exists"]
  1933. async def get_communities(
  1934. self,
  1935. parent_id: UUID,
  1936. offset: int,
  1937. limit: int,
  1938. community_ids: Optional[list[UUID]] = None,
  1939. include_embeddings: bool = False,
  1940. ) -> tuple[list[Community], int]:
  1941. """Get communities for a graph.
  1942. Args:
  1943. collection_id: UUID of the collection
  1944. offset: Number of records to skip
  1945. limit: Maximum number of records to return (-1 for no limit)
  1946. community_ids: Optional list of community IDs to filter by
  1947. include_embeddings: Whether to include embeddings in the response
  1948. Returns:
  1949. Tuple of (list of communities, total count)
  1950. """
  1951. conditions = ["collection_id = $1"]
  1952. params: list[Any] = [parent_id]
  1953. param_index = 2
  1954. if community_ids:
  1955. conditions.append(f"id = ANY(${param_index})")
  1956. params.append(community_ids)
  1957. param_index += 1
  1958. select_fields = """
  1959. id, collection_id, name, summary, findings, rating, rating_explanation
  1960. """
  1961. if include_embeddings:
  1962. select_fields += ", description_embedding"
  1963. COUNT_QUERY = f"""
  1964. SELECT COUNT(*)
  1965. FROM {self._get_table_name("graphs_communities")}
  1966. WHERE {" AND ".join(conditions)}
  1967. """
  1968. count = (
  1969. await self.connection_manager.fetch_query(COUNT_QUERY, params)
  1970. )[0]["count"]
  1971. QUERY = f"""
  1972. SELECT {select_fields}
  1973. FROM {self._get_table_name("graphs_communities")}
  1974. WHERE {" AND ".join(conditions)}
  1975. ORDER BY created_at
  1976. OFFSET ${param_index}
  1977. """
  1978. params.append(offset)
  1979. param_index += 1
  1980. if limit != -1:
  1981. QUERY += f" LIMIT ${param_index}"
  1982. params.append(limit)
  1983. rows = await self.connection_manager.fetch_query(QUERY, params)
  1984. communities = []
  1985. for row in rows:
  1986. community_dict = dict(row)
  1987. communities.append(Community(**community_dict))
  1988. return communities, count
  1989. async def add_community(self, community: Community) -> None:
  1990. # TODO: Fix in the short term.
  1991. # we need to do this because postgres insert needs to be a string
  1992. community.description_embedding = str(community.description_embedding) # type: ignore[assignment]
  1993. non_null_attrs = {
  1994. k: v for k, v in community.__dict__.items() if v is not None
  1995. }
  1996. columns = ", ".join(non_null_attrs.keys())
  1997. placeholders = ", ".join(
  1998. f"${i + 1}" for i in range(len(non_null_attrs))
  1999. )
  2000. conflict_columns = ", ".join(
  2001. [f"{k} = EXCLUDED.{k}" for k in non_null_attrs]
  2002. )
  2003. QUERY = f"""
  2004. INSERT INTO {self._get_table_name("graphs_communities")} ({columns})
  2005. VALUES ({placeholders})
  2006. ON CONFLICT (community_id, level, collection_id) DO UPDATE SET
  2007. {conflict_columns}
  2008. """
  2009. await self.connection_manager.execute_many(
  2010. QUERY, [tuple(non_null_attrs.values())]
  2011. )
  2012. async def delete(self, collection_id: UUID) -> None:
  2013. graphs = await self.get(graph_id=collection_id, offset=0, limit=-1)
  2014. if len(graphs["results"]) == 0:
  2015. raise R2RException(
  2016. message=f"Graph not found for collection {collection_id}",
  2017. status_code=404,
  2018. )
  2019. await self.reset(collection_id)
  2020. # set status to PENDING for this collection.
  2021. QUERY = f"""
  2022. UPDATE {self._get_table_name("collections")} SET graph_cluster_status = $1 WHERE id = $2
  2023. """
  2024. await self.connection_manager.execute_query(
  2025. QUERY, [GraphExtractionStatus.PENDING, collection_id]
  2026. )
  2027. # Delete the graph
  2028. QUERY = f"""
  2029. DELETE FROM {self._get_table_name("graphs")} WHERE collection_id = $1
  2030. """
  2031. try:
  2032. await self.connection_manager.execute_query(QUERY, [collection_id])
  2033. except Exception as e:
  2034. raise HTTPException(
  2035. status_code=500,
  2036. detail=f"An error occurred while deleting the graph: {e}",
  2037. ) from e
  2038. async def perform_graph_clustering(
  2039. self,
  2040. collection_id: UUID,
  2041. leiden_params: dict[str, Any],
  2042. ) -> Tuple[int, Any]:
  2043. """Calls the external clustering service to cluster the graph."""
  2044. offset = 0
  2045. page_size = 1000
  2046. all_relationships = []
  2047. while True:
  2048. relationships, count = await self.relationships.get(
  2049. parent_id=collection_id,
  2050. store_type=StoreType.GRAPHS,
  2051. offset=offset,
  2052. limit=page_size,
  2053. )
  2054. if not relationships:
  2055. break
  2056. all_relationships.extend(relationships)
  2057. offset += len(relationships)
  2058. if offset >= count:
  2059. break
  2060. logger.info(
  2061. f"Clustering over {len(all_relationships)} relationships for {collection_id} with settings: {leiden_params}"
  2062. )
  2063. if len(all_relationships) == 0:
  2064. raise R2RException(
  2065. message="No relationships found for clustering",
  2066. status_code=400,
  2067. )
  2068. return await self._cluster_and_add_community_info(
  2069. relationships=all_relationships,
  2070. leiden_params=leiden_params,
  2071. collection_id=collection_id,
  2072. )
  2073. async def _call_clustering_service(
  2074. self, relationships: list[Relationship], leiden_params: dict[str, Any]
  2075. ) -> list[dict]:
  2076. """Calls the external Graspologic clustering service, sending
  2077. relationships and parameters.
  2078. Expects a response with 'communities' field.
  2079. """
  2080. # Convert relationships to a JSON-friendly format
  2081. rel_data = []
  2082. for r in relationships:
  2083. rel_data.append(
  2084. {
  2085. "id": str(r.id),
  2086. "subject": r.subject,
  2087. "object": r.object,
  2088. "weight": r.weight if r.weight is not None else 1.0,
  2089. }
  2090. )
  2091. endpoint = os.environ.get("CLUSTERING_SERVICE_URL")
  2092. if not endpoint:
  2093. raise ValueError("CLUSTERING_SERVICE_URL not set.")
  2094. url = f"{endpoint}/cluster"
  2095. payload = {"relationships": rel_data, "leiden_params": leiden_params}
  2096. async with httpx.AsyncClient() as client:
  2097. response = await client.post(url, json=payload, timeout=3600)
  2098. response.raise_for_status()
  2099. data = response.json()
  2100. return data.get("communities", [])
  2101. async def _create_graph_and_cluster(
  2102. self,
  2103. relationships: list[Relationship],
  2104. leiden_params: dict[str, Any],
  2105. ) -> Any:
  2106. """Create a graph and cluster it."""
  2107. return await self._call_clustering_service(
  2108. relationships, leiden_params
  2109. )
  2110. async def _cluster_and_add_community_info(
  2111. self,
  2112. relationships: list[Relationship],
  2113. leiden_params: dict[str, Any],
  2114. collection_id: UUID,
  2115. ) -> Tuple[int, Any]:
  2116. logger.info(f"Creating graph and clustering for {collection_id}")
  2117. await asyncio.sleep(0.1)
  2118. start_time = time.time()
  2119. hierarchical_communities = await self._create_graph_and_cluster(
  2120. relationships=relationships,
  2121. leiden_params=leiden_params,
  2122. )
  2123. logger.info(
  2124. f"Computing Leiden communities completed, time {time.time() - start_time:.2f} seconds."
  2125. )
  2126. if not hierarchical_communities:
  2127. num_communities = 0
  2128. else:
  2129. num_communities = (
  2130. max(item["cluster"] for item in hierarchical_communities) + 1
  2131. )
  2132. logger.info(
  2133. f"Generated {num_communities} communities, time {time.time() - start_time:.2f} seconds."
  2134. )
  2135. return num_communities, hierarchical_communities
  2136. async def get_entity_map(
  2137. self, offset: int, limit: int, document_id: UUID
  2138. ) -> dict[str, dict[str, list[dict[str, Any]]]]:
  2139. QUERY1 = f"""
  2140. WITH entities_list AS (
  2141. SELECT DISTINCT name
  2142. FROM {self._get_table_name("documents_entities")}
  2143. WHERE parent_id = $1
  2144. ORDER BY name ASC
  2145. LIMIT {limit} OFFSET {offset}
  2146. )
  2147. SELECT e.name, e.description, e.category,
  2148. (SELECT array_agg(DISTINCT x) FROM unnest(e.chunk_ids) x) AS chunk_ids,
  2149. e.parent_id
  2150. FROM {self._get_table_name("documents_entities")} e
  2151. JOIN entities_list el ON e.name = el.name
  2152. GROUP BY e.name, e.description, e.category, e.chunk_ids, e.parent_id
  2153. ORDER BY e.name;"""
  2154. entities_list = await self.connection_manager.fetch_query(
  2155. QUERY1, [document_id]
  2156. )
  2157. entities_list = [Entity(**entity) for entity in entities_list]
  2158. QUERY2 = f"""
  2159. WITH entities_list AS (
  2160. SELECT DISTINCT name
  2161. FROM {self._get_table_name("documents_entities")}
  2162. WHERE parent_id = $1
  2163. ORDER BY name ASC
  2164. LIMIT {limit} OFFSET {offset}
  2165. )
  2166. SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description,
  2167. (SELECT array_agg(DISTINCT x) FROM unnest(t.chunk_ids) x) AS chunk_ids, t.parent_id
  2168. FROM {self._get_table_name("documents_relationships")} t
  2169. JOIN entities_list el ON t.subject = el.name
  2170. ORDER BY t.subject, t.predicate, t.object;
  2171. """
  2172. relationships_list = await self.connection_manager.fetch_query(
  2173. QUERY2, [document_id]
  2174. )
  2175. relationships_list = [
  2176. Relationship(**relationship) for relationship in relationships_list
  2177. ]
  2178. entity_map: dict[str, dict[str, list[Any]]] = {}
  2179. for entity in entities_list:
  2180. if entity.name not in entity_map:
  2181. entity_map[entity.name] = {"entities": [], "relationships": []}
  2182. entity_map[entity.name]["entities"].append(entity)
  2183. for relationship in relationships_list:
  2184. if relationship.subject in entity_map:
  2185. entity_map[relationship.subject]["relationships"].append(
  2186. relationship
  2187. )
  2188. if relationship.object in entity_map:
  2189. entity_map[relationship.object]["relationships"].append(
  2190. relationship
  2191. )
  2192. return entity_map
  2193. async def graph_search(
  2194. self, query: str, **kwargs: Any
  2195. ) -> AsyncGenerator[Any, None]:
  2196. """Perform semantic search with similarity scores while maintaining
  2197. exact same structure."""
  2198. query_embedding = kwargs.get("query_embedding", None)
  2199. if query_embedding is None:
  2200. raise ValueError(
  2201. "query_embedding must be provided for semantic search"
  2202. )
  2203. search_type = kwargs.get(
  2204. "search_type", "entities"
  2205. ) # entities | relationships | communities
  2206. embedding_type = kwargs.get("embedding_type", "description_embedding")
  2207. property_names = kwargs.get("property_names", ["name", "description"])
  2208. # Add metadata if not present
  2209. if "metadata" not in property_names:
  2210. property_names.append("metadata")
  2211. filters = kwargs.get("filters", {})
  2212. limit = kwargs.get("limit", 10)
  2213. use_fulltext_search = kwargs.get("use_fulltext_search", True)
  2214. use_hybrid_search = kwargs.get("use_hybrid_search", True)
  2215. if use_hybrid_search or use_fulltext_search:
  2216. logger.warning(
  2217. "Hybrid and fulltext search not supported for graph search, ignoring."
  2218. )
  2219. table_name = f"graphs_{search_type}"
  2220. property_names_str = ", ".join(property_names)
  2221. # Build the WHERE clause from filters
  2222. params: list[str | int | bytes] = [
  2223. json.dumps(query_embedding),
  2224. limit,
  2225. ]
  2226. conditions_clause = self._build_filters(filters, params, search_type)
  2227. where_clause = (
  2228. f"WHERE {conditions_clause}" if conditions_clause else ""
  2229. )
  2230. # Construct the query
  2231. # Note: For vector similarity, we use <=> for distance. The smaller the number, the more similar.
  2232. # We'll convert that to similarity_score by doing (1 - distance).
  2233. QUERY = f"""
  2234. SELECT
  2235. {property_names_str},
  2236. ({embedding_type} <=> $1) as similarity_score
  2237. FROM {self._get_table_name(table_name)}
  2238. {where_clause}
  2239. ORDER BY {embedding_type} <=> $1
  2240. LIMIT $2;
  2241. """
  2242. results = await self.connection_manager.fetch_query(
  2243. QUERY, tuple(params)
  2244. )
  2245. for result in results:
  2246. output = {
  2247. prop: result[prop] for prop in property_names if prop in result
  2248. }
  2249. output["similarity_score"] = (
  2250. 1 - float(result["similarity_score"])
  2251. if result.get("similarity_score")
  2252. else "n/a"
  2253. )
  2254. yield output
  2255. def _build_filters(
  2256. self, filter_dict: dict, parameters: list[Any], search_type: str
  2257. ) -> str:
  2258. """Build a WHERE clause from a nested filter dictionary for the graph
  2259. search.
  2260. - If search_type == "communities", we normally filter by `collection_id`.
  2261. - Otherwise (entities/relationships), we normally filter by `parent_id`.
  2262. - If user provides `"collection_ids": {...}`, we interpret that as wanting
  2263. to filter by multiple collection IDs (i.e. 'parent_id IN (...)' or
  2264. 'collection_id IN (...)').
  2265. """
  2266. # The usual "base" column used by your code
  2267. base_id_column = (
  2268. "collection_id" if search_type == "communities" else "parent_id"
  2269. )
  2270. def parse_condition(key: str, value: Any) -> str:
  2271. # ----------------------------------------------------------------------
  2272. # 1) If it's the normal base_id_column (like "parent_id" or "collection_id")
  2273. # ----------------------------------------------------------------------
  2274. if key == base_id_column:
  2275. if isinstance(value, dict):
  2276. op, clause = next(iter(value.items()))
  2277. if op == "$eq":
  2278. # single equality
  2279. parameters.append(str(clause))
  2280. return f"{base_id_column} = ${len(parameters)}::uuid"
  2281. elif op in ("$in", "$overlap"):
  2282. # treat both $in/$overlap as "IN the set" for a single column
  2283. array_val = [str(x) for x in clause]
  2284. parameters.append(array_val)
  2285. return f"{base_id_column} = ANY(${len(parameters)}::uuid[])"
  2286. # handle other operators as needed
  2287. else:
  2288. # direct equality
  2289. parameters.append(str(value))
  2290. return f"{base_id_column} = ${len(parameters)}::uuid"
  2291. # ----------------------------------------------------------------------
  2292. # 2) SPECIAL: if user specifically sets "collection_ids" in filters
  2293. # We interpret that to mean "Look for rows whose parent_id (or collection_id)
  2294. # is in the array of values" – i.e. we do the same logic but we forcibly
  2295. # direct it to the same column: parent_id or collection_id.
  2296. # ----------------------------------------------------------------------
  2297. elif key == "collection_ids":
  2298. # If we are searching communities, the relevant field is `collection_id`.
  2299. # If searching entities/relationships, the relevant field is `parent_id`.
  2300. col_to_use = (
  2301. "collection_id"
  2302. if search_type == "communities"
  2303. else "parent_id"
  2304. )
  2305. if isinstance(value, dict):
  2306. op, clause = next(iter(value.items()))
  2307. if op == "$eq":
  2308. # single equality => col_to_use = clause
  2309. parameters.append(str(clause))
  2310. return f"{col_to_use} = ${len(parameters)}::uuid"
  2311. elif op in ("$in", "$overlap"):
  2312. # "col_to_use = ANY($param::uuid[])"
  2313. array_val = [str(x) for x in clause]
  2314. parameters.append(array_val)
  2315. return (
  2316. f"{col_to_use} = ANY(${len(parameters)}::uuid[])"
  2317. )
  2318. # add more if you want, e.g. $ne, $gt, etc.
  2319. else:
  2320. # direct equality scenario: "collection_ids": "some-uuid"
  2321. parameters.append(str(value))
  2322. return f"{col_to_use} = ${len(parameters)}::uuid"
  2323. # ----------------------------------------------------------------------
  2324. # 3) If key starts with "metadata.", handle metadata-based filters
  2325. # ----------------------------------------------------------------------
  2326. elif key.startswith("metadata."):
  2327. field = key.split("metadata.")[1]
  2328. if isinstance(value, dict):
  2329. op, clause = next(iter(value.items()))
  2330. if op == "$eq":
  2331. parameters.append(clause)
  2332. return f"(metadata->>'{field}') = ${len(parameters)}"
  2333. elif op == "$ne":
  2334. parameters.append(clause)
  2335. return f"(metadata->>'{field}') != ${len(parameters)}"
  2336. elif op == "$gt":
  2337. parameters.append(clause)
  2338. return f"(metadata->>'{field}')::float > ${len(parameters)}::float"
  2339. # etc...
  2340. else:
  2341. parameters.append(value)
  2342. return f"(metadata->>'{field}') = ${len(parameters)}"
  2343. # ----------------------------------------------------------------------
  2344. # 4) Not recognized => return empty so we skip it
  2345. # ----------------------------------------------------------------------
  2346. return ""
  2347. # --------------------------------------------------------------------------
  2348. # 5) parse_filter() is the recursive walker that sees $and/$or or normal fields
  2349. # --------------------------------------------------------------------------
  2350. def parse_filter(fd: dict) -> str:
  2351. filter_conditions = []
  2352. for k, v in fd.items():
  2353. if k == "$and":
  2354. and_parts = [parse_filter(sub) for sub in v if sub]
  2355. and_parts = [x for x in and_parts if x.strip()]
  2356. if and_parts:
  2357. filter_conditions.append(
  2358. f"({' AND '.join(and_parts)})"
  2359. )
  2360. elif k == "$or":
  2361. or_parts = [parse_filter(sub) for sub in v if sub]
  2362. or_parts = [x for x in or_parts if x.strip()]
  2363. if or_parts:
  2364. filter_conditions.append(f"({' OR '.join(or_parts)})")
  2365. else:
  2366. c = parse_condition(k, v)
  2367. if c and c.strip():
  2368. filter_conditions.append(c)
  2369. if not filter_conditions:
  2370. return ""
  2371. if len(filter_conditions) == 1:
  2372. return filter_conditions[0]
  2373. return " AND ".join(filter_conditions)
  2374. return parse_filter(filter_dict)
  2375. async def get_existing_document_entity_chunk_ids(
  2376. self, document_id: UUID
  2377. ) -> list[str]:
  2378. QUERY = f"""
  2379. SELECT DISTINCT unnest(chunk_ids) AS chunk_id FROM {self._get_table_name("documents_entities")} WHERE parent_id = $1
  2380. """
  2381. return [
  2382. item["chunk_id"]
  2383. for item in await self.connection_manager.fetch_query(
  2384. QUERY, [document_id]
  2385. )
  2386. ]
  2387. async def get_entity_count(
  2388. self,
  2389. collection_id: Optional[UUID] = None,
  2390. document_id: Optional[UUID] = None,
  2391. distinct: bool = False,
  2392. entity_table_name: str = "entity",
  2393. ) -> int:
  2394. if collection_id is None and document_id is None:
  2395. raise ValueError(
  2396. "Either collection_id or document_id must be provided."
  2397. )
  2398. conditions = ["parent_id = $1"]
  2399. params = [str(document_id)]
  2400. count_value = "DISTINCT name" if distinct else "*"
  2401. QUERY = f"""
  2402. SELECT COUNT({count_value}) FROM {self._get_table_name(entity_table_name)}
  2403. WHERE {" AND ".join(conditions)}
  2404. """
  2405. return (await self.connection_manager.fetch_query(QUERY, params))[0][
  2406. "count"
  2407. ]
  2408. async def update_entity_descriptions(self, entities: list[Entity]):
  2409. query = f"""
  2410. UPDATE {self._get_table_name("graphs_entities")}
  2411. SET description = $3, description_embedding = $4
  2412. WHERE name = $1 AND graph_id = $2
  2413. """
  2414. inputs = [
  2415. (
  2416. entity.name,
  2417. entity.parent_id,
  2418. entity.description,
  2419. entity.description_embedding,
  2420. )
  2421. for entity in entities
  2422. ]
  2423. await self.connection_manager.execute_many(query, inputs) # type: ignore
  2424. def _json_serialize(obj):
  2425. if isinstance(obj, UUID):
  2426. return str(obj)
  2427. elif isinstance(obj, (datetime.datetime, datetime.date)):
  2428. return obj.isoformat()
  2429. raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
  2430. async def _add_objects(
  2431. objects: list[dict],
  2432. full_table_name: str,
  2433. connection_manager: PostgresConnectionManager,
  2434. conflict_columns: list[str] | None = None,
  2435. exclude_metadata: list[str] | None = None,
  2436. ) -> list[UUID]:
  2437. """Bulk insert objects into the specified table using
  2438. jsonb_to_recordset."""
  2439. if conflict_columns is None:
  2440. conflict_columns = []
  2441. if exclude_metadata is None:
  2442. exclude_metadata = []
  2443. # Exclude specified metadata and prepare data
  2444. cleaned_objects = []
  2445. for obj in objects:
  2446. cleaned_obj = {
  2447. k: v
  2448. for k, v in obj.items()
  2449. if k not in exclude_metadata and v is not None
  2450. }
  2451. cleaned_objects.append(cleaned_obj)
  2452. # Serialize the list of objects to JSON
  2453. json_data = json.dumps(cleaned_objects, default=_json_serialize)
  2454. # Prepare the column definitions for jsonb_to_recordset
  2455. columns = cleaned_objects[0].keys()
  2456. column_defs = []
  2457. for col in columns:
  2458. # Map Python types to PostgreSQL types
  2459. sample_value = cleaned_objects[0][col]
  2460. if "embedding" in col:
  2461. pg_type = "vector"
  2462. elif "chunk_ids" in col or "document_ids" in col or "graph_ids" in col:
  2463. pg_type = "uuid[]"
  2464. elif col == "id" or "_id" in col:
  2465. pg_type = "uuid"
  2466. elif isinstance(sample_value, str):
  2467. pg_type = "text"
  2468. elif isinstance(sample_value, UUID):
  2469. pg_type = "uuid"
  2470. elif isinstance(sample_value, (int, float)):
  2471. pg_type = "numeric"
  2472. elif isinstance(sample_value, list) and all(
  2473. isinstance(x, UUID) for x in sample_value
  2474. ):
  2475. pg_type = "uuid[]"
  2476. elif isinstance(sample_value, list):
  2477. pg_type = "jsonb"
  2478. elif isinstance(sample_value, dict):
  2479. pg_type = "jsonb"
  2480. elif isinstance(sample_value, bool):
  2481. pg_type = "boolean"
  2482. elif isinstance(sample_value, (datetime.datetime, datetime.date)):
  2483. pg_type = "timestamp"
  2484. else:
  2485. raise TypeError(
  2486. f"Unsupported data type for column '{col}': {type(sample_value)}"
  2487. )
  2488. column_defs.append(f"{col} {pg_type}")
  2489. columns_str = ", ".join(columns)
  2490. column_defs_str = ", ".join(column_defs)
  2491. if conflict_columns:
  2492. conflict_columns_str = ", ".join(conflict_columns)
  2493. update_columns_str = ", ".join(
  2494. f"{col}=EXCLUDED.{col}"
  2495. for col in columns
  2496. if col not in conflict_columns
  2497. )
  2498. on_conflict_clause = f"ON CONFLICT ({conflict_columns_str}) DO UPDATE SET {update_columns_str}"
  2499. else:
  2500. on_conflict_clause = ""
  2501. QUERY = f"""
  2502. INSERT INTO {full_table_name} ({columns_str})
  2503. SELECT {columns_str}
  2504. FROM jsonb_to_recordset($1::jsonb)
  2505. AS x({column_defs_str})
  2506. {on_conflict_clause}
  2507. RETURNING id;
  2508. """
  2509. # Execute the query
  2510. result = await connection_manager.fetch_query(QUERY, [json_data])
  2511. # Extract and return the IDs
  2512. return [record["id"] for record in result]