graphs.py 98 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790
  1. import asyncio
  2. import datetime
  3. import json
  4. import logging
  5. import os
  6. import time
  7. from enum import Enum
  8. from typing import Any, AsyncGenerator, Optional, Tuple, Union
  9. from uuid import UUID
  10. import asyncpg
  11. import httpx
  12. from asyncpg.exceptions import UndefinedTableError, UniqueViolationError
  13. from fastapi import HTTPException
  14. from core.base.abstractions import (
  15. Community,
  16. Entity,
  17. Graph,
  18. KGCreationSettings,
  19. KGEnrichmentSettings,
  20. KGEnrichmentStatus,
  21. KGEntityDeduplicationSettings,
  22. KGExtractionStatus,
  23. R2RException,
  24. Relationship,
  25. VectorQuantizationType,
  26. )
  27. from core.base.api.models import GraphResponse
  28. from core.base.providers.database import Handler
  29. from core.base.utils import (
  30. _decorate_vector_type,
  31. _get_str_estimation_output,
  32. llm_cost_per_million_tokens,
  33. )
  34. from .base import PostgresConnectionManager
  35. from .collections import PostgresCollectionsHandler
  36. class StoreType(str, Enum):
  37. GRAPHS = "graphs"
  38. DOCUMENTS = "documents"
  39. logger = logging.getLogger()
  40. class PostgresEntitiesHandler(Handler):
  41. def __init__(self, *args: Any, **kwargs: Any) -> None:
  42. self.project_name: str = kwargs.get("project_name") # type: ignore
  43. self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore
  44. self.dimension: int = kwargs.get("dimension") # type: ignore
  45. self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore
  46. def _get_table_name(self, table: str) -> str:
  47. """Get the fully qualified table name."""
  48. return f'"{self.project_name}"."{table}"'
  49. def _get_entity_table_for_store(self, store_type: StoreType) -> str:
  50. """Get the appropriate table name for the store type."""
  51. if isinstance(store_type, StoreType):
  52. store_type = store_type.value
  53. return f"{store_type}_entities"
  54. def _get_parent_constraint(self, store_type: StoreType) -> str:
  55. """Get the appropriate foreign key constraint for the store type."""
  56. if store_type == StoreType.GRAPHS:
  57. return f"""
  58. CONSTRAINT fk_graph
  59. FOREIGN KEY(parent_id)
  60. REFERENCES {self._get_table_name("graphs")}(id)
  61. ON DELETE CASCADE
  62. """
  63. else:
  64. return f"""
  65. CONSTRAINT fk_document
  66. FOREIGN KEY(parent_id)
  67. REFERENCES {self._get_table_name("documents")}(id)
  68. ON DELETE CASCADE
  69. """
  70. async def create_tables(self) -> None:
  71. """Create separate tables for graph and document entities."""
  72. vector_column_str = _decorate_vector_type(
  73. f"({self.dimension})", self.quantization_type
  74. )
  75. for store_type in StoreType:
  76. table_name = self._get_entity_table_for_store(store_type)
  77. parent_constraint = self._get_parent_constraint(store_type)
  78. QUERY = f"""
  79. CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
  80. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  81. name TEXT NOT NULL,
  82. category TEXT,
  83. description TEXT,
  84. parent_id UUID NOT NULL,
  85. description_embedding {vector_column_str},
  86. chunk_ids UUID[],
  87. metadata JSONB,
  88. created_at TIMESTAMPTZ DEFAULT NOW(),
  89. updated_at TIMESTAMPTZ DEFAULT NOW(),
  90. {parent_constraint}
  91. );
  92. CREATE INDEX IF NOT EXISTS {table_name}_name_idx
  93. ON {self._get_table_name(table_name)} (name);
  94. CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
  95. ON {self._get_table_name(table_name)} (parent_id);
  96. CREATE INDEX IF NOT EXISTS {table_name}_category_idx
  97. ON {self._get_table_name(table_name)} (category);
  98. """
  99. await self.connection_manager.execute_query(QUERY)
  100. async def create(
  101. self,
  102. parent_id: UUID,
  103. store_type: StoreType,
  104. name: str,
  105. category: Optional[str] = None,
  106. description: Optional[str] = None,
  107. description_embedding: Optional[list[float] | str] = None,
  108. chunk_ids: Optional[list[UUID]] = None,
  109. metadata: Optional[dict[str, Any] | str] = None,
  110. ) -> Entity:
  111. """Create a new entity in the specified store."""
  112. table_name = self._get_entity_table_for_store(store_type)
  113. if isinstance(metadata, str):
  114. try:
  115. metadata = json.loads(metadata)
  116. except json.JSONDecodeError:
  117. pass
  118. if isinstance(description_embedding, list):
  119. description_embedding = str(description_embedding)
  120. query = f"""
  121. INSERT INTO {self._get_table_name(table_name)}
  122. (name, category, description, parent_id, description_embedding, chunk_ids, metadata)
  123. VALUES ($1, $2, $3, $4, $5, $6, $7)
  124. RETURNING id, name, category, description, parent_id, chunk_ids, metadata
  125. """
  126. params = [
  127. name,
  128. category,
  129. description,
  130. parent_id,
  131. description_embedding,
  132. chunk_ids,
  133. json.dumps(metadata) if metadata else None,
  134. ]
  135. result = await self.connection_manager.fetchrow_query(
  136. query=query,
  137. params=params,
  138. )
  139. return Entity(
  140. id=result["id"],
  141. name=result["name"],
  142. category=result["category"],
  143. description=result["description"],
  144. parent_id=result["parent_id"],
  145. chunk_ids=result["chunk_ids"],
  146. metadata=result["metadata"],
  147. )
  148. async def get(
  149. self,
  150. parent_id: UUID,
  151. store_type: StoreType,
  152. offset: int,
  153. limit: int,
  154. entity_ids: Optional[list[UUID]] = None,
  155. entity_names: Optional[list[str]] = None,
  156. include_embeddings: bool = False,
  157. ):
  158. """Retrieve entities from the specified store."""
  159. table_name = self._get_entity_table_for_store(store_type)
  160. conditions = ["parent_id = $1"]
  161. params: list[Any] = [parent_id]
  162. param_index = 2
  163. if entity_ids:
  164. conditions.append(f"id = ANY(${param_index})")
  165. params.append(entity_ids)
  166. param_index += 1
  167. if entity_names:
  168. conditions.append(f"name = ANY(${param_index})")
  169. params.append(entity_names)
  170. param_index += 1
  171. select_fields = """
  172. id, name, category, description, parent_id,
  173. chunk_ids, metadata
  174. """
  175. if include_embeddings:
  176. select_fields += ", description_embedding"
  177. COUNT_QUERY = f"""
  178. SELECT COUNT(*)
  179. FROM {self._get_table_name(table_name)}
  180. WHERE {' AND '.join(conditions)}
  181. """
  182. count_params = params[: param_index - 1]
  183. count = (
  184. await self.connection_manager.fetch_query(
  185. COUNT_QUERY, count_params
  186. )
  187. )[0]["count"]
  188. QUERY = f"""
  189. SELECT {select_fields}
  190. FROM {self._get_table_name(table_name)}
  191. WHERE {' AND '.join(conditions)}
  192. ORDER BY created_at
  193. OFFSET ${param_index}
  194. """
  195. params.append(offset)
  196. param_index += 1
  197. if limit != -1:
  198. QUERY += f" LIMIT ${param_index}"
  199. params.append(limit)
  200. rows = await self.connection_manager.fetch_query(QUERY, params)
  201. entities = []
  202. for row in rows:
  203. # Convert the Record to a dictionary
  204. entity_dict = dict(row)
  205. # Process metadata if it exists and is a string
  206. if isinstance(entity_dict["metadata"], str):
  207. try:
  208. entity_dict["metadata"] = json.loads(
  209. entity_dict["metadata"]
  210. )
  211. except json.JSONDecodeError:
  212. pass
  213. entities.append(Entity(**entity_dict))
  214. return entities, count
  215. async def update(
  216. self,
  217. entity_id: UUID,
  218. store_type: StoreType,
  219. name: Optional[str] = None,
  220. description: Optional[str] = None,
  221. description_embedding: Optional[list[float] | str] = None,
  222. category: Optional[str] = None,
  223. metadata: Optional[dict] = None,
  224. ) -> Entity:
  225. """Update an entity in the specified store."""
  226. table_name = self._get_entity_table_for_store(store_type)
  227. update_fields = []
  228. params: list[Any] = []
  229. param_index = 1
  230. if isinstance(metadata, str):
  231. try:
  232. metadata = json.loads(metadata)
  233. except json.JSONDecodeError:
  234. pass
  235. if name is not None:
  236. update_fields.append(f"name = ${param_index}")
  237. params.append(name)
  238. param_index += 1
  239. if description is not None:
  240. update_fields.append(f"description = ${param_index}")
  241. params.append(description)
  242. param_index += 1
  243. if description_embedding is not None:
  244. update_fields.append(f"description_embedding = ${param_index}")
  245. params.append(description_embedding)
  246. param_index += 1
  247. if category is not None:
  248. update_fields.append(f"category = ${param_index}")
  249. params.append(category)
  250. param_index += 1
  251. if metadata is not None:
  252. update_fields.append(f"metadata = ${param_index}")
  253. params.append(json.dumps(metadata))
  254. param_index += 1
  255. if not update_fields:
  256. raise R2RException(status_code=400, message="No fields to update")
  257. update_fields.append("updated_at = NOW()")
  258. params.append(entity_id)
  259. query = f"""
  260. UPDATE {self._get_table_name(table_name)}
  261. SET {', '.join(update_fields)}
  262. WHERE id = ${param_index}\
  263. RETURNING id, name, category, description, parent_id, chunk_ids, metadata
  264. """
  265. try:
  266. result = await self.connection_manager.fetchrow_query(
  267. query=query,
  268. params=params,
  269. )
  270. return Entity(
  271. id=result["id"],
  272. name=result["name"],
  273. category=result["category"],
  274. description=result["description"],
  275. parent_id=result["parent_id"],
  276. chunk_ids=result["chunk_ids"],
  277. metadata=result["metadata"],
  278. )
  279. except Exception as e:
  280. raise HTTPException(
  281. status_code=500,
  282. detail=f"An error occurred while updating the entity: {e}",
  283. )
  284. async def delete(
  285. self,
  286. parent_id: UUID,
  287. entity_ids: Optional[list[UUID]] = None,
  288. store_type: StoreType = StoreType.GRAPHS,
  289. ) -> None:
  290. """
  291. Delete entities from the specified store.
  292. If entity_ids is not provided, deletes all entities for the given parent_id.
  293. Args:
  294. parent_id (UUID): Parent ID (collection_id or document_id)
  295. entity_ids (Optional[list[UUID]]): Specific entity IDs to delete. If None, deletes all entities for parent_id
  296. store_type (StoreType): Type of store (graph or document)
  297. Returns:
  298. list[UUID]: List of deleted entity IDs
  299. Raises:
  300. R2RException: If specific entities were requested but not all found
  301. """
  302. table_name = self._get_entity_table_for_store(store_type)
  303. if entity_ids is None:
  304. # Delete all entities for the parent_id
  305. QUERY = f"""
  306. DELETE FROM {self._get_table_name(table_name)}
  307. WHERE parent_id = $1
  308. RETURNING id
  309. """
  310. results = await self.connection_manager.fetch_query(
  311. QUERY, [parent_id]
  312. )
  313. else:
  314. # Delete specific entities
  315. QUERY = f"""
  316. DELETE FROM {self._get_table_name(table_name)}
  317. WHERE id = ANY($1) AND parent_id = $2
  318. RETURNING id
  319. """
  320. results = await self.connection_manager.fetch_query(
  321. QUERY, [entity_ids, parent_id]
  322. )
  323. # Check if all requested entities were deleted
  324. deleted_ids = [row["id"] for row in results]
  325. if entity_ids and len(deleted_ids) != len(entity_ids):
  326. raise R2RException(
  327. f"Some entities not found in {store_type} store or no permission to delete",
  328. 404,
  329. )
  330. class PostgresRelationshipsHandler(Handler):
  331. def __init__(self, *args: Any, **kwargs: Any) -> None:
  332. self.project_name: str = kwargs.get("project_name") # type: ignore
  333. self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore
  334. self.dimension: int = kwargs.get("dimension") # type: ignore
  335. self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore
  336. def _get_table_name(self, table: str) -> str:
  337. """Get the fully qualified table name."""
  338. return f'"{self.project_name}"."{table}"'
  339. def _get_relationship_table_for_store(self, store_type: StoreType) -> str:
  340. """Get the appropriate table name for the store type."""
  341. if isinstance(store_type, StoreType):
  342. store_type = store_type.value
  343. return f"{store_type}_relationships"
  344. def _get_parent_constraint(self, store_type: StoreType) -> str:
  345. """Get the appropriate foreign key constraint for the store type."""
  346. if store_type == StoreType.GRAPHS:
  347. return f"""
  348. CONSTRAINT fk_graph
  349. FOREIGN KEY(parent_id)
  350. REFERENCES {self._get_table_name("graphs")}(id)
  351. ON DELETE CASCADE
  352. """
  353. else:
  354. return f"""
  355. CONSTRAINT fk_document
  356. FOREIGN KEY(parent_id)
  357. REFERENCES {self._get_table_name("documents")}(id)
  358. ON DELETE CASCADE
  359. """
  360. async def create_tables(self) -> None:
  361. """Create separate tables for graph and document relationships."""
  362. for store_type in StoreType:
  363. table_name = self._get_relationship_table_for_store(store_type)
  364. parent_constraint = self._get_parent_constraint(store_type)
  365. vector_column_str = _decorate_vector_type(
  366. f"({self.dimension})", self.quantization_type
  367. )
  368. QUERY = f"""
  369. CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
  370. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  371. subject TEXT NOT NULL,
  372. predicate TEXT NOT NULL,
  373. object TEXT NOT NULL,
  374. description TEXT,
  375. description_embedding {vector_column_str},
  376. subject_id UUID,
  377. object_id UUID,
  378. weight FLOAT DEFAULT 1.0,
  379. chunk_ids UUID[],
  380. parent_id UUID NOT NULL,
  381. metadata JSONB,
  382. created_at TIMESTAMPTZ DEFAULT NOW(),
  383. updated_at TIMESTAMPTZ DEFAULT NOW(),
  384. {parent_constraint}
  385. );
  386. CREATE INDEX IF NOT EXISTS {table_name}_subject_idx
  387. ON {self._get_table_name(table_name)} (subject);
  388. CREATE INDEX IF NOT EXISTS {table_name}_object_idx
  389. ON {self._get_table_name(table_name)} (object);
  390. CREATE INDEX IF NOT EXISTS {table_name}_predicate_idx
  391. ON {self._get_table_name(table_name)} (predicate);
  392. CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
  393. ON {self._get_table_name(table_name)} (parent_id);
  394. CREATE INDEX IF NOT EXISTS {table_name}_subject_id_idx
  395. ON {self._get_table_name(table_name)} (subject_id);
  396. CREATE INDEX IF NOT EXISTS {table_name}_object_id_idx
  397. ON {self._get_table_name(table_name)} (object_id);
  398. """
  399. await self.connection_manager.execute_query(QUERY)
  400. async def create(
  401. self,
  402. subject: str,
  403. subject_id: UUID,
  404. predicate: str,
  405. object: str,
  406. object_id: UUID,
  407. parent_id: UUID,
  408. store_type: StoreType,
  409. description: str | None = None,
  410. weight: float | None = 1.0,
  411. chunk_ids: Optional[list[UUID]] = None,
  412. description_embedding: Optional[list[float] | str] = None,
  413. metadata: Optional[dict[str, Any] | str] = None,
  414. ) -> Relationship:
  415. """Create a new relationship in the specified store."""
  416. table_name = self._get_relationship_table_for_store(store_type)
  417. if isinstance(metadata, str):
  418. try:
  419. metadata = json.loads(metadata)
  420. except json.JSONDecodeError:
  421. pass
  422. if isinstance(description_embedding, list):
  423. description_embedding = str(description_embedding)
  424. query = f"""
  425. INSERT INTO {self._get_table_name(table_name)}
  426. (subject, predicate, object, description, subject_id, object_id,
  427. weight, chunk_ids, parent_id, description_embedding, metadata)
  428. VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
  429. RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
  430. """
  431. params = [
  432. subject,
  433. predicate,
  434. object,
  435. description,
  436. subject_id,
  437. object_id,
  438. weight,
  439. chunk_ids,
  440. parent_id,
  441. description_embedding,
  442. json.dumps(metadata) if metadata else None,
  443. ]
  444. result = await self.connection_manager.fetchrow_query(
  445. query=query,
  446. params=params,
  447. )
  448. return Relationship(
  449. id=result["id"],
  450. subject=result["subject"],
  451. predicate=result["predicate"],
  452. object=result["object"],
  453. description=result["description"],
  454. subject_id=result["subject_id"],
  455. object_id=result["object_id"],
  456. weight=result["weight"],
  457. chunk_ids=result["chunk_ids"],
  458. parent_id=result["parent_id"],
  459. metadata=result["metadata"],
  460. )
  461. async def get(
  462. self,
  463. parent_id: UUID,
  464. store_type: StoreType,
  465. offset: int,
  466. limit: int,
  467. relationship_ids: Optional[list[UUID]] = None,
  468. entity_names: Optional[list[str]] = None,
  469. relationship_types: Optional[list[str]] = None,
  470. include_metadata: bool = False,
  471. ):
  472. """
  473. Get relationships from the specified store.
  474. Args:
  475. parent_id: UUID of the parent (collection_id or document_id)
  476. store_type: Type of store (graph or document)
  477. offset: Number of records to skip
  478. limit: Maximum number of records to return (-1 for no limit)
  479. relationship_ids: Optional list of specific relationship IDs to retrieve
  480. entity_names: Optional list of entity names to filter by (matches subject or object)
  481. relationship_types: Optional list of relationship types (predicates) to filter by
  482. include_metadata: Whether to include metadata in the response
  483. Returns:
  484. Tuple of (list of relationships, total count)
  485. """
  486. table_name = self._get_relationship_table_for_store(store_type)
  487. conditions = ["parent_id = $1"]
  488. params: list[Any] = [parent_id]
  489. param_index = 2
  490. if relationship_ids:
  491. conditions.append(f"id = ANY(${param_index})")
  492. params.append(relationship_ids)
  493. param_index += 1
  494. if entity_names:
  495. conditions.append(
  496. f"(subject = ANY(${param_index}) OR object = ANY(${param_index}))"
  497. )
  498. params.append(entity_names)
  499. param_index += 1
  500. if relationship_types:
  501. conditions.append(f"predicate = ANY(${param_index})")
  502. params.append(relationship_types)
  503. param_index += 1
  504. select_fields = """
  505. id, subject, predicate, object, description,
  506. subject_id, object_id, weight, chunk_ids,
  507. parent_id
  508. """
  509. if include_metadata:
  510. select_fields += ", metadata"
  511. # Count query
  512. COUNT_QUERY = f"""
  513. SELECT COUNT(*)
  514. FROM {self._get_table_name(table_name)}
  515. WHERE {' AND '.join(conditions)}
  516. """
  517. count_params = params[: param_index - 1]
  518. count = (
  519. await self.connection_manager.fetch_query(
  520. COUNT_QUERY, count_params
  521. )
  522. )[0]["count"]
  523. # Main query
  524. QUERY = f"""
  525. SELECT {select_fields}
  526. FROM {self._get_table_name(table_name)}
  527. WHERE {' AND '.join(conditions)}
  528. ORDER BY created_at
  529. OFFSET ${param_index}
  530. """
  531. params.append(offset)
  532. param_index += 1
  533. if limit != -1:
  534. QUERY += f" LIMIT ${param_index}"
  535. params.append(limit)
  536. rows = await self.connection_manager.fetch_query(QUERY, params)
  537. relationships = []
  538. for row in rows:
  539. relationship_dict = dict(row)
  540. if include_metadata and isinstance(
  541. relationship_dict["metadata"], str
  542. ):
  543. try:
  544. relationship_dict["metadata"] = json.loads(
  545. relationship_dict["metadata"]
  546. )
  547. except json.JSONDecodeError:
  548. pass
  549. elif not include_metadata:
  550. relationship_dict.pop("metadata", None)
  551. relationships.append(Relationship(**relationship_dict))
  552. return relationships, count
  553. async def update(
  554. self,
  555. relationship_id: UUID,
  556. store_type: StoreType,
  557. subject: Optional[str],
  558. subject_id: Optional[UUID],
  559. predicate: Optional[str],
  560. object: Optional[str],
  561. object_id: Optional[UUID],
  562. description: Optional[str],
  563. description_embedding: Optional[list[float] | str],
  564. weight: Optional[float],
  565. metadata: Optional[dict[str, Any] | str],
  566. ) -> Relationship:
  567. """Update multiple relationships in the specified store."""
  568. table_name = self._get_relationship_table_for_store(store_type)
  569. update_fields = []
  570. params: list = []
  571. param_index = 1
  572. if isinstance(metadata, str):
  573. try:
  574. metadata = json.loads(metadata)
  575. except json.JSONDecodeError:
  576. pass
  577. if subject is not None:
  578. update_fields.append(f"subject = ${param_index}")
  579. params.append(subject)
  580. param_index += 1
  581. if subject_id is not None:
  582. update_fields.append(f"subject_id = ${param_index}")
  583. params.append(subject_id)
  584. param_index += 1
  585. if predicate is not None:
  586. update_fields.append(f"predicate = ${param_index}")
  587. params.append(predicate)
  588. param_index += 1
  589. if object is not None:
  590. update_fields.append(f"object = ${param_index}")
  591. params.append(object)
  592. param_index += 1
  593. if object_id is not None:
  594. update_fields.append(f"object_id = ${param_index}")
  595. params.append(object_id)
  596. param_index += 1
  597. if description is not None:
  598. update_fields.append(f"description = ${param_index}")
  599. params.append(description)
  600. param_index += 1
  601. if description_embedding is not None:
  602. update_fields.append(f"description_embedding = ${param_index}")
  603. params.append(description_embedding)
  604. param_index += 1
  605. if weight is not None:
  606. update_fields.append(f"weight = ${param_index}")
  607. params.append(weight)
  608. param_index += 1
  609. if not update_fields:
  610. raise R2RException(status_code=400, message="No fields to update")
  611. update_fields.append("updated_at = NOW()")
  612. params.append(relationship_id)
  613. query = f"""
  614. UPDATE {self._get_table_name(table_name)}
  615. SET {', '.join(update_fields)}
  616. WHERE id = ${param_index}
  617. RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
  618. """
  619. try:
  620. result = await self.connection_manager.fetchrow_query(
  621. query=query,
  622. params=params,
  623. )
  624. return Relationship(
  625. id=result["id"],
  626. subject=result["subject"],
  627. predicate=result["predicate"],
  628. object=result["object"],
  629. description=result["description"],
  630. subject_id=result["subject_id"],
  631. object_id=result["object_id"],
  632. weight=result["weight"],
  633. chunk_ids=result["chunk_ids"],
  634. parent_id=result["parent_id"],
  635. metadata=result["metadata"],
  636. )
  637. except Exception as e:
  638. raise HTTPException(
  639. status_code=500,
  640. detail=f"An error occurred while updating the relationship: {e}",
  641. )
  642. async def delete(
  643. self,
  644. parent_id: UUID,
  645. relationship_ids: Optional[list[UUID]] = None,
  646. store_type: StoreType = StoreType.GRAPHS,
  647. ) -> None:
  648. """
  649. Delete relationships from the specified store.
  650. If relationship_ids is not provided, deletes all relationships for the given parent_id.
  651. Args:
  652. parent_id: UUID of the parent (collection_id or document_id)
  653. relationship_ids: Optional list of specific relationship IDs to delete
  654. store_type: Type of store (graph or document)
  655. Returns:
  656. List of deleted relationship IDs
  657. Raises:
  658. R2RException: If specific relationships were requested but not all found
  659. """
  660. table_name = self._get_relationship_table_for_store(store_type)
  661. if relationship_ids is None:
  662. QUERY = f"""
  663. DELETE FROM {self._get_table_name(table_name)}
  664. WHERE parent_id = $1
  665. RETURNING id
  666. """
  667. results = await self.connection_manager.fetch_query(
  668. QUERY, [parent_id]
  669. )
  670. else:
  671. QUERY = f"""
  672. DELETE FROM {self._get_table_name(table_name)}
  673. WHERE id = ANY($1) AND parent_id = $2
  674. RETURNING id
  675. """
  676. results = await self.connection_manager.fetch_query(
  677. QUERY, [relationship_ids, parent_id]
  678. )
  679. deleted_ids = [row["id"] for row in results]
  680. if relationship_ids and len(deleted_ids) != len(relationship_ids):
  681. raise R2RException(
  682. f"Some relationships not found in {store_type} store or no permission to delete",
  683. 404,
  684. )
  685. class PostgresCommunitiesHandler(Handler):
  686. def __init__(self, *args: Any, **kwargs: Any) -> None:
  687. self.project_name: str = kwargs.get("project_name") # type: ignore
  688. self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore
  689. self.dimension: int = kwargs.get("dimension") # type: ignore
  690. self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore
  691. async def create_tables(self) -> None:
  692. vector_column_str = _decorate_vector_type(
  693. f"({self.dimension})", self.quantization_type
  694. )
  695. query = f"""
  696. CREATE TABLE IF NOT EXISTS {self._get_table_name("graphs_communities")} (
  697. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  698. collection_id UUID,
  699. community_id UUID,
  700. level INT,
  701. name TEXT NOT NULL,
  702. summary TEXT NOT NULL,
  703. findings TEXT[],
  704. rating FLOAT,
  705. rating_explanation TEXT,
  706. description_embedding {vector_column_str} NOT NULL,
  707. created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
  708. updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
  709. metadata JSONB,
  710. UNIQUE (community_id, level, collection_id)
  711. );"""
  712. await self.connection_manager.execute_query(query)
  713. async def create(
  714. self,
  715. parent_id: UUID,
  716. store_type: StoreType,
  717. name: str,
  718. summary: str,
  719. findings: Optional[list[str]],
  720. rating: Optional[float],
  721. rating_explanation: Optional[str],
  722. description_embedding: Optional[list[float] | str] = None,
  723. ) -> Community:
  724. # Do we ever want to get communities from document store?
  725. table_name = "graphs_communities"
  726. if isinstance(description_embedding, list):
  727. description_embedding = str(description_embedding)
  728. query = f"""
  729. INSERT INTO {self._get_table_name(table_name)}
  730. (collection_id, name, summary, findings, rating, rating_explanation, description_embedding)
  731. VALUES ($1, $2, $3, $4, $5, $6, $7)
  732. RETURNING id, collection_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
  733. """
  734. params = [
  735. parent_id,
  736. name,
  737. summary,
  738. findings,
  739. rating,
  740. rating_explanation,
  741. description_embedding,
  742. ]
  743. try:
  744. result = await self.connection_manager.fetchrow_query(
  745. query=query,
  746. params=params,
  747. )
  748. return Community(
  749. id=result["id"],
  750. collection_id=result["collection_id"],
  751. name=result["name"],
  752. summary=result["summary"],
  753. findings=result["findings"],
  754. rating=result["rating"],
  755. rating_explanation=result["rating_explanation"],
  756. created_at=result["created_at"],
  757. updated_at=result["updated_at"],
  758. )
  759. except Exception as e:
  760. raise HTTPException(
  761. status_code=500,
  762. detail=f"An error occurred while creating the community: {e}",
  763. )
  764. async def update(
  765. self,
  766. community_id: UUID,
  767. store_type: StoreType,
  768. name: Optional[str] = None,
  769. summary: Optional[str] = None,
  770. summary_embedding: Optional[list[float] | str] = None,
  771. findings: Optional[list[str]] = None,
  772. rating: Optional[float] = None,
  773. rating_explanation: Optional[str] = None,
  774. ) -> Community:
  775. table_name = "graphs_communities"
  776. update_fields = []
  777. params: list[Any] = []
  778. param_index = 1
  779. if name is not None:
  780. update_fields.append(f"name = ${param_index}")
  781. params.append(name)
  782. param_index += 1
  783. if summary is not None:
  784. update_fields.append(f"summary = ${param_index}")
  785. params.append(summary)
  786. param_index += 1
  787. if summary_embedding is not None:
  788. update_fields.append(f"description_embedding = ${param_index}")
  789. params.append(summary_embedding)
  790. param_index += 1
  791. if findings is not None:
  792. update_fields.append(f"findings = ${param_index}")
  793. params.append(findings)
  794. param_index += 1
  795. if rating is not None:
  796. update_fields.append(f"rating = ${param_index}")
  797. params.append(rating)
  798. param_index += 1
  799. if rating_explanation is not None:
  800. update_fields.append(f"rating_explanation = ${param_index}")
  801. params.append(rating_explanation)
  802. param_index += 1
  803. if not update_fields:
  804. raise R2RException(status_code=400, message="No fields to update")
  805. update_fields.append("updated_at = NOW()")
  806. params.append(community_id)
  807. query = f"""
  808. UPDATE {self._get_table_name(table_name)}
  809. SET {", ".join(update_fields)}
  810. WHERE id = ${param_index}\
  811. RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
  812. """
  813. try:
  814. result = await self.connection_manager.fetchrow_query(
  815. query, params
  816. )
  817. return Community(
  818. id=result["id"],
  819. community_id=result["community_id"],
  820. name=result["name"],
  821. summary=result["summary"],
  822. findings=result["findings"],
  823. rating=result["rating"],
  824. rating_explanation=result["rating_explanation"],
  825. created_at=result["created_at"],
  826. updated_at=result["updated_at"],
  827. )
  828. except Exception as e:
  829. raise HTTPException(
  830. status_code=500,
  831. detail=f"An error occurred while updating the community: {e}",
  832. )
  833. async def delete(
  834. self,
  835. parent_id: UUID,
  836. community_id: UUID,
  837. ) -> None:
  838. table_name = "graphs_communities"
  839. query = f"""
  840. DELETE FROM {self._get_table_name(table_name)}
  841. WHERE id = $1 AND collection_id = $2
  842. """
  843. params = [community_id, parent_id]
  844. try:
  845. results = await self.connection_manager.execute_query(
  846. query, params
  847. )
  848. except Exception as e:
  849. raise HTTPException(
  850. status_code=500,
  851. detail=f"An error occurred while deleting the community: {e}",
  852. )
  853. params = [
  854. community_id,
  855. parent_id,
  856. ]
  857. try:
  858. results = await self.connection_manager.execute_query(
  859. query, params
  860. )
  861. except Exception as e:
  862. raise HTTPException(
  863. status_code=500,
  864. detail=f"An error occurred while deleting the community: {e}",
  865. )
  866. async def get(
  867. self,
  868. parent_id: UUID,
  869. store_type: StoreType,
  870. offset: int,
  871. limit: int,
  872. community_ids: Optional[list[UUID]] = None,
  873. community_names: Optional[list[str]] = None,
  874. include_embeddings: bool = False,
  875. ):
  876. """Retrieve communities from the specified store."""
  877. # Do we ever want to get communities from document store?
  878. table_name = "graphs_communities"
  879. conditions = ["collection_id = $1"]
  880. params: list[Any] = [parent_id]
  881. param_index = 2
  882. if community_ids:
  883. conditions.append(f"id = ANY(${param_index})")
  884. params.append(community_ids)
  885. param_index += 1
  886. if community_names:
  887. conditions.append(f"name = ANY(${param_index})")
  888. params.append(community_names)
  889. param_index += 1
  890. select_fields = """
  891. id, community_id, name, summary, findings, rating,
  892. rating_explanation, level, created_at, updated_at
  893. """
  894. if include_embeddings:
  895. select_fields += ", description_embedding"
  896. COUNT_QUERY = f"""
  897. SELECT COUNT(*)
  898. FROM {self._get_table_name(table_name)}
  899. WHERE {' AND '.join(conditions)}
  900. """
  901. count = (
  902. await self.connection_manager.fetch_query(
  903. COUNT_QUERY, params[: param_index - 1]
  904. )
  905. )[0]["count"]
  906. QUERY = f"""
  907. SELECT {select_fields}
  908. FROM {self._get_table_name(table_name)}
  909. WHERE {' AND '.join(conditions)}
  910. ORDER BY created_at
  911. OFFSET ${param_index}
  912. """
  913. params.append(offset)
  914. param_index += 1
  915. if limit != -1:
  916. QUERY += f" LIMIT ${param_index}"
  917. params.append(limit)
  918. rows = await self.connection_manager.fetch_query(QUERY, params)
  919. communities = []
  920. for row in rows:
  921. community_dict = dict(row)
  922. communities.append(Community(**community_dict))
  923. return communities, count
  924. class PostgresGraphsHandler(Handler):
  925. """Handler for Knowledge Graph METHODS in PostgreSQL."""
  926. TABLE_NAME = "graphs"
  927. def __init__(
  928. self,
  929. *args: Any,
  930. **kwargs: Any,
  931. ) -> None:
  932. self.project_name: str = kwargs.get("project_name") # type: ignore
  933. self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore
  934. self.dimension: int = kwargs.get("dimension") # type: ignore
  935. self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore
  936. self.collections_handler: PostgresCollectionsHandler = kwargs.get("collections_handler") # type: ignore
  937. self.entities = PostgresEntitiesHandler(*args, **kwargs)
  938. self.relationships = PostgresRelationshipsHandler(*args, **kwargs)
  939. self.communities = PostgresCommunitiesHandler(*args, **kwargs)
  940. self.handlers = [
  941. self.entities,
  942. self.relationships,
  943. self.communities,
  944. ]
  945. import networkx as nx
  946. self.nx = nx
  947. async def create_tables(self) -> None:
  948. """Create the graph tables with mandatory collection_id support."""
  949. QUERY = f"""
  950. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} (
  951. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  952. collection_id UUID NOT NULL,
  953. name TEXT NOT NULL,
  954. description TEXT,
  955. status TEXT NOT NULL,
  956. document_ids UUID[],
  957. metadata JSONB,
  958. created_at TIMESTAMPTZ DEFAULT NOW(),
  959. updated_at TIMESTAMPTZ DEFAULT NOW()
  960. );
  961. CREATE INDEX IF NOT EXISTS graph_collection_id_idx
  962. ON {self._get_table_name("graphs")} (collection_id);
  963. """
  964. await self.connection_manager.execute_query(QUERY)
  965. for handler in self.handlers:
  966. await handler.create_tables()
  967. async def create(
  968. self,
  969. collection_id: UUID,
  970. name: Optional[str] = None,
  971. description: Optional[str] = None,
  972. status: str = "pending",
  973. ) -> GraphResponse:
  974. """Create a new graph associated with a collection."""
  975. name = name or f"Graph {collection_id}"
  976. description = description or ""
  977. query = f"""
  978. INSERT INTO {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  979. (id, collection_id, name, description, status)
  980. VALUES ($1, $2, $3, $4, $5)
  981. RETURNING id, collection_id, name, description, status, created_at, updated_at, document_ids
  982. """
  983. params = [
  984. collection_id,
  985. collection_id,
  986. name,
  987. description,
  988. status,
  989. ]
  990. try:
  991. result = await self.connection_manager.fetchrow_query(
  992. query=query,
  993. params=params,
  994. )
  995. return GraphResponse(
  996. id=result["id"],
  997. collection_id=result["collection_id"],
  998. name=result["name"],
  999. description=result["description"],
  1000. status=result["status"],
  1001. created_at=result["created_at"],
  1002. updated_at=result["updated_at"],
  1003. document_ids=result["document_ids"] or [],
  1004. )
  1005. except UniqueViolationError:
  1006. raise R2RException(
  1007. message="Graph with this ID already exists",
  1008. status_code=409,
  1009. )
  1010. async def reset(self, parent_id: UUID) -> None:
  1011. """
  1012. Completely reset a graph and all associated data.
  1013. """
  1014. try:
  1015. entity_delete_query = f"""
  1016. DELETE FROM {self._get_table_name("graphs_entities")}
  1017. WHERE parent_id = $1
  1018. """
  1019. await self.connection_manager.execute_query(
  1020. entity_delete_query, [parent_id]
  1021. )
  1022. # Delete all graph relationships
  1023. relationship_delete_query = f"""
  1024. DELETE FROM {self._get_table_name("graphs_relationships")}
  1025. WHERE parent_id = $1
  1026. """
  1027. await self.connection_manager.execute_query(
  1028. relationship_delete_query, [parent_id]
  1029. )
  1030. # Delete all graph relationships
  1031. community_delete_query = f"""
  1032. DELETE FROM {self._get_table_name("graphs_communities")}
  1033. WHERE collection_id = $1
  1034. """
  1035. await self.connection_manager.execute_query(
  1036. community_delete_query, [parent_id]
  1037. )
  1038. # Delete all graph communities and community info
  1039. query = f"""
  1040. DELETE FROM {self._get_table_name("graphs_communities")}
  1041. WHERE collection_id = $1
  1042. """
  1043. await self.connection_manager.execute_query(query, [parent_id])
  1044. except Exception as e:
  1045. logger.error(f"Error deleting graph {parent_id}: {str(e)}")
  1046. raise R2RException(f"Failed to delete graph: {str(e)}", 500)
  1047. async def list_graphs(
  1048. self,
  1049. offset: int,
  1050. limit: int,
  1051. # filter_user_ids: Optional[list[UUID]] = None,
  1052. filter_graph_ids: Optional[list[UUID]] = None,
  1053. filter_collection_id: Optional[UUID] = None,
  1054. ) -> dict[str, list[GraphResponse] | int]:
  1055. conditions = []
  1056. params: list[Any] = []
  1057. param_index = 1
  1058. if filter_graph_ids:
  1059. conditions.append(f"id = ANY(${param_index})")
  1060. params.append(filter_graph_ids)
  1061. param_index += 1
  1062. # if filter_user_ids:
  1063. # conditions.append(f"user_id = ANY(${param_index})")
  1064. # params.append(filter_user_ids)
  1065. # param_index += 1
  1066. if filter_collection_id:
  1067. conditions.append(f"collection_id = ${param_index}")
  1068. params.append(filter_collection_id)
  1069. param_index += 1
  1070. where_clause = (
  1071. f"WHERE {' AND '.join(conditions)}" if conditions else ""
  1072. )
  1073. query = f"""
  1074. WITH RankedGraphs AS (
  1075. SELECT
  1076. id, collection_id, name, description, status, created_at, updated_at, document_ids,
  1077. COUNT(*) OVER() as total_entries,
  1078. ROW_NUMBER() OVER (PARTITION BY collection_id ORDER BY created_at DESC) as rn
  1079. FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1080. {where_clause}
  1081. )
  1082. SELECT * FROM RankedGraphs
  1083. WHERE rn = 1
  1084. ORDER BY created_at DESC
  1085. OFFSET ${param_index} LIMIT ${param_index + 1}
  1086. """
  1087. params.extend([offset, limit])
  1088. try:
  1089. results = await self.connection_manager.fetch_query(query, params)
  1090. if not results:
  1091. return {"results": [], "total_entries": 0}
  1092. total_entries = results[0]["total_entries"] if results else 0
  1093. graphs = [
  1094. GraphResponse(
  1095. id=row["id"],
  1096. document_ids=row["document_ids"] or [],
  1097. name=row["name"],
  1098. collection_id=row["collection_id"],
  1099. description=row["description"],
  1100. status=row["status"],
  1101. created_at=row["created_at"],
  1102. updated_at=row["updated_at"],
  1103. )
  1104. for row in results
  1105. ]
  1106. return {"results": graphs, "total_entries": total_entries}
  1107. except Exception as e:
  1108. raise HTTPException(
  1109. status_code=500,
  1110. detail=f"An error occurred while fetching graphs: {e}",
  1111. )
  1112. async def get(
  1113. self, offset: int, limit: int, graph_id: Optional[UUID] = None
  1114. ):
  1115. if graph_id is None:
  1116. params = [offset, limit]
  1117. QUERY = f"""
  1118. SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1119. OFFSET $1 LIMIT $2
  1120. """
  1121. ret = await self.connection_manager.fetch_query(QUERY, params)
  1122. COUNT_QUERY = f"""
  1123. SELECT COUNT(*) FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1124. """
  1125. count = (await self.connection_manager.fetch_query(COUNT_QUERY))[
  1126. 0
  1127. ]["count"]
  1128. return {
  1129. "results": [Graph(**row) for row in ret],
  1130. "total_entries": count,
  1131. }
  1132. else:
  1133. QUERY = f"""
  1134. SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} WHERE id = $1
  1135. """
  1136. params = [graph_id] # type: ignore
  1137. return {
  1138. "results": [
  1139. Graph(
  1140. **await self.connection_manager.fetchrow_query(
  1141. QUERY, params
  1142. )
  1143. )
  1144. ]
  1145. }
  1146. async def add_documents(self, id: UUID, document_ids: list[UUID]) -> bool:
  1147. """
  1148. Add documents to the graph by copying their entities and relationships.
  1149. """
  1150. # Copy entities from document_entity to graphs_entities
  1151. ENTITY_COPY_QUERY = f"""
  1152. INSERT INTO {self._get_table_name("graphs_entities")} (
  1153. name, category, description, parent_id, description_embedding,
  1154. chunk_ids, metadata
  1155. )
  1156. SELECT
  1157. name, category, description, $1, description_embedding,
  1158. chunk_ids, metadata
  1159. FROM {self._get_table_name("documents_entities")}
  1160. WHERE parent_id = ANY($2)
  1161. """
  1162. await self.connection_manager.execute_query(
  1163. ENTITY_COPY_QUERY, [id, document_ids]
  1164. )
  1165. # Copy relationships from documents_relationships to graphs_relationships
  1166. RELATIONSHIP_COPY_QUERY = f"""
  1167. INSERT INTO {self._get_table_name("graphs_relationships")} (
  1168. subject, predicate, object, description, subject_id, object_id,
  1169. weight, chunk_ids, parent_id, metadata, description_embedding
  1170. )
  1171. SELECT
  1172. subject, predicate, object, description, subject_id, object_id,
  1173. weight, chunk_ids, $1, metadata, description_embedding
  1174. FROM {self._get_table_name("documents_relationships")}
  1175. WHERE parent_id = ANY($2)
  1176. """
  1177. await self.connection_manager.execute_query(
  1178. RELATIONSHIP_COPY_QUERY, [id, document_ids]
  1179. )
  1180. # Add document_ids to the graph
  1181. UPDATE_GRAPH_QUERY = f"""
  1182. UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1183. SET document_ids = array_cat(
  1184. CASE
  1185. WHEN document_ids IS NULL THEN ARRAY[]::uuid[]
  1186. ELSE document_ids
  1187. END,
  1188. $2::uuid[]
  1189. )
  1190. WHERE id = $1
  1191. """
  1192. await self.connection_manager.execute_query(
  1193. UPDATE_GRAPH_QUERY, [id, document_ids]
  1194. )
  1195. return True
  1196. async def update(
  1197. self,
  1198. collection_id: UUID,
  1199. name: Optional[str] = None,
  1200. description: Optional[str] = None,
  1201. ) -> GraphResponse:
  1202. """Update an existing graph."""
  1203. update_fields = []
  1204. params: list = []
  1205. param_index = 1
  1206. if name is not None:
  1207. update_fields.append(f"name = ${param_index}")
  1208. params.append(name)
  1209. param_index += 1
  1210. if description is not None:
  1211. update_fields.append(f"description = ${param_index}")
  1212. params.append(description)
  1213. param_index += 1
  1214. if not update_fields:
  1215. raise R2RException(status_code=400, message="No fields to update")
  1216. update_fields.append("updated_at = NOW()")
  1217. params.append(collection_id)
  1218. query = f"""
  1219. UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
  1220. SET {', '.join(update_fields)}
  1221. WHERE id = ${param_index}
  1222. RETURNING id, name, description, status, created_at, updated_at, collection_id, document_ids
  1223. """
  1224. try:
  1225. result = await self.connection_manager.fetchrow_query(
  1226. query, params
  1227. )
  1228. if not result:
  1229. raise R2RException(status_code=404, message="Graph not found")
  1230. return GraphResponse(
  1231. id=result["id"],
  1232. collection_id=result["collection_id"],
  1233. name=result["name"],
  1234. description=result["description"],
  1235. status=result["status"],
  1236. created_at=result["created_at"],
  1237. document_ids=result["document_ids"] or [],
  1238. updated_at=result["updated_at"],
  1239. )
  1240. except Exception as e:
  1241. raise HTTPException(
  1242. status_code=500,
  1243. detail=f"An error occurred while updating the graph: {e}",
  1244. )
  1245. async def get_creation_estimate(
  1246. self,
  1247. graph_creation_settings: KGCreationSettings,
  1248. document_id: Optional[UUID] = None,
  1249. collection_id: Optional[UUID] = None,
  1250. ):
  1251. """Get the estimated cost and time for creating a KG."""
  1252. if bool(document_id) ^ bool(collection_id) is False:
  1253. raise ValueError(
  1254. "Exactly one of document_id or collection_id must be provided."
  1255. )
  1256. # todo: harmonize the document_id and id fields: postgres table contains document_id, but other places use id.
  1257. document_ids = (
  1258. [document_id]
  1259. if document_id
  1260. else [
  1261. doc.id for doc in (await self.collections_handler.documents_in_collection(collection_id, offset=0, limit=-1))["results"] # type: ignore
  1262. ]
  1263. )
  1264. chunk_counts = await self.connection_manager.fetch_query(
  1265. f"SELECT document_id, COUNT(*) as chunk_count FROM {self._get_table_name('vectors')} "
  1266. f"WHERE document_id = ANY($1) GROUP BY document_id",
  1267. [document_ids],
  1268. )
  1269. total_chunks = (
  1270. sum(doc["chunk_count"] for doc in chunk_counts)
  1271. // graph_creation_settings.chunk_merge_count
  1272. )
  1273. estimated_entities = (total_chunks * 10, total_chunks * 20)
  1274. estimated_relationships = (
  1275. int(estimated_entities[0] * 1.25),
  1276. int(estimated_entities[1] * 1.5),
  1277. )
  1278. estimated_llm_calls = (
  1279. total_chunks * 2 + estimated_entities[0],
  1280. total_chunks * 2 + estimated_entities[1],
  1281. )
  1282. total_in_out_tokens = tuple(
  1283. 2000 * calls // 1000000 for calls in estimated_llm_calls
  1284. )
  1285. cost_per_million = llm_cost_per_million_tokens(
  1286. graph_creation_settings.generation_config.model
  1287. )
  1288. estimated_cost = tuple(
  1289. tokens * cost_per_million for tokens in total_in_out_tokens
  1290. )
  1291. total_time_in_minutes = tuple(
  1292. tokens * 10 / 60 for tokens in total_in_out_tokens
  1293. )
  1294. return {
  1295. "message": 'Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `extract-triples` with `--run` in the cli, or `run_type="run"` in the client.',
  1296. "document_count": len(document_ids),
  1297. "number_of_jobs_created": len(document_ids) + 1,
  1298. "total_chunks": total_chunks,
  1299. "estimated_entities": _get_str_estimation_output(
  1300. estimated_entities
  1301. ),
  1302. "estimated_relationships": _get_str_estimation_output(
  1303. estimated_relationships
  1304. ),
  1305. "estimated_llm_calls": _get_str_estimation_output(
  1306. estimated_llm_calls
  1307. ),
  1308. "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(
  1309. total_in_out_tokens
  1310. ),
  1311. "estimated_cost_in_usd": _get_str_estimation_output(
  1312. estimated_cost
  1313. ),
  1314. "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: "
  1315. + _get_str_estimation_output(total_time_in_minutes),
  1316. }
  1317. async def get_enrichment_estimate(
  1318. self,
  1319. collection_id: UUID | None = None,
  1320. graph_id: UUID | None = None,
  1321. graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
  1322. ):
  1323. """Get the estimated cost and time for enriching a KG."""
  1324. if collection_id is not None:
  1325. document_ids = [
  1326. doc.id
  1327. for doc in (
  1328. await self.collections_handler.documents_in_collection(collection_id, offset=0, limit=-1) # type: ignore
  1329. )["results"]
  1330. ]
  1331. # Get entity and relationship counts
  1332. entity_count = (
  1333. await self.connection_manager.fetch_query(
  1334. f"SELECT COUNT(*) FROM {self._get_table_name('entity')} WHERE document_id = ANY($1);",
  1335. [document_ids],
  1336. )
  1337. )[0]["count"]
  1338. if not entity_count:
  1339. raise ValueError(
  1340. "No entities found in the graph. Please run `extract-triples` first."
  1341. )
  1342. relationship_count = (
  1343. await self.connection_manager.fetch_query(
  1344. f"""SELECT COUNT(*) FROM {self._get_table_name("documents_relationships")} WHERE document_id = ANY($1);""",
  1345. [document_ids],
  1346. )
  1347. )[0]["count"]
  1348. else:
  1349. entity_count = (
  1350. await self.connection_manager.fetch_query(
  1351. f"SELECT COUNT(*) FROM {self._get_table_name('entity')} WHERE $1 = ANY(graph_ids);",
  1352. [graph_id],
  1353. )
  1354. )[0]["count"]
  1355. if not entity_count:
  1356. raise ValueError(
  1357. "No entities found in the graph. Please run `extract-triples` first."
  1358. )
  1359. relationship_count = (
  1360. await self.connection_manager.fetch_query(
  1361. f"SELECT COUNT(*) FROM {self._get_table_name('relationship')} WHERE $1 = ANY(graph_ids);",
  1362. [graph_id],
  1363. )
  1364. )[0]["count"]
  1365. # Calculate estimates
  1366. estimated_llm_calls = (entity_count // 10, entity_count // 5)
  1367. tokens_in_millions = tuple(
  1368. 2000 * calls / 1000000 for calls in estimated_llm_calls
  1369. )
  1370. cost_per_million = llm_cost_per_million_tokens(
  1371. graph_enrichment_settings.generation_config.model # type: ignore
  1372. )
  1373. estimated_cost = tuple(
  1374. tokens * cost_per_million for tokens in tokens_in_millions
  1375. )
  1376. estimated_time = tuple(
  1377. tokens * 10 / 60 for tokens in tokens_in_millions
  1378. )
  1379. return {
  1380. "message": 'Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `build-communities` with `--run` in the cli, or `run_type="run"` in the client.',
  1381. "total_entities": entity_count,
  1382. "total_relationships": relationship_count,
  1383. "estimated_llm_calls": _get_str_estimation_output(
  1384. estimated_llm_calls
  1385. ),
  1386. "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(
  1387. tokens_in_millions
  1388. ),
  1389. "estimated_cost_in_usd": _get_str_estimation_output(
  1390. estimated_cost
  1391. ),
  1392. "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: "
  1393. + _get_str_estimation_output(estimated_time),
  1394. }
  1395. async def get_deduplication_estimate(
  1396. self,
  1397. collection_id: UUID,
  1398. kg_deduplication_settings: KGEntityDeduplicationSettings,
  1399. ):
  1400. """Get the estimated cost and time for deduplicating entities in a KG."""
  1401. try:
  1402. query = f"""
  1403. SELECT name, count(name)
  1404. FROM {self._get_table_name("entity")}
  1405. WHERE document_id = ANY(
  1406. SELECT document_id FROM {self._get_table_name("documents")}
  1407. WHERE $1 = ANY(collection_ids)
  1408. )
  1409. GROUP BY name
  1410. HAVING count(name) >= 5
  1411. """
  1412. entities = await self.connection_manager.fetch_query(
  1413. query, [collection_id]
  1414. )
  1415. num_entities = len(entities)
  1416. estimated_llm_calls = (num_entities, num_entities)
  1417. tokens_in_millions = (
  1418. estimated_llm_calls[0] * 1000 / 1000000,
  1419. estimated_llm_calls[1] * 5000 / 1000000,
  1420. )
  1421. cost_per_million = llm_cost_per_million_tokens(
  1422. kg_deduplication_settings.generation_config.model
  1423. )
  1424. estimated_cost = (
  1425. tokens_in_millions[0] * cost_per_million,
  1426. tokens_in_millions[1] * cost_per_million,
  1427. )
  1428. estimated_time = (
  1429. tokens_in_millions[0] * 10 / 60,
  1430. tokens_in_millions[1] * 10 / 60,
  1431. )
  1432. return {
  1433. "message": "Ran Deduplication Estimate (not the actual run). Note that these are estimated ranges.",
  1434. "num_entities": num_entities,
  1435. "estimated_llm_calls": _get_str_estimation_output(
  1436. estimated_llm_calls
  1437. ),
  1438. "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(
  1439. tokens_in_millions
  1440. ),
  1441. "estimated_cost_in_usd": _get_str_estimation_output(
  1442. estimated_cost
  1443. ),
  1444. "estimated_total_time_in_minutes": _get_str_estimation_output(
  1445. estimated_time
  1446. ),
  1447. }
  1448. except UndefinedTableError:
  1449. raise R2RException(
  1450. "Entity embedding table not found. Please run `extract-triples` first.",
  1451. 404,
  1452. )
  1453. except Exception as e:
  1454. logger.error(f"Error in get_deduplication_estimate: {str(e)}")
  1455. raise HTTPException(500, "Error fetching deduplication estimate.")
  1456. async def get_entities(
  1457. self,
  1458. parent_id: UUID,
  1459. offset: int,
  1460. limit: int,
  1461. entity_ids: Optional[list[UUID]] = None,
  1462. entity_names: Optional[list[str]] = None,
  1463. include_embeddings: bool = False,
  1464. ) -> tuple[list[Entity], int]:
  1465. """
  1466. Get entities for a graph.
  1467. Args:
  1468. offset: Number of records to skip
  1469. limit: Maximum number of records to return (-1 for no limit)
  1470. parent_id: UUID of the collection
  1471. entity_ids: Optional list of entity IDs to filter by
  1472. entity_names: Optional list of entity names to filter by
  1473. include_embeddings: Whether to include embeddings in the response
  1474. Returns:
  1475. Tuple of (list of entities, total count)
  1476. """
  1477. conditions = ["parent_id = $1"]
  1478. params: list[Any] = [parent_id]
  1479. param_index = 2
  1480. if entity_ids:
  1481. conditions.append(f"id = ANY(${param_index})")
  1482. params.append(entity_ids)
  1483. param_index += 1
  1484. if entity_names:
  1485. conditions.append(f"name = ANY(${param_index})")
  1486. params.append(entity_names)
  1487. param_index += 1
  1488. # Count query - uses the same conditions but without offset/limit
  1489. COUNT_QUERY = f"""
  1490. SELECT COUNT(*)
  1491. FROM {self._get_table_name("graphs_entities")}
  1492. WHERE {' AND '.join(conditions)}
  1493. """
  1494. count = (
  1495. await self.connection_manager.fetch_query(COUNT_QUERY, params)
  1496. )[0]["count"]
  1497. # Define base columns to select
  1498. select_fields = """
  1499. id, name, category, description, parent_id,
  1500. chunk_ids, metadata
  1501. """
  1502. if include_embeddings:
  1503. select_fields += ", description_embedding"
  1504. # Main query for fetching entities with pagination
  1505. QUERY = f"""
  1506. SELECT {select_fields}
  1507. FROM {self._get_table_name("graphs_entities")}
  1508. WHERE {' AND '.join(conditions)}
  1509. ORDER BY created_at
  1510. OFFSET ${param_index}
  1511. """
  1512. params.append(offset)
  1513. param_index += 1
  1514. if limit != -1:
  1515. QUERY += f" LIMIT ${param_index}"
  1516. params.append(limit)
  1517. rows = await self.connection_manager.fetch_query(QUERY, params)
  1518. entities = []
  1519. for row in rows:
  1520. entity_dict = dict(row)
  1521. if isinstance(entity_dict["metadata"], str):
  1522. try:
  1523. entity_dict["metadata"] = json.loads(
  1524. entity_dict["metadata"]
  1525. )
  1526. except json.JSONDecodeError:
  1527. pass
  1528. entities.append(Entity(**entity_dict))
  1529. return entities, count
  1530. async def get_relationships(
  1531. self,
  1532. parent_id: UUID,
  1533. offset: int,
  1534. limit: int,
  1535. relationship_ids: Optional[list[UUID]] = None,
  1536. relationship_types: Optional[list[str]] = None,
  1537. include_embeddings: bool = False,
  1538. ) -> tuple[list[Relationship], int]:
  1539. """
  1540. Get relationships for a graph.
  1541. Args:
  1542. parent_id: UUID of the graph
  1543. offset: Number of records to skip
  1544. limit: Maximum number of records to return (-1 for no limit)
  1545. relationship_ids: Optional list of relationship IDs to filter by
  1546. relationship_types: Optional list of relationship types to filter by
  1547. include_metadata: Whether to include metadata in the response
  1548. Returns:
  1549. Tuple of (list of relationships, total count)
  1550. """
  1551. conditions = ["parent_id = $1"]
  1552. params: list[Any] = [parent_id]
  1553. param_index = 2
  1554. if relationship_ids:
  1555. conditions.append(f"id = ANY(${param_index})")
  1556. params.append(relationship_ids)
  1557. param_index += 1
  1558. if relationship_types:
  1559. conditions.append(f"predicate = ANY(${param_index})")
  1560. params.append(relationship_types)
  1561. param_index += 1
  1562. # Count query - uses the same conditions but without offset/limit
  1563. COUNT_QUERY = f"""
  1564. SELECT COUNT(*)
  1565. FROM {self._get_table_name("graphs_relationships")}
  1566. WHERE {' AND '.join(conditions)}
  1567. """
  1568. count = (
  1569. await self.connection_manager.fetch_query(COUNT_QUERY, params)
  1570. )[0]["count"]
  1571. # Define base columns to select
  1572. select_fields = """
  1573. id, subject, predicate, object, weight, chunk_ids, parent_id, metadata
  1574. """
  1575. if include_embeddings:
  1576. select_fields += ", description_embedding"
  1577. # Main query for fetching relationships with pagination
  1578. QUERY = f"""
  1579. SELECT {select_fields}
  1580. FROM {self._get_table_name("graphs_relationships")}
  1581. WHERE {' AND '.join(conditions)}
  1582. ORDER BY created_at
  1583. OFFSET ${param_index}
  1584. """
  1585. params.append(offset)
  1586. param_index += 1
  1587. if limit != -1:
  1588. QUERY += f" LIMIT ${param_index}"
  1589. params.append(limit)
  1590. rows = await self.connection_manager.fetch_query(QUERY, params)
  1591. relationships = []
  1592. for row in rows:
  1593. relationship_dict = dict(row)
  1594. if isinstance(relationship_dict["metadata"], str):
  1595. try:
  1596. relationship_dict["metadata"] = json.loads(
  1597. relationship_dict["metadata"]
  1598. )
  1599. except json.JSONDecodeError:
  1600. pass
  1601. relationships.append(Relationship(**relationship_dict))
  1602. return relationships, count
  1603. async def add_entities(
  1604. self,
  1605. entities: list[Entity],
  1606. table_name: str,
  1607. conflict_columns: list[str] = [],
  1608. ) -> asyncpg.Record:
  1609. """
  1610. Upsert entities into the entities_raw table. These are raw entities extracted from the document.
  1611. Args:
  1612. entities: list[Entity]: list of entities to upsert
  1613. collection_name: str: name of the collection
  1614. Returns:
  1615. result: asyncpg.Record: result of the upsert operation
  1616. """
  1617. cleaned_entities = []
  1618. for entity in entities:
  1619. entity_dict = entity.to_dict()
  1620. entity_dict["chunk_ids"] = (
  1621. entity_dict["chunk_ids"]
  1622. if entity_dict.get("chunk_ids")
  1623. else []
  1624. )
  1625. entity_dict["description_embedding"] = (
  1626. str(entity_dict["description_embedding"])
  1627. if entity_dict.get("description_embedding") # type: ignore
  1628. else None
  1629. )
  1630. cleaned_entities.append(entity_dict)
  1631. return await _add_objects(
  1632. objects=cleaned_entities,
  1633. full_table_name=self._get_table_name(table_name),
  1634. connection_manager=self.connection_manager,
  1635. conflict_columns=conflict_columns,
  1636. )
  1637. async def delete_node_via_document_id(
  1638. self, document_id: UUID, collection_id: UUID
  1639. ) -> None:
  1640. # don't delete if status is PROCESSING.
  1641. QUERY = f"""
  1642. SELECT graph_cluster_status FROM {self._get_table_name("collections")} WHERE id = $1
  1643. """
  1644. status = (
  1645. await self.connection_manager.fetch_query(QUERY, [collection_id])
  1646. )[0]["graph_cluster_status"]
  1647. if status == KGExtractionStatus.PROCESSING.value:
  1648. return
  1649. # Execute separate DELETE queries
  1650. delete_queries = [
  1651. f"""DELETE FROM {self._get_table_name("documents_relationships")} WHERE parent_id = $1""",
  1652. f"""DELETE FROM {self._get_table_name("documents_entities")} WHERE parent_id = $1""",
  1653. ]
  1654. for query in delete_queries:
  1655. await self.connection_manager.execute_query(query, [document_id])
  1656. return None
  1657. async def get_all_relationships(
  1658. self,
  1659. collection_id: UUID | None,
  1660. graph_id: UUID | None,
  1661. document_ids: Optional[list[UUID]] = None,
  1662. ) -> list[Relationship]:
  1663. QUERY = f"""
  1664. SELECT id, subject, predicate, weight, object, parent_id FROM {self._get_table_name("graphs_relationships")} WHERE parent_id = ANY($1)
  1665. """
  1666. relationships = await self.connection_manager.fetch_query(
  1667. QUERY, [collection_id]
  1668. )
  1669. return [Relationship(**relationship) for relationship in relationships]
  1670. async def has_document(self, graph_id: UUID, document_id: UUID) -> bool:
  1671. """
  1672. Check if a document exists in the graph's document_ids array.
  1673. Args:
  1674. graph_id (UUID): ID of the graph to check
  1675. document_id (UUID): ID of the document to look for
  1676. Returns:
  1677. bool: True if document exists in graph, False otherwise
  1678. Raises:
  1679. R2RException: If graph not found
  1680. """
  1681. QUERY = f"""
  1682. SELECT EXISTS (
  1683. SELECT 1
  1684. FROM {self._get_table_name("graphs")}
  1685. WHERE id = $1
  1686. AND document_ids IS NOT NULL
  1687. AND $2 = ANY(document_ids)
  1688. ) as exists;
  1689. """
  1690. result = await self.connection_manager.fetchrow_query(
  1691. QUERY, [graph_id, document_id]
  1692. )
  1693. if result is None:
  1694. raise R2RException(f"Graph {graph_id} not found", 404)
  1695. return result["exists"]
  1696. async def get_communities(
  1697. self,
  1698. parent_id: UUID,
  1699. offset: int,
  1700. limit: int,
  1701. community_ids: Optional[list[UUID]] = None,
  1702. include_embeddings: bool = False,
  1703. ) -> tuple[list[Community], int]:
  1704. """
  1705. Get communities for a graph.
  1706. Args:
  1707. collection_id: UUID of the collection
  1708. offset: Number of records to skip
  1709. limit: Maximum number of records to return (-1 for no limit)
  1710. community_ids: Optional list of community IDs to filter by
  1711. include_embeddings: Whether to include embeddings in the response
  1712. Returns:
  1713. Tuple of (list of communities, total count)
  1714. """
  1715. conditions = ["collection_id = $1"]
  1716. params: list[Any] = [parent_id]
  1717. param_index = 2
  1718. if community_ids:
  1719. conditions.append(f"id = ANY(${param_index})")
  1720. params.append(community_ids)
  1721. param_index += 1
  1722. select_fields = """
  1723. id, collection_id, name, summary, findings, rating, rating_explanation
  1724. """
  1725. if include_embeddings:
  1726. select_fields += ", description_embedding"
  1727. COUNT_QUERY = f"""
  1728. SELECT COUNT(*)
  1729. FROM {self._get_table_name("graphs_communities")}
  1730. WHERE {' AND '.join(conditions)}
  1731. """
  1732. count = (
  1733. await self.connection_manager.fetch_query(COUNT_QUERY, params)
  1734. )[0]["count"]
  1735. QUERY = f"""
  1736. SELECT {select_fields}
  1737. FROM {self._get_table_name("graphs_communities")}
  1738. WHERE {' AND '.join(conditions)}
  1739. ORDER BY created_at
  1740. OFFSET ${param_index}
  1741. """
  1742. params.append(offset)
  1743. param_index += 1
  1744. if limit != -1:
  1745. QUERY += f" LIMIT ${param_index}"
  1746. params.append(limit)
  1747. rows = await self.connection_manager.fetch_query(QUERY, params)
  1748. communities = []
  1749. for row in rows:
  1750. community_dict = dict(row)
  1751. communities.append(Community(**community_dict))
  1752. return communities, count
  1753. async def add_community(self, community: Community) -> None:
  1754. # TODO: Fix in the short term.
  1755. # we need to do this because postgres insert needs to be a string
  1756. community.description_embedding = str(community.description_embedding) # type: ignore[assignment]
  1757. non_null_attrs = {
  1758. k: v for k, v in community.__dict__.items() if v is not None
  1759. }
  1760. columns = ", ".join(non_null_attrs.keys())
  1761. placeholders = ", ".join(f"${i+1}" for i in range(len(non_null_attrs)))
  1762. conflict_columns = ", ".join(
  1763. [f"{k} = EXCLUDED.{k}" for k in non_null_attrs]
  1764. )
  1765. QUERY = f"""
  1766. INSERT INTO {self._get_table_name("graphs_communities")} ({columns})
  1767. VALUES ({placeholders})
  1768. ON CONFLICT (community_id, level, collection_id) DO UPDATE SET
  1769. {conflict_columns}
  1770. """
  1771. await self.connection_manager.execute_many(
  1772. QUERY, [tuple(non_null_attrs.values())]
  1773. )
  1774. async def delete_graph_for_collection(
  1775. self, collection_id: UUID, cascade: bool = False
  1776. ) -> None:
  1777. # don't delete if status is PROCESSING.
  1778. QUERY = f"""
  1779. SELECT graph_cluster_status FROM {self._get_table_name("collections")} WHERE id = $1
  1780. """
  1781. status = (
  1782. await self.connection_manager.fetch_query(QUERY, [collection_id])
  1783. )[0]["graph_cluster_status"]
  1784. if status == KGExtractionStatus.PROCESSING.value:
  1785. return
  1786. # remove all relationships for these documents.
  1787. DELETE_QUERIES = [
  1788. f"DELETE FROM {self._get_table_name('graphs_communities')} WHERE collection_id = $1;",
  1789. ]
  1790. # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
  1791. document_ids_response = (
  1792. await self.collections_handler.documents_in_collection(
  1793. offset=0,
  1794. limit=100,
  1795. collection_id=collection_id,
  1796. )
  1797. )
  1798. # This type ignore is due to insufficient typing of the documents_in_collection method
  1799. document_ids = [doc.id for doc in document_ids_response["results"]] # type: ignore
  1800. # TODO: make these queries more efficient. Pass the document_ids as params.
  1801. if cascade:
  1802. DELETE_QUERIES += [
  1803. f"DELETE FROM {self._get_table_name('graphs_relationships')} WHERE document_id = ANY($1::uuid[]);",
  1804. f"DELETE FROM {self._get_table_name('graphs_entities')} WHERE document_id = ANY($1::uuid[]);",
  1805. f"DELETE FROM {self._get_table_name('graphs_entities')} WHERE collection_id = $1;",
  1806. ]
  1807. # setting the kg_creation_status to PENDING for this collection.
  1808. QUERY = f"""
  1809. UPDATE {self._get_table_name("documents")} SET extraction_status = $1 WHERE $2::uuid = ANY(collection_ids)
  1810. """
  1811. await self.connection_manager.execute_query(
  1812. QUERY, [KGExtractionStatus.PENDING, collection_id]
  1813. )
  1814. if document_ids:
  1815. for query in DELETE_QUERIES:
  1816. if "community" in query or "graphs_entities" in query:
  1817. await self.connection_manager.execute_query(
  1818. query, [collection_id]
  1819. )
  1820. else:
  1821. await self.connection_manager.execute_query(
  1822. query, [document_ids]
  1823. )
  1824. # set status to PENDING for this collection.
  1825. QUERY = f"""
  1826. UPDATE {self._get_table_name("collections")} SET graph_cluster_status = $1 WHERE id = $2
  1827. """
  1828. await self.connection_manager.execute_query(
  1829. QUERY, [KGExtractionStatus.PENDING, collection_id]
  1830. )
  1831. async def perform_graph_clustering(
  1832. self,
  1833. collection_id: UUID,
  1834. leiden_params: dict[str, Any],
  1835. clustering_mode: str,
  1836. ) -> Tuple[int, Any]:
  1837. """
  1838. Calls the external clustering service to cluster the KG.
  1839. """
  1840. offset = 0
  1841. page_size = 1000
  1842. all_relationships = []
  1843. while True:
  1844. relationships, count = await self.relationships.get(
  1845. parent_id=collection_id,
  1846. store_type=StoreType.GRAPHS,
  1847. offset=offset,
  1848. limit=page_size,
  1849. )
  1850. if not relationships:
  1851. break
  1852. all_relationships.extend(relationships)
  1853. offset += len(relationships)
  1854. if offset >= count:
  1855. break
  1856. relationship_ids_cache = await self._get_relationship_ids_cache(
  1857. all_relationships
  1858. )
  1859. logger.info(
  1860. f"Clustering over {len(all_relationships)} relationships for {collection_id} with settings: {leiden_params}"
  1861. )
  1862. return await self._cluster_and_add_community_info(
  1863. relationships=all_relationships,
  1864. relationship_ids_cache=relationship_ids_cache,
  1865. leiden_params=leiden_params,
  1866. collection_id=collection_id,
  1867. clustering_mode=clustering_mode,
  1868. )
  1869. async def _call_clustering_service(
  1870. self, relationships: list[Relationship], leiden_params: dict[str, Any]
  1871. ) -> list[dict]:
  1872. """
  1873. Calls the external Graspologic clustering service, sending relationships and parameters.
  1874. Expects a response with 'communities' field.
  1875. """
  1876. # Convert relationships to a JSON-friendly format
  1877. rel_data = []
  1878. for r in relationships:
  1879. rel_data.append(
  1880. {
  1881. "id": str(r.id),
  1882. "subject": r.subject,
  1883. "object": r.object,
  1884. "weight": r.weight if r.weight is not None else 1.0,
  1885. }
  1886. )
  1887. endpoint = os.environ.get("CLUSTERING_SERVICE_URL")
  1888. if not endpoint:
  1889. raise ValueError("CLUSTERING_SERVICE_URL not set.")
  1890. url = f"{endpoint}/cluster"
  1891. payload = {"relationships": rel_data, "leiden_params": leiden_params}
  1892. async with httpx.AsyncClient() as client:
  1893. response = await client.post(url, json=payload, timeout=3600)
  1894. response.raise_for_status()
  1895. data = response.json()
  1896. communities = data.get("communities", [])
  1897. return communities
  1898. async def _create_graph_and_cluster(
  1899. self,
  1900. relationships: list[Relationship],
  1901. leiden_params: dict[str, Any],
  1902. clustering_mode: str = "remote",
  1903. ) -> Any:
  1904. """
  1905. Create a graph and cluster it. If clustering_mode='local', use hierarchical_leiden locally.
  1906. If clustering_mode='remote', call the external service.
  1907. """
  1908. if clustering_mode == "remote":
  1909. logger.info("Sending request to external clustering service...")
  1910. communities = await self._call_clustering_service(
  1911. relationships, leiden_params
  1912. )
  1913. logger.info("Received communities from clustering service.")
  1914. return communities
  1915. else:
  1916. # Local mode: run hierarchical_leiden directly
  1917. G = self.nx.Graph()
  1918. for relationship in relationships:
  1919. G.add_edge(
  1920. relationship.subject,
  1921. relationship.object,
  1922. weight=relationship.weight,
  1923. id=relationship.id,
  1924. )
  1925. logger.info(
  1926. f"Graph has {len(G.nodes)} nodes and {len(G.edges)} edges"
  1927. )
  1928. return await self._compute_leiden_communities(G, leiden_params)
  1929. async def _cluster_and_add_community_info(
  1930. self,
  1931. relationships: list[Relationship],
  1932. relationship_ids_cache: dict[str, list[int]],
  1933. leiden_params: dict[str, Any],
  1934. collection_id: Optional[UUID] = None,
  1935. clustering_mode: str = "local",
  1936. ) -> Tuple[int, Any]:
  1937. # clear if there is any old information
  1938. conditions = []
  1939. if collection_id is not None:
  1940. conditions.append("collection_id = $1")
  1941. await asyncio.sleep(0.1)
  1942. start_time = time.time()
  1943. logger.info(f"Creating graph and clustering for {collection_id}")
  1944. hierarchical_communities = await self._create_graph_and_cluster(
  1945. relationships=relationships,
  1946. leiden_params=leiden_params,
  1947. clustering_mode=clustering_mode,
  1948. )
  1949. logger.info(
  1950. f"Computing Leiden communities completed, time {time.time() - start_time:.2f} seconds."
  1951. )
  1952. def relationship_ids(node: str) -> list[int]:
  1953. return relationship_ids_cache.get(node, [])
  1954. logger.info(
  1955. f"Cached {len(relationship_ids_cache)} relationship ids, time {time.time() - start_time:.2f} seconds."
  1956. )
  1957. # If remote: hierarchical_communities is a list of dicts like:
  1958. # [{"node": str, "cluster": int, "level": int}, ...]
  1959. # If local: hierarchical_communities is the returned structure from hierarchical_leiden (list of named tuples)
  1960. if clustering_mode == "remote":
  1961. if not hierarchical_communities:
  1962. num_communities = 0
  1963. else:
  1964. num_communities = (
  1965. max(item["cluster"] for item in hierarchical_communities)
  1966. + 1
  1967. )
  1968. else:
  1969. # Local mode: hierarchical_communities returned by hierarchical_leiden
  1970. # According to the original code, it's likely a list of items with .cluster attribute
  1971. if not hierarchical_communities:
  1972. num_communities = 0
  1973. else:
  1974. num_communities = (
  1975. max(item.cluster for item in hierarchical_communities) + 1
  1976. )
  1977. logger.info(
  1978. f"Generated {num_communities} communities, time {time.time() - start_time:.2f} seconds."
  1979. )
  1980. return num_communities, hierarchical_communities
  1981. async def _get_relationship_ids_cache(
  1982. self, relationships: list[Relationship]
  1983. ) -> dict[str, list[int]]:
  1984. relationship_ids_cache: dict[str, list[int]] = {}
  1985. for relationship in relationships:
  1986. if relationship.subject is not None:
  1987. relationship_ids_cache.setdefault(relationship.subject, [])
  1988. if relationship.id is not None:
  1989. relationship_ids_cache[relationship.subject].append(
  1990. relationship.id
  1991. )
  1992. if relationship.object is not None:
  1993. relationship_ids_cache.setdefault(relationship.object, [])
  1994. if relationship.id is not None:
  1995. relationship_ids_cache[relationship.object].append(
  1996. relationship.id
  1997. )
  1998. return relationship_ids_cache
  1999. async def get_entity_map(
  2000. self, offset: int, limit: int, document_id: UUID
  2001. ) -> dict[str, dict[str, list[dict[str, Any]]]]:
  2002. QUERY1 = f"""
  2003. WITH entities_list AS (
  2004. SELECT DISTINCT name
  2005. FROM {self._get_table_name("documents_entities")}
  2006. WHERE parent_id = $1
  2007. ORDER BY name ASC
  2008. LIMIT {limit} OFFSET {offset}
  2009. )
  2010. SELECT e.name, e.description, e.category,
  2011. (SELECT array_agg(DISTINCT x) FROM unnest(e.chunk_ids) x) AS chunk_ids,
  2012. e.parent_id
  2013. FROM {self._get_table_name("documents_entities")} e
  2014. JOIN entities_list el ON e.name = el.name
  2015. GROUP BY e.name, e.description, e.category, e.chunk_ids, e.parent_id
  2016. ORDER BY e.name;"""
  2017. entities_list = await self.connection_manager.fetch_query(
  2018. QUERY1, [document_id]
  2019. )
  2020. entities_list = [Entity(**entity) for entity in entities_list]
  2021. QUERY2 = f"""
  2022. WITH entities_list AS (
  2023. SELECT DISTINCT name
  2024. FROM {self._get_table_name("documents_entities")}
  2025. WHERE parent_id = $1
  2026. ORDER BY name ASC
  2027. LIMIT {limit} OFFSET {offset}
  2028. )
  2029. SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description,
  2030. (SELECT array_agg(DISTINCT x) FROM unnest(t.chunk_ids) x) AS chunk_ids, t.parent_id
  2031. FROM {self._get_table_name("documents_relationships")} t
  2032. JOIN entities_list el ON t.subject = el.name
  2033. ORDER BY t.subject, t.predicate, t.object;
  2034. """
  2035. relationships_list = await self.connection_manager.fetch_query(
  2036. QUERY2, [document_id]
  2037. )
  2038. relationships_list = [
  2039. Relationship(**relationship) for relationship in relationships_list
  2040. ]
  2041. entity_map: dict[str, dict[str, list[Any]]] = {}
  2042. for entity in entities_list:
  2043. if entity.name not in entity_map:
  2044. entity_map[entity.name] = {"entities": [], "relationships": []}
  2045. entity_map[entity.name]["entities"].append(entity)
  2046. for relationship in relationships_list:
  2047. if relationship.subject in entity_map:
  2048. entity_map[relationship.subject]["relationships"].append(
  2049. relationship
  2050. )
  2051. if relationship.object in entity_map:
  2052. entity_map[relationship.object]["relationships"].append(
  2053. relationship
  2054. )
  2055. return entity_map
  2056. async def graph_search(
  2057. self, query: str, **kwargs: Any
  2058. ) -> AsyncGenerator[Any, None]:
  2059. """
  2060. Perform semantic search with similarity scores while maintaining exact same structure.
  2061. """
  2062. query_embedding = kwargs.get("query_embedding", None)
  2063. if query_embedding is None:
  2064. raise ValueError(
  2065. "query_embedding must be provided for semantic search"
  2066. )
  2067. search_type = kwargs.get(
  2068. "search_type", "entities"
  2069. ) # entities | relationships | communities
  2070. embedding_type = kwargs.get("embedding_type", "description_embedding")
  2071. property_names = kwargs.get("property_names", ["name", "description"])
  2072. # Add metadata if not present
  2073. if "metadata" not in property_names:
  2074. property_names.append("metadata")
  2075. filters = kwargs.get("filters", {})
  2076. limit = kwargs.get("limit", 10)
  2077. use_fulltext_search = kwargs.get("use_fulltext_search", True)
  2078. use_hybrid_search = kwargs.get("use_hybrid_search", True)
  2079. if use_hybrid_search or use_fulltext_search:
  2080. logger.warning(
  2081. "Hybrid and fulltext search not supported for graph search, ignoring."
  2082. )
  2083. table_name = f"graphs_{search_type}"
  2084. property_names_str = ", ".join(property_names)
  2085. # Build the WHERE clause from filters
  2086. params: list[Union[str, int, bytes]] = [
  2087. json.dumps(query_embedding),
  2088. limit,
  2089. ]
  2090. conditions_clause = self._build_filters(filters, params, search_type)
  2091. where_clause = (
  2092. f"WHERE {conditions_clause}" if conditions_clause else ""
  2093. )
  2094. # Construct the query
  2095. # Note: For vector similarity, we use <=> for distance. The smaller the number, the more similar.
  2096. # We'll convert that to similarity_score by doing (1 - distance).
  2097. QUERY = f"""
  2098. SELECT
  2099. {property_names_str},
  2100. ({embedding_type} <=> $1) as similarity_score
  2101. FROM {self._get_table_name(table_name)}
  2102. {where_clause}
  2103. ORDER BY {embedding_type} <=> $1
  2104. LIMIT $2;
  2105. """
  2106. results = await self.connection_manager.fetch_query(
  2107. QUERY, tuple(params)
  2108. )
  2109. for result in results:
  2110. output = {
  2111. prop: result[prop] for prop in property_names if prop in result
  2112. }
  2113. output["similarity_score"] = 1 - float(result["similarity_score"])
  2114. yield output
  2115. def _build_filters(
  2116. self, filter_dict: dict, parameters: list[Any], search_type: str
  2117. ) -> str:
  2118. """
  2119. Build a WHERE clause from a nested filter dictionary for the graph search.
  2120. For communities we use collection_id as primary key filter; for entities/relationships we use parent_id.
  2121. """
  2122. # Determine primary identifier column depending on search_type
  2123. # communities: use collection_id
  2124. # entities/relationships: use parent_id
  2125. base_id_column = (
  2126. "collection_id" if search_type == "communities" else "parent_id"
  2127. )
  2128. def parse_condition(key: str, value: Any) -> str:
  2129. # This function returns a single condition (string) or empty if no valid condition.
  2130. # Supported keys:
  2131. # - base_id_column (collection_id or parent_id)
  2132. # - metadata fields: metadata.some_field
  2133. # Supported ops: $eq, $ne, $lt, $lte, $gt, $gte, $in, $contains
  2134. if key == base_id_column:
  2135. # e.g. {"collection_id": {"$eq": "<some-uuid>"}}
  2136. if isinstance(value, dict):
  2137. op, clause = next(iter(value.items()))
  2138. if op == "$eq":
  2139. parameters.append(str(clause))
  2140. return f"{base_id_column} = ${len(parameters)}::uuid"
  2141. elif op == "$in":
  2142. # $in expects a list of UUIDs
  2143. parameters.append([str(x) for x in clause])
  2144. return f"{base_id_column} = ANY(${len(parameters)}::uuid[])"
  2145. else:
  2146. # direct equality?
  2147. parameters.append(str(value))
  2148. return f"{base_id_column} = ${len(parameters)}::uuid"
  2149. elif key.startswith("metadata."):
  2150. # Handle metadata filters
  2151. # Example: {"metadata.some_key": {"$eq": "value"}}
  2152. field = key.split("metadata.")[1]
  2153. if isinstance(value, dict):
  2154. op, clause = next(iter(value.items()))
  2155. if op == "$eq":
  2156. parameters.append(clause)
  2157. return f"(metadata->>'{field}') = ${len(parameters)}"
  2158. elif op == "$ne":
  2159. parameters.append(clause)
  2160. return f"(metadata->>'{field}') != ${len(parameters)}"
  2161. elif op == "$lt":
  2162. parameters.append(clause)
  2163. return f"(metadata->>'{field}')::float < ${len(parameters)}::float"
  2164. elif op == "$lte":
  2165. parameters.append(clause)
  2166. return f"(metadata->>'{field}')::float <= ${len(parameters)}::float"
  2167. elif op == "$gt":
  2168. parameters.append(clause)
  2169. return f"(metadata->>'{field}')::float > ${len(parameters)}::float"
  2170. elif op == "$gte":
  2171. parameters.append(clause)
  2172. return f"(metadata->>'{field}')::float >= ${len(parameters)}::float"
  2173. elif op == "$in":
  2174. # Ensure clause is a list
  2175. if not isinstance(clause, list):
  2176. raise Exception(
  2177. "argument to $in filter must be a list"
  2178. )
  2179. # Append the Python list as a parameter; many drivers can convert Python lists to arrays
  2180. parameters.append(clause)
  2181. # Cast the parameter to a text array type
  2182. return f"(metadata->>'{key}')::text = ANY(${len(parameters)}::text[])"
  2183. # elif op == "$in":
  2184. # # For $in, we assume an array of values and check if the field is in that set.
  2185. # # Note: This is simplistic, adjust as needed.
  2186. # parameters.append(clause)
  2187. # # convert field to text and check membership
  2188. # return f"(metadata->>'{field}') = ANY(SELECT jsonb_array_elements_text(${len(parameters)}::jsonb))"
  2189. elif op == "$contains":
  2190. # $contains for metadata likely means metadata @> clause in JSON.
  2191. # If clause is dict or list, we use json containment.
  2192. parameters.append(json.dumps(clause))
  2193. return f"metadata @> ${len(parameters)}::jsonb"
  2194. else:
  2195. # direct equality
  2196. parameters.append(value)
  2197. return f"(metadata->>'{field}') = ${len(parameters)}"
  2198. # Add additional conditions for other columns if needed
  2199. # If key not recognized, return empty so it doesn't break query
  2200. return ""
  2201. def parse_filter(fd: dict) -> str:
  2202. filter_conditions = []
  2203. for k, v in fd.items():
  2204. if k == "$and":
  2205. and_parts = [parse_filter(sub) for sub in v if sub]
  2206. # Remove empty strings
  2207. and_parts = [x for x in and_parts if x.strip()]
  2208. if and_parts:
  2209. filter_conditions.append(
  2210. f"({' AND '.join(and_parts)})"
  2211. )
  2212. elif k == "$or":
  2213. or_parts = [parse_filter(sub) for sub in v if sub]
  2214. # Remove empty strings
  2215. or_parts = [x for x in or_parts if x.strip()]
  2216. if or_parts:
  2217. filter_conditions.append(f"({' OR '.join(or_parts)})")
  2218. else:
  2219. # Regular condition
  2220. c = parse_condition(k, v)
  2221. if c and c.strip():
  2222. filter_conditions.append(c)
  2223. if not filter_conditions:
  2224. return ""
  2225. if len(filter_conditions) == 1:
  2226. return filter_conditions[0]
  2227. return " AND ".join(filter_conditions)
  2228. return parse_filter(filter_dict)
  2229. # async def _create_graph_and_cluster(
  2230. # self, relationships: list[Relationship], leiden_params: dict[str, Any]
  2231. # ) -> Any:
  2232. # G = self.nx.Graph()
  2233. # for relationship in relationships:
  2234. # G.add_edge(
  2235. # relationship.subject,
  2236. # relationship.object,
  2237. # weight=relationship.weight,
  2238. # id=relationship.id,
  2239. # )
  2240. # logger.info(f"Graph has {len(G.nodes)} nodes and {len(G.edges)} edges")
  2241. # return await self._compute_leiden_communities(G, leiden_params)
  2242. async def _compute_leiden_communities(
  2243. self,
  2244. graph: Any,
  2245. leiden_params: dict[str, Any],
  2246. ) -> Any:
  2247. """Compute Leiden communities."""
  2248. try:
  2249. from graspologic.partition import hierarchical_leiden
  2250. if "random_seed" not in leiden_params:
  2251. leiden_params["random_seed"] = (
  2252. 7272 # add seed to control randomness
  2253. )
  2254. start_time = time.time()
  2255. logger.info(
  2256. f"Running Leiden clustering with params: {leiden_params}"
  2257. )
  2258. community_mapping = hierarchical_leiden(graph, **leiden_params)
  2259. logger.info(
  2260. f"Leiden clustering completed in {time.time() - start_time:.2f} seconds."
  2261. )
  2262. return community_mapping
  2263. except ImportError as e:
  2264. raise ImportError("Please install the graspologic package.") from e
  2265. async def get_existing_document_entity_chunk_ids(
  2266. self, document_id: UUID
  2267. ) -> list[str]:
  2268. QUERY = f"""
  2269. SELECT DISTINCT unnest(chunk_ids) AS chunk_id FROM {self._get_table_name("documents_entities")} WHERE parent_id = $1
  2270. """
  2271. return [
  2272. item["chunk_id"]
  2273. for item in await self.connection_manager.fetch_query(
  2274. QUERY, [document_id]
  2275. )
  2276. ]
  2277. async def get_entity_count(
  2278. self,
  2279. collection_id: Optional[UUID] = None,
  2280. document_id: Optional[UUID] = None,
  2281. distinct: bool = False,
  2282. entity_table_name: str = "entity",
  2283. ) -> int:
  2284. if collection_id is None and document_id is None:
  2285. raise ValueError(
  2286. "Either collection_id or document_id must be provided."
  2287. )
  2288. conditions = ["parent_id = $1"]
  2289. params = [str(document_id)]
  2290. count_value = "DISTINCT name" if distinct else "*"
  2291. QUERY = f"""
  2292. SELECT COUNT({count_value}) FROM {self._get_table_name(entity_table_name)}
  2293. WHERE {" AND ".join(conditions)}
  2294. """
  2295. return (await self.connection_manager.fetch_query(QUERY, params))[0][
  2296. "count"
  2297. ]
  2298. async def update_entity_descriptions(self, entities: list[Entity]):
  2299. query = f"""
  2300. UPDATE {self._get_table_name("graphs_entities")}
  2301. SET description = $3, description_embedding = $4
  2302. WHERE name = $1 AND graph_id = $2
  2303. """
  2304. inputs = [
  2305. (
  2306. entity.name,
  2307. entity.parent_id,
  2308. entity.description,
  2309. entity.description_embedding,
  2310. )
  2311. for entity in entities
  2312. ]
  2313. await self.connection_manager.execute_many(query, inputs) # type: ignore
  2314. def _json_serialize(obj):
  2315. if isinstance(obj, UUID):
  2316. return str(obj)
  2317. elif isinstance(obj, (datetime.datetime, datetime.date)):
  2318. return obj.isoformat()
  2319. raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
  2320. async def _add_objects(
  2321. objects: list[dict],
  2322. full_table_name: str,
  2323. connection_manager: PostgresConnectionManager,
  2324. conflict_columns: list[str] = [],
  2325. exclude_metadata: list[str] = [],
  2326. ) -> list[UUID]:
  2327. """
  2328. Bulk insert objects into the specified table using jsonb_to_recordset.
  2329. """
  2330. # Exclude specified metadata and prepare data
  2331. cleaned_objects = []
  2332. for obj in objects:
  2333. cleaned_obj = {
  2334. k: v
  2335. for k, v in obj.items()
  2336. if k not in exclude_metadata and v is not None
  2337. }
  2338. cleaned_objects.append(cleaned_obj)
  2339. # Serialize the list of objects to JSON
  2340. json_data = json.dumps(cleaned_objects, default=_json_serialize)
  2341. # Prepare the column definitions for jsonb_to_recordset
  2342. columns = cleaned_objects[0].keys()
  2343. column_defs = []
  2344. for col in columns:
  2345. # Map Python types to PostgreSQL types
  2346. sample_value = cleaned_objects[0][col]
  2347. if "embedding" in col:
  2348. pg_type = "vector"
  2349. elif "chunk_ids" in col or "document_ids" in col or "graph_ids" in col:
  2350. pg_type = "uuid[]"
  2351. elif col == "id" or "_id" in col:
  2352. pg_type = "uuid"
  2353. elif isinstance(sample_value, str):
  2354. pg_type = "text"
  2355. elif isinstance(sample_value, UUID):
  2356. pg_type = "uuid"
  2357. elif isinstance(sample_value, (int, float)):
  2358. pg_type = "numeric"
  2359. elif isinstance(sample_value, list) and all(
  2360. isinstance(x, UUID) for x in sample_value
  2361. ):
  2362. pg_type = "uuid[]"
  2363. elif isinstance(sample_value, list):
  2364. pg_type = "jsonb"
  2365. elif isinstance(sample_value, dict):
  2366. pg_type = "jsonb"
  2367. elif isinstance(sample_value, bool):
  2368. pg_type = "boolean"
  2369. elif isinstance(sample_value, (datetime.datetime, datetime.date)):
  2370. pg_type = "timestamp"
  2371. else:
  2372. raise TypeError(
  2373. f"Unsupported data type for column '{col}': {type(sample_value)}"
  2374. )
  2375. column_defs.append(f"{col} {pg_type}")
  2376. columns_str = ", ".join(columns)
  2377. column_defs_str = ", ".join(column_defs)
  2378. if conflict_columns:
  2379. conflict_columns_str = ", ".join(conflict_columns)
  2380. update_columns_str = ", ".join(
  2381. f"{col}=EXCLUDED.{col}"
  2382. for col in columns
  2383. if col not in conflict_columns
  2384. )
  2385. on_conflict_clause = f"ON CONFLICT ({conflict_columns_str}) DO UPDATE SET {update_columns_str}"
  2386. else:
  2387. on_conflict_clause = ""
  2388. QUERY = f"""
  2389. INSERT INTO {full_table_name} ({columns_str})
  2390. SELECT {columns_str}
  2391. FROM jsonb_to_recordset($1::jsonb)
  2392. AS x({column_defs_str})
  2393. {on_conflict_clause}
  2394. RETURNING id;
  2395. """
  2396. # Execute the query
  2397. result = await connection_manager.fetch_query(QUERY, [json_data])
  2398. # Extract and return the IDs
  2399. return [record["id"] for record in result]