1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923 |
- import asyncio
- import contextlib
- import csv
- import datetime
- import json
- import logging
- import os
- import tempfile
- import time
- from typing import IO, Any, AsyncGenerator, Optional, Tuple
- from uuid import UUID
- import asyncpg
- import httpx
- from asyncpg.exceptions import UndefinedTableError, UniqueViolationError
- from fastapi import HTTPException
- from core.base.abstractions import (
- Community,
- Entity,
- Graph,
- KGCreationSettings,
- KGEnrichmentSettings,
- KGExtractionStatus,
- R2RException,
- Relationship,
- StoreType,
- VectorQuantizationType,
- )
- from core.base.api.models import GraphResponse
- from core.base.providers.database import Handler
- from core.base.utils import (
- _decorate_vector_type,
- _get_str_estimation_output,
- llm_cost_per_million_tokens,
- )
- from .base import PostgresConnectionManager
- from .collections import PostgresCollectionsHandler
- logger = logging.getLogger()
- class PostgresEntitiesHandler(Handler):
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- self.project_name: str = kwargs.get("project_name") # type: ignore
- self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore
- self.dimension: int = kwargs.get("dimension") # type: ignore
- self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore
- def _get_table_name(self, table: str) -> str:
- """Get the fully qualified table name."""
- return f'"{self.project_name}"."{table}"'
- def _get_entity_table_for_store(self, store_type: StoreType) -> str:
- """Get the appropriate table name for the store type."""
- return f"{store_type.value}_entities"
- def _get_parent_constraint(self, store_type: StoreType) -> str:
- """Get the appropriate foreign key constraint for the store type."""
- if store_type == StoreType.GRAPHS:
- return f"""
- CONSTRAINT fk_graph
- FOREIGN KEY(parent_id)
- REFERENCES {self._get_table_name("graphs")}(id)
- ON DELETE CASCADE
- """
- else:
- return f"""
- CONSTRAINT fk_document
- FOREIGN KEY(parent_id)
- REFERENCES {self._get_table_name("documents")}(id)
- ON DELETE CASCADE
- """
- async def create_tables(self) -> None:
- """Create separate tables for graph and document entities."""
- vector_column_str = _decorate_vector_type(
- f"({self.dimension})", self.quantization_type
- )
- for store_type in StoreType:
- table_name = self._get_entity_table_for_store(store_type)
- parent_constraint = self._get_parent_constraint(store_type)
- QUERY = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
- id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
- name TEXT NOT NULL,
- category TEXT,
- description TEXT,
- parent_id UUID NOT NULL,
- description_embedding {vector_column_str},
- chunk_ids UUID[],
- metadata JSONB,
- created_at TIMESTAMPTZ DEFAULT NOW(),
- updated_at TIMESTAMPTZ DEFAULT NOW(),
- {parent_constraint}
- );
- CREATE INDEX IF NOT EXISTS {table_name}_name_idx
- ON {self._get_table_name(table_name)} (name);
- CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
- ON {self._get_table_name(table_name)} (parent_id);
- CREATE INDEX IF NOT EXISTS {table_name}_category_idx
- ON {self._get_table_name(table_name)} (category);
- """
- await self.connection_manager.execute_query(QUERY)
- async def create(
- self,
- parent_id: UUID,
- store_type: StoreType,
- name: str,
- category: Optional[str] = None,
- description: Optional[str] = None,
- description_embedding: Optional[list[float] | str] = None,
- chunk_ids: Optional[list[UUID]] = None,
- metadata: Optional[dict[str, Any] | str] = None,
- ) -> Entity:
- """Create a new entity in the specified store."""
- table_name = self._get_entity_table_for_store(store_type)
- if isinstance(metadata, str):
- with contextlib.suppress(json.JSONDecodeError):
- metadata = json.loads(metadata)
- if isinstance(description_embedding, list):
- description_embedding = str(description_embedding)
- query = f"""
- INSERT INTO {self._get_table_name(table_name)}
- (name, category, description, parent_id, description_embedding, chunk_ids, metadata)
- VALUES ($1, $2, $3, $4, $5, $6, $7)
- RETURNING id, name, category, description, parent_id, chunk_ids, metadata
- """
- params = [
- name,
- category,
- description,
- parent_id,
- description_embedding,
- chunk_ids,
- json.dumps(metadata) if metadata else None,
- ]
- result = await self.connection_manager.fetchrow_query(
- query=query,
- params=params,
- )
- return Entity(
- id=result["id"],
- name=result["name"],
- category=result["category"],
- description=result["description"],
- parent_id=result["parent_id"],
- chunk_ids=result["chunk_ids"],
- metadata=result["metadata"],
- )
- async def get(
- self,
- parent_id: UUID,
- store_type: StoreType,
- offset: int,
- limit: int,
- entity_ids: Optional[list[UUID]] = None,
- entity_names: Optional[list[str]] = None,
- include_embeddings: bool = False,
- ):
- """Retrieve entities from the specified store."""
- table_name = self._get_entity_table_for_store(store_type)
- conditions = ["parent_id = $1"]
- params: list[Any] = [parent_id]
- param_index = 2
- if entity_ids:
- conditions.append(f"id = ANY(${param_index})")
- params.append(entity_ids)
- param_index += 1
- if entity_names:
- conditions.append(f"name = ANY(${param_index})")
- params.append(entity_names)
- param_index += 1
- select_fields = """
- id, name, category, description, parent_id,
- chunk_ids, metadata
- """
- if include_embeddings:
- select_fields += ", description_embedding"
- COUNT_QUERY = f"""
- SELECT COUNT(*)
- FROM {self._get_table_name(table_name)}
- WHERE {' AND '.join(conditions)}
- """
- count_params = params[: param_index - 1]
- count = (
- await self.connection_manager.fetch_query(
- COUNT_QUERY, count_params
- )
- )[0]["count"]
- QUERY = f"""
- SELECT {select_fields}
- FROM {self._get_table_name(table_name)}
- WHERE {' AND '.join(conditions)}
- ORDER BY created_at
- OFFSET ${param_index}
- """
- params.append(offset)
- param_index += 1
- if limit != -1:
- QUERY += f" LIMIT ${param_index}"
- params.append(limit)
- rows = await self.connection_manager.fetch_query(QUERY, params)
- entities = []
- for row in rows:
- # Convert the Record to a dictionary
- entity_dict = dict(row)
- # Process metadata if it exists and is a string
- if isinstance(entity_dict["metadata"], str):
- with contextlib.suppress(json.JSONDecodeError):
- entity_dict["metadata"] = json.loads(
- entity_dict["metadata"]
- )
- entities.append(Entity(**entity_dict))
- return entities, count
- async def update(
- self,
- entity_id: UUID,
- store_type: StoreType,
- name: Optional[str] = None,
- description: Optional[str] = None,
- description_embedding: Optional[list[float] | str] = None,
- category: Optional[str] = None,
- metadata: Optional[dict] = None,
- ) -> Entity:
- """Update an entity in the specified store."""
- table_name = self._get_entity_table_for_store(store_type)
- update_fields = []
- params: list[Any] = []
- param_index = 1
- if isinstance(metadata, str):
- with contextlib.suppress(json.JSONDecodeError):
- metadata = json.loads(metadata)
- if name is not None:
- update_fields.append(f"name = ${param_index}")
- params.append(name)
- param_index += 1
- if description is not None:
- update_fields.append(f"description = ${param_index}")
- params.append(description)
- param_index += 1
- if description_embedding is not None:
- update_fields.append(f"description_embedding = ${param_index}")
- params.append(description_embedding)
- param_index += 1
- if category is not None:
- update_fields.append(f"category = ${param_index}")
- params.append(category)
- param_index += 1
- if metadata is not None:
- update_fields.append(f"metadata = ${param_index}")
- params.append(json.dumps(metadata))
- param_index += 1
- if not update_fields:
- raise R2RException(status_code=400, message="No fields to update")
- update_fields.append("updated_at = NOW()")
- params.append(entity_id)
- query = f"""
- UPDATE {self._get_table_name(table_name)}
- SET {', '.join(update_fields)}
- WHERE id = ${param_index}\
- RETURNING id, name, category, description, parent_id, chunk_ids, metadata
- """
- try:
- result = await self.connection_manager.fetchrow_query(
- query=query,
- params=params,
- )
- return Entity(
- id=result["id"],
- name=result["name"],
- category=result["category"],
- description=result["description"],
- parent_id=result["parent_id"],
- chunk_ids=result["chunk_ids"],
- metadata=result["metadata"],
- )
- except Exception as e:
- raise HTTPException(
- status_code=500,
- detail=f"An error occurred while updating the entity: {e}",
- ) from e
- async def delete(
- self,
- parent_id: UUID,
- entity_ids: Optional[list[UUID]] = None,
- store_type: StoreType = StoreType.GRAPHS,
- ) -> None:
- """
- Delete entities from the specified store.
- If entity_ids is not provided, deletes all entities for the given parent_id.
- Args:
- parent_id (UUID): Parent ID (collection_id or document_id)
- entity_ids (Optional[list[UUID]]): Specific entity IDs to delete. If None, deletes all entities for parent_id
- store_type (StoreType): Type of store (graph or document)
- Returns:
- list[UUID]: List of deleted entity IDs
- Raises:
- R2RException: If specific entities were requested but not all found
- """
- table_name = self._get_entity_table_for_store(store_type)
- if entity_ids is None:
- # Delete all entities for the parent_id
- QUERY = f"""
- DELETE FROM {self._get_table_name(table_name)}
- WHERE parent_id = $1
- RETURNING id
- """
- results = await self.connection_manager.fetch_query(
- QUERY, [parent_id]
- )
- else:
- # Delete specific entities
- QUERY = f"""
- DELETE FROM {self._get_table_name(table_name)}
- WHERE id = ANY($1) AND parent_id = $2
- RETURNING id
- """
- results = await self.connection_manager.fetch_query(
- QUERY, [entity_ids, parent_id]
- )
- # Check if all requested entities were deleted
- deleted_ids = [row["id"] for row in results]
- if entity_ids and len(deleted_ids) != len(entity_ids):
- raise R2RException(
- f"Some entities not found in {store_type} store or no permission to delete",
- 404,
- )
- async def export_to_csv(
- self,
- parent_id: UUID,
- store_type: StoreType,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- """
- Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
- """
- valid_columns = {
- "id",
- "name",
- "category",
- "description",
- "parent_id",
- "chunk_ids",
- "metadata",
- "created_at",
- "updated_at",
- }
- if not columns:
- columns = list(valid_columns)
- elif invalid_cols := set(columns) - valid_columns:
- raise ValueError(f"Invalid columns: {invalid_cols}")
- select_stmt = f"""
- SELECT
- id::text,
- name,
- category,
- description,
- parent_id::text,
- chunk_ids::text,
- metadata::text,
- to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
- to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
- FROM {self._get_table_name(self._get_entity_table_for_store(store_type))}
- """
- conditions = ["parent_id = $1"]
- params: list[Any] = [parent_id]
- param_index = 2
- if filters:
- for field, value in filters.items():
- if field not in valid_columns:
- continue
- if isinstance(value, dict):
- for op, val in value.items():
- if op == "$eq":
- conditions.append(f"{field} = ${param_index}")
- params.append(val)
- param_index += 1
- elif op == "$gt":
- conditions.append(f"{field} > ${param_index}")
- params.append(val)
- param_index += 1
- elif op == "$lt":
- conditions.append(f"{field} < ${param_index}")
- params.append(val)
- param_index += 1
- else:
- # Direct equality
- conditions.append(f"{field} = ${param_index}")
- params.append(value)
- param_index += 1
- if conditions:
- select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
- select_stmt = f"{select_stmt} ORDER BY created_at DESC"
- temp_file = None
- try:
- temp_file = tempfile.NamedTemporaryFile(
- mode="w", delete=True, suffix=".csv"
- )
- writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
- async with self.connection_manager.pool.get_connection() as conn: # type: ignore
- async with conn.transaction():
- cursor = await conn.cursor(select_stmt, *params)
- if include_header:
- writer.writerow(columns)
- chunk_size = 1000
- while True:
- rows = await cursor.fetch(chunk_size)
- if not rows:
- break
- for row in rows:
- writer.writerow(row)
- temp_file.flush()
- return temp_file.name, temp_file
- except Exception as e:
- if temp_file:
- temp_file.close()
- raise HTTPException(
- status_code=500,
- detail=f"Failed to export data: {str(e)}",
- ) from e
- class PostgresRelationshipsHandler(Handler):
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- self.project_name: str = kwargs.get("project_name") # type: ignore
- self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore
- self.dimension: int = kwargs.get("dimension") # type: ignore
- self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore
- def _get_table_name(self, table: str) -> str:
- """Get the fully qualified table name."""
- return f'"{self.project_name}"."{table}"'
- def _get_relationship_table_for_store(self, store_type: StoreType) -> str:
- """Get the appropriate table name for the store type."""
- return f"{store_type.value}_relationships"
- def _get_parent_constraint(self, store_type: StoreType) -> str:
- """Get the appropriate foreign key constraint for the store type."""
- if store_type == StoreType.GRAPHS:
- return f"""
- CONSTRAINT fk_graph
- FOREIGN KEY(parent_id)
- REFERENCES {self._get_table_name("graphs")}(id)
- ON DELETE CASCADE
- """
- else:
- return f"""
- CONSTRAINT fk_document
- FOREIGN KEY(parent_id)
- REFERENCES {self._get_table_name("documents")}(id)
- ON DELETE CASCADE
- """
- async def create_tables(self) -> None:
- """Create separate tables for graph and document relationships."""
- for store_type in StoreType:
- table_name = self._get_relationship_table_for_store(store_type)
- parent_constraint = self._get_parent_constraint(store_type)
- vector_column_str = _decorate_vector_type(
- f"({self.dimension})", self.quantization_type
- )
- QUERY = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name(table_name)} (
- id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
- subject TEXT NOT NULL,
- predicate TEXT NOT NULL,
- object TEXT NOT NULL,
- description TEXT,
- description_embedding {vector_column_str},
- subject_id UUID,
- object_id UUID,
- weight FLOAT DEFAULT 1.0,
- chunk_ids UUID[],
- parent_id UUID NOT NULL,
- metadata JSONB,
- created_at TIMESTAMPTZ DEFAULT NOW(),
- updated_at TIMESTAMPTZ DEFAULT NOW(),
- {parent_constraint}
- );
- CREATE INDEX IF NOT EXISTS {table_name}_subject_idx
- ON {self._get_table_name(table_name)} (subject);
- CREATE INDEX IF NOT EXISTS {table_name}_object_idx
- ON {self._get_table_name(table_name)} (object);
- CREATE INDEX IF NOT EXISTS {table_name}_predicate_idx
- ON {self._get_table_name(table_name)} (predicate);
- CREATE INDEX IF NOT EXISTS {table_name}_parent_id_idx
- ON {self._get_table_name(table_name)} (parent_id);
- CREATE INDEX IF NOT EXISTS {table_name}_subject_id_idx
- ON {self._get_table_name(table_name)} (subject_id);
- CREATE INDEX IF NOT EXISTS {table_name}_object_id_idx
- ON {self._get_table_name(table_name)} (object_id);
- """
- await self.connection_manager.execute_query(QUERY)
- async def create(
- self,
- subject: str,
- subject_id: UUID,
- predicate: str,
- object: str,
- object_id: UUID,
- parent_id: UUID,
- store_type: StoreType,
- description: str | None = None,
- weight: float | None = 1.0,
- chunk_ids: Optional[list[UUID]] = None,
- description_embedding: Optional[list[float] | str] = None,
- metadata: Optional[dict[str, Any] | str] = None,
- ) -> Relationship:
- """Create a new relationship in the specified store."""
- table_name = self._get_relationship_table_for_store(store_type)
- if isinstance(metadata, str):
- with contextlib.suppress(json.JSONDecodeError):
- metadata = json.loads(metadata)
- if isinstance(description_embedding, list):
- description_embedding = str(description_embedding)
- query = f"""
- INSERT INTO {self._get_table_name(table_name)}
- (subject, predicate, object, description, subject_id, object_id,
- weight, chunk_ids, parent_id, description_embedding, metadata)
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
- RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
- """
- params = [
- subject,
- predicate,
- object,
- description,
- subject_id,
- object_id,
- weight,
- chunk_ids,
- parent_id,
- description_embedding,
- json.dumps(metadata) if metadata else None,
- ]
- result = await self.connection_manager.fetchrow_query(
- query=query,
- params=params,
- )
- return Relationship(
- id=result["id"],
- subject=result["subject"],
- predicate=result["predicate"],
- object=result["object"],
- description=result["description"],
- subject_id=result["subject_id"],
- object_id=result["object_id"],
- weight=result["weight"],
- chunk_ids=result["chunk_ids"],
- parent_id=result["parent_id"],
- metadata=result["metadata"],
- )
- async def get(
- self,
- parent_id: UUID,
- store_type: StoreType,
- offset: int,
- limit: int,
- relationship_ids: Optional[list[UUID]] = None,
- entity_names: Optional[list[str]] = None,
- relationship_types: Optional[list[str]] = None,
- include_metadata: bool = False,
- ):
- """
- Get relationships from the specified store.
- Args:
- parent_id: UUID of the parent (collection_id or document_id)
- store_type: Type of store (graph or document)
- offset: Number of records to skip
- limit: Maximum number of records to return (-1 for no limit)
- relationship_ids: Optional list of specific relationship IDs to retrieve
- entity_names: Optional list of entity names to filter by (matches subject or object)
- relationship_types: Optional list of relationship types (predicates) to filter by
- include_metadata: Whether to include metadata in the response
- Returns:
- Tuple of (list of relationships, total count)
- """
- table_name = self._get_relationship_table_for_store(store_type)
- conditions = ["parent_id = $1"]
- params: list[Any] = [parent_id]
- param_index = 2
- if relationship_ids:
- conditions.append(f"id = ANY(${param_index})")
- params.append(relationship_ids)
- param_index += 1
- if entity_names:
- conditions.append(
- f"(subject = ANY(${param_index}) OR object = ANY(${param_index}))"
- )
- params.append(entity_names)
- param_index += 1
- if relationship_types:
- conditions.append(f"predicate = ANY(${param_index})")
- params.append(relationship_types)
- param_index += 1
- select_fields = """
- id, subject, predicate, object, description,
- subject_id, object_id, weight, chunk_ids,
- parent_id
- """
- if include_metadata:
- select_fields += ", metadata"
- # Count query
- COUNT_QUERY = f"""
- SELECT COUNT(*)
- FROM {self._get_table_name(table_name)}
- WHERE {' AND '.join(conditions)}
- """
- count_params = params[: param_index - 1]
- count = (
- await self.connection_manager.fetch_query(
- COUNT_QUERY, count_params
- )
- )[0]["count"]
- # Main query
- QUERY = f"""
- SELECT {select_fields}
- FROM {self._get_table_name(table_name)}
- WHERE {' AND '.join(conditions)}
- ORDER BY created_at
- OFFSET ${param_index}
- """
- params.append(offset)
- param_index += 1
- if limit != -1:
- QUERY += f" LIMIT ${param_index}"
- params.append(limit)
- rows = await self.connection_manager.fetch_query(QUERY, params)
- relationships = []
- for row in rows:
- relationship_dict = dict(row)
- if include_metadata and isinstance(
- relationship_dict["metadata"], str
- ):
- with contextlib.suppress(json.JSONDecodeError):
- relationship_dict["metadata"] = json.loads(
- relationship_dict["metadata"]
- )
- elif not include_metadata:
- relationship_dict.pop("metadata", None)
- relationships.append(Relationship(**relationship_dict))
- return relationships, count
- async def update(
- self,
- relationship_id: UUID,
- store_type: StoreType,
- subject: Optional[str],
- subject_id: Optional[UUID],
- predicate: Optional[str],
- object: Optional[str],
- object_id: Optional[UUID],
- description: Optional[str],
- description_embedding: Optional[list[float] | str],
- weight: Optional[float],
- metadata: Optional[dict[str, Any] | str],
- ) -> Relationship:
- """Update multiple relationships in the specified store."""
- table_name = self._get_relationship_table_for_store(store_type)
- update_fields = []
- params: list = []
- param_index = 1
- if isinstance(metadata, str):
- with contextlib.suppress(json.JSONDecodeError):
- metadata = json.loads(metadata)
- if subject is not None:
- update_fields.append(f"subject = ${param_index}")
- params.append(subject)
- param_index += 1
- if subject_id is not None:
- update_fields.append(f"subject_id = ${param_index}")
- params.append(subject_id)
- param_index += 1
- if predicate is not None:
- update_fields.append(f"predicate = ${param_index}")
- params.append(predicate)
- param_index += 1
- if object is not None:
- update_fields.append(f"object = ${param_index}")
- params.append(object)
- param_index += 1
- if object_id is not None:
- update_fields.append(f"object_id = ${param_index}")
- params.append(object_id)
- param_index += 1
- if description is not None:
- update_fields.append(f"description = ${param_index}")
- params.append(description)
- param_index += 1
- if description_embedding is not None:
- update_fields.append(f"description_embedding = ${param_index}")
- params.append(description_embedding)
- param_index += 1
- if weight is not None:
- update_fields.append(f"weight = ${param_index}")
- params.append(weight)
- param_index += 1
- if not update_fields:
- raise R2RException(status_code=400, message="No fields to update")
- update_fields.append("updated_at = NOW()")
- params.append(relationship_id)
- query = f"""
- UPDATE {self._get_table_name(table_name)}
- SET {', '.join(update_fields)}
- WHERE id = ${param_index}
- RETURNING id, subject, predicate, object, description, subject_id, object_id, weight, chunk_ids, parent_id, metadata
- """
- try:
- result = await self.connection_manager.fetchrow_query(
- query=query,
- params=params,
- )
- return Relationship(
- id=result["id"],
- subject=result["subject"],
- predicate=result["predicate"],
- object=result["object"],
- description=result["description"],
- subject_id=result["subject_id"],
- object_id=result["object_id"],
- weight=result["weight"],
- chunk_ids=result["chunk_ids"],
- parent_id=result["parent_id"],
- metadata=result["metadata"],
- )
- except Exception as e:
- raise HTTPException(
- status_code=500,
- detail=f"An error occurred while updating the relationship: {e}",
- ) from e
- async def delete(
- self,
- parent_id: UUID,
- relationship_ids: Optional[list[UUID]] = None,
- store_type: StoreType = StoreType.GRAPHS,
- ) -> None:
- """
- Delete relationships from the specified store.
- If relationship_ids is not provided, deletes all relationships for the given parent_id.
- Args:
- parent_id: UUID of the parent (collection_id or document_id)
- relationship_ids: Optional list of specific relationship IDs to delete
- store_type: Type of store (graph or document)
- Returns:
- List of deleted relationship IDs
- Raises:
- R2RException: If specific relationships were requested but not all found
- """
- table_name = self._get_relationship_table_for_store(store_type)
- if relationship_ids is None:
- QUERY = f"""
- DELETE FROM {self._get_table_name(table_name)}
- WHERE parent_id = $1
- RETURNING id
- """
- results = await self.connection_manager.fetch_query(
- QUERY, [parent_id]
- )
- else:
- QUERY = f"""
- DELETE FROM {self._get_table_name(table_name)}
- WHERE id = ANY($1) AND parent_id = $2
- RETURNING id
- """
- results = await self.connection_manager.fetch_query(
- QUERY, [relationship_ids, parent_id]
- )
- deleted_ids = [row["id"] for row in results]
- if relationship_ids and len(deleted_ids) != len(relationship_ids):
- raise R2RException(
- f"Some relationships not found in {store_type} store or no permission to delete",
- 404,
- )
- async def export_to_csv(
- self,
- parent_id: UUID,
- store_type: StoreType,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- """
- Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
- """
- valid_columns = {
- "id",
- "subject",
- "predicate",
- "object",
- "description",
- "subject_id",
- "object_id",
- "weight",
- "chunk_ids",
- "parent_id",
- "metadata",
- "created_at",
- "updated_at",
- }
- if not columns:
- columns = list(valid_columns)
- elif invalid_cols := set(columns) - valid_columns:
- raise ValueError(f"Invalid columns: {invalid_cols}")
- select_stmt = f"""
- SELECT
- id::text,
- subject,
- predicate,
- object,
- description,
- subject_id::text,
- object_id::text,
- weight,
- chunk_ids::text,
- parent_id::text,
- metadata::text,
- to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
- to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
- FROM {self._get_table_name(self._get_relationship_table_for_store(store_type))}
- """
- conditions = ["parent_id = $1"]
- params: list[Any] = [parent_id]
- param_index = 2
- if filters:
- for field, value in filters.items():
- if field not in valid_columns:
- continue
- if isinstance(value, dict):
- for op, val in value.items():
- if op == "$eq":
- conditions.append(f"{field} = ${param_index}")
- params.append(val)
- param_index += 1
- elif op == "$gt":
- conditions.append(f"{field} > ${param_index}")
- params.append(val)
- param_index += 1
- elif op == "$lt":
- conditions.append(f"{field} < ${param_index}")
- params.append(val)
- param_index += 1
- else:
- # Direct equality
- conditions.append(f"{field} = ${param_index}")
- params.append(value)
- param_index += 1
- if conditions:
- select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
- select_stmt = f"{select_stmt} ORDER BY created_at DESC"
- temp_file = None
- try:
- temp_file = tempfile.NamedTemporaryFile(
- mode="w", delete=True, suffix=".csv"
- )
- writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
- async with self.connection_manager.pool.get_connection() as conn: # type: ignore
- async with conn.transaction():
- cursor = await conn.cursor(select_stmt, *params)
- if include_header:
- writer.writerow(columns)
- chunk_size = 1000
- while True:
- rows = await cursor.fetch(chunk_size)
- if not rows:
- break
- for row in rows:
- writer.writerow(row)
- temp_file.flush()
- return temp_file.name, temp_file
- except Exception as e:
- if temp_file:
- temp_file.close()
- raise HTTPException(
- status_code=500,
- detail=f"Failed to export data: {str(e)}",
- ) from e
- class PostgresCommunitiesHandler(Handler):
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- self.project_name: str = kwargs.get("project_name") # type: ignore
- self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore
- self.dimension: int = kwargs.get("dimension") # type: ignore
- self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore
- async def create_tables(self) -> None:
- vector_column_str = _decorate_vector_type(
- f"({self.dimension})", self.quantization_type
- )
- query = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name("graphs_communities")} (
- id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
- collection_id UUID,
- community_id UUID,
- level INT,
- name TEXT NOT NULL,
- summary TEXT NOT NULL,
- findings TEXT[],
- rating FLOAT,
- rating_explanation TEXT,
- description_embedding {vector_column_str} NOT NULL,
- created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
- metadata JSONB,
- UNIQUE (community_id, level, collection_id)
- );"""
- await self.connection_manager.execute_query(query)
- async def create(
- self,
- parent_id: UUID,
- store_type: StoreType,
- name: str,
- summary: str,
- findings: Optional[list[str]],
- rating: Optional[float],
- rating_explanation: Optional[str],
- description_embedding: Optional[list[float] | str] = None,
- ) -> Community:
- table_name = "graphs_communities"
- if isinstance(description_embedding, list):
- description_embedding = str(description_embedding)
- query = f"""
- INSERT INTO {self._get_table_name(table_name)}
- (collection_id, name, summary, findings, rating, rating_explanation, description_embedding)
- VALUES ($1, $2, $3, $4, $5, $6, $7)
- RETURNING id, collection_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
- """
- params = [
- parent_id,
- name,
- summary,
- findings,
- rating,
- rating_explanation,
- description_embedding,
- ]
- try:
- result = await self.connection_manager.fetchrow_query(
- query=query,
- params=params,
- )
- return Community(
- id=result["id"],
- collection_id=result["collection_id"],
- name=result["name"],
- summary=result["summary"],
- findings=result["findings"],
- rating=result["rating"],
- rating_explanation=result["rating_explanation"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- )
- except Exception as e:
- raise HTTPException(
- status_code=500,
- detail=f"An error occurred while creating the community: {e}",
- ) from e
- async def update(
- self,
- community_id: UUID,
- store_type: StoreType,
- name: Optional[str] = None,
- summary: Optional[str] = None,
- summary_embedding: Optional[list[float] | str] = None,
- findings: Optional[list[str]] = None,
- rating: Optional[float] = None,
- rating_explanation: Optional[str] = None,
- ) -> Community:
- table_name = "graphs_communities"
- update_fields = []
- params: list[Any] = []
- param_index = 1
- if name is not None:
- update_fields.append(f"name = ${param_index}")
- params.append(name)
- param_index += 1
- if summary is not None:
- update_fields.append(f"summary = ${param_index}")
- params.append(summary)
- param_index += 1
- if summary_embedding is not None:
- update_fields.append(f"description_embedding = ${param_index}")
- params.append(summary_embedding)
- param_index += 1
- if findings is not None:
- update_fields.append(f"findings = ${param_index}")
- params.append(findings)
- param_index += 1
- if rating is not None:
- update_fields.append(f"rating = ${param_index}")
- params.append(rating)
- param_index += 1
- if rating_explanation is not None:
- update_fields.append(f"rating_explanation = ${param_index}")
- params.append(rating_explanation)
- param_index += 1
- if not update_fields:
- raise R2RException(status_code=400, message="No fields to update")
- update_fields.append("updated_at = NOW()")
- params.append(community_id)
- query = f"""
- UPDATE {self._get_table_name(table_name)}
- SET {", ".join(update_fields)}
- WHERE id = ${param_index}\
- RETURNING id, community_id, name, summary, findings, rating, rating_explanation, created_at, updated_at
- """
- try:
- result = await self.connection_manager.fetchrow_query(
- query, params
- )
- return Community(
- id=result["id"],
- community_id=result["community_id"],
- name=result["name"],
- summary=result["summary"],
- findings=result["findings"],
- rating=result["rating"],
- rating_explanation=result["rating_explanation"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- )
- except Exception as e:
- raise HTTPException(
- status_code=500,
- detail=f"An error occurred while updating the community: {e}",
- ) from e
- async def delete(
- self,
- parent_id: UUID,
- community_id: UUID,
- ) -> None:
- table_name = "graphs_communities"
- params = [community_id, parent_id]
- # Delete the community
- query = f"""
- DELETE FROM {self._get_table_name(table_name)}
- WHERE id = $1 AND collection_id = $2
- """
- try:
- await self.connection_manager.execute_query(query, params)
- except Exception as e:
- raise HTTPException(
- status_code=500,
- detail=f"An error occurred while deleting the community: {e}",
- ) from e
- async def delete_all_communities(
- self,
- parent_id: UUID,
- ) -> None:
- table_name = "graphs_communities"
- params = [parent_id]
- # Delete all communities for the parent_id
- query = f"""
- DELETE FROM {self._get_table_name(table_name)}
- WHERE collection_id = $1
- """
- try:
- await self.connection_manager.execute_query(query, params)
- except Exception as e:
- raise HTTPException(
- status_code=500,
- detail=f"An error occurred while deleting communities: {e}",
- ) from e
- async def get(
- self,
- parent_id: UUID,
- store_type: StoreType,
- offset: int,
- limit: int,
- community_ids: Optional[list[UUID]] = None,
- community_names: Optional[list[str]] = None,
- include_embeddings: bool = False,
- ):
- """Retrieve communities from the specified store."""
- # Do we ever want to get communities from document store?
- table_name = "graphs_communities"
- conditions = ["collection_id = $1"]
- params: list[Any] = [parent_id]
- param_index = 2
- if community_ids:
- conditions.append(f"id = ANY(${param_index})")
- params.append(community_ids)
- param_index += 1
- if community_names:
- conditions.append(f"name = ANY(${param_index})")
- params.append(community_names)
- param_index += 1
- select_fields = """
- id, community_id, name, summary, findings, rating,
- rating_explanation, level, created_at, updated_at
- """
- if include_embeddings:
- select_fields += ", description_embedding"
- COUNT_QUERY = f"""
- SELECT COUNT(*)
- FROM {self._get_table_name(table_name)}
- WHERE {' AND '.join(conditions)}
- """
- count = (
- await self.connection_manager.fetch_query(
- COUNT_QUERY, params[: param_index - 1]
- )
- )[0]["count"]
- QUERY = f"""
- SELECT {select_fields}
- FROM {self._get_table_name(table_name)}
- WHERE {' AND '.join(conditions)}
- ORDER BY created_at
- OFFSET ${param_index}
- """
- params.append(offset)
- param_index += 1
- if limit != -1:
- QUERY += f" LIMIT ${param_index}"
- params.append(limit)
- rows = await self.connection_manager.fetch_query(QUERY, params)
- communities = []
- for row in rows:
- community_dict = dict(row)
- communities.append(Community(**community_dict))
- return communities, count
- async def export_to_csv(
- self,
- parent_id: UUID,
- store_type: StoreType,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- """
- Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
- """
- valid_columns = {
- "id",
- "collection_id",
- "community_id",
- "level",
- "name",
- "summary",
- "findings",
- "rating",
- "rating_explanation",
- "created_at",
- "updated_at",
- "metadata",
- }
- if not columns:
- columns = list(valid_columns)
- elif invalid_cols := set(columns) - valid_columns:
- raise ValueError(f"Invalid columns: {invalid_cols}")
- table_name = "graphs_communities"
- select_stmt = f"""
- SELECT
- id::text,
- collection_id::text,
- community_id::text,
- level,
- name,
- summary,
- findings::text,
- rating,
- rating_explanation,
- to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
- to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
- metadata::text
- FROM {self._get_table_name(table_name)}
- """
- conditions = ["collection_id = $1"]
- params: list[Any] = [parent_id]
- param_index = 2
- if filters:
- for field, value in filters.items():
- if field not in valid_columns:
- continue
- if isinstance(value, dict):
- for op, val in value.items():
- if op == "$eq":
- conditions.append(f"{field} = ${param_index}")
- params.append(val)
- param_index += 1
- elif op == "$gt":
- conditions.append(f"{field} > ${param_index}")
- params.append(val)
- param_index += 1
- elif op == "$lt":
- conditions.append(f"{field} < ${param_index}")
- params.append(val)
- param_index += 1
- else:
- # Direct equality
- conditions.append(f"{field} = ${param_index}")
- params.append(value)
- param_index += 1
- if conditions:
- select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
- select_stmt = f"{select_stmt} ORDER BY created_at DESC"
- temp_file = None
- try:
- temp_file = tempfile.NamedTemporaryFile(
- mode="w", delete=True, suffix=".csv"
- )
- writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
- async with self.connection_manager.pool.get_connection() as conn: # type: ignore
- async with conn.transaction():
- cursor = await conn.cursor(select_stmt, *params)
- if include_header:
- writer.writerow(columns)
- chunk_size = 1000
- while True:
- rows = await cursor.fetch(chunk_size)
- if not rows:
- break
- for row in rows:
- writer.writerow(row)
- temp_file.flush()
- return temp_file.name, temp_file
- except Exception as e:
- if temp_file:
- temp_file.close()
- raise HTTPException(
- status_code=500,
- detail=f"Failed to export data: {str(e)}",
- ) from e
- class PostgresGraphsHandler(Handler):
- """Handler for Knowledge Graph METHODS in PostgreSQL."""
- TABLE_NAME = "graphs"
- def __init__(
- self,
- *args: Any,
- **kwargs: Any,
- ) -> None:
- self.project_name: str = kwargs.get("project_name") # type: ignore
- self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore
- self.dimension: int = kwargs.get("dimension") # type: ignore
- self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore
- self.collections_handler: PostgresCollectionsHandler = kwargs.get("collections_handler") # type: ignore
- self.entities = PostgresEntitiesHandler(*args, **kwargs)
- self.relationships = PostgresRelationshipsHandler(*args, **kwargs)
- self.communities = PostgresCommunitiesHandler(*args, **kwargs)
- self.handlers = [
- self.entities,
- self.relationships,
- self.communities,
- ]
- import networkx as nx
- self.nx = nx
- async def create_tables(self) -> None:
- """Create the graph tables with mandatory collection_id support."""
- QUERY = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} (
- id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
- collection_id UUID NOT NULL,
- name TEXT NOT NULL,
- description TEXT,
- status TEXT NOT NULL,
- document_ids UUID[],
- metadata JSONB,
- created_at TIMESTAMPTZ DEFAULT NOW(),
- updated_at TIMESTAMPTZ DEFAULT NOW()
- );
- CREATE INDEX IF NOT EXISTS graph_collection_id_idx
- ON {self._get_table_name("graphs")} (collection_id);
- """
- await self.connection_manager.execute_query(QUERY)
- for handler in self.handlers:
- await handler.create_tables()
- async def create(
- self,
- collection_id: UUID,
- name: Optional[str] = None,
- description: Optional[str] = None,
- status: str = "pending",
- ) -> GraphResponse:
- """Create a new graph associated with a collection."""
- name = name or f"Graph {collection_id}"
- description = description or ""
- query = f"""
- INSERT INTO {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
- (id, collection_id, name, description, status)
- VALUES ($1, $2, $3, $4, $5)
- RETURNING id, collection_id, name, description, status, created_at, updated_at, document_ids
- """
- params = [
- collection_id,
- collection_id,
- name,
- description,
- status,
- ]
- try:
- result = await self.connection_manager.fetchrow_query(
- query=query,
- params=params,
- )
- return GraphResponse(
- id=result["id"],
- collection_id=result["collection_id"],
- name=result["name"],
- description=result["description"],
- status=result["status"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- document_ids=result["document_ids"] or [],
- )
- except UniqueViolationError:
- raise R2RException(
- message="Graph with this ID already exists",
- status_code=409,
- )
- async def reset(self, parent_id: UUID) -> None:
- """
- Completely reset a graph and all associated data.
- """
- await self.entities.delete(
- parent_id=parent_id, store_type=StoreType.GRAPHS
- )
- await self.relationships.delete(
- parent_id=parent_id, store_type=StoreType.GRAPHS
- )
- await self.communities.delete_all_communities(parent_id=parent_id)
- return
- async def list_graphs(
- self,
- offset: int,
- limit: int,
- # filter_user_ids: Optional[list[UUID]] = None,
- filter_graph_ids: Optional[list[UUID]] = None,
- filter_collection_id: Optional[UUID] = None,
- ) -> dict[str, list[GraphResponse] | int]:
- conditions = []
- params: list[Any] = []
- param_index = 1
- if filter_graph_ids:
- conditions.append(f"id = ANY(${param_index})")
- params.append(filter_graph_ids)
- param_index += 1
- # if filter_user_ids:
- # conditions.append(f"user_id = ANY(${param_index})")
- # params.append(filter_user_ids)
- # param_index += 1
- if filter_collection_id:
- conditions.append(f"collection_id = ${param_index}")
- params.append(filter_collection_id)
- param_index += 1
- where_clause = (
- f"WHERE {' AND '.join(conditions)}" if conditions else ""
- )
- query = f"""
- WITH RankedGraphs AS (
- SELECT
- id, collection_id, name, description, status, created_at, updated_at, document_ids,
- COUNT(*) OVER() as total_entries,
- ROW_NUMBER() OVER (PARTITION BY collection_id ORDER BY created_at DESC) as rn
- FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
- {where_clause}
- )
- SELECT * FROM RankedGraphs
- WHERE rn = 1
- ORDER BY created_at DESC
- OFFSET ${param_index} LIMIT ${param_index + 1}
- """
- params.extend([offset, limit])
- try:
- results = await self.connection_manager.fetch_query(query, params)
- if not results:
- return {"results": [], "total_entries": 0}
- total_entries = results[0]["total_entries"] if results else 0
- graphs = [
- GraphResponse(
- id=row["id"],
- document_ids=row["document_ids"] or [],
- name=row["name"],
- collection_id=row["collection_id"],
- description=row["description"],
- status=row["status"],
- created_at=row["created_at"],
- updated_at=row["updated_at"],
- )
- for row in results
- ]
- return {"results": graphs, "total_entries": total_entries}
- except Exception as e:
- raise HTTPException(
- status_code=500,
- detail=f"An error occurred while fetching graphs: {e}",
- ) from e
- async def get(
- self, offset: int, limit: int, graph_id: Optional[UUID] = None
- ):
- if graph_id is None:
- params = [offset, limit]
- QUERY = f"""
- SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
- OFFSET $1 LIMIT $2
- """
- ret = await self.connection_manager.fetch_query(QUERY, params)
- COUNT_QUERY = f"""
- SELECT COUNT(*) FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
- """
- count = (await self.connection_manager.fetch_query(COUNT_QUERY))[
- 0
- ]["count"]
- return {
- "results": [Graph(**row) for row in ret],
- "total_entries": count,
- }
- else:
- QUERY = f"""
- SELECT * FROM {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)} WHERE id = $1
- """
- params = [graph_id] # type: ignore
- return {
- "results": [
- Graph(
- **await self.connection_manager.fetchrow_query(
- QUERY, params
- )
- )
- ]
- }
- async def add_documents(self, id: UUID, document_ids: list[UUID]) -> bool:
- """
- Add documents to the graph by copying their entities and relationships.
- """
- # Copy entities from document_entity to graphs_entities
- ENTITY_COPY_QUERY = f"""
- INSERT INTO {self._get_table_name("graphs_entities")} (
- name, category, description, parent_id, description_embedding,
- chunk_ids, metadata
- )
- SELECT
- name, category, description, $1, description_embedding,
- chunk_ids, metadata
- FROM {self._get_table_name("documents_entities")}
- WHERE parent_id = ANY($2)
- """
- await self.connection_manager.execute_query(
- ENTITY_COPY_QUERY, [id, document_ids]
- )
- # Copy relationships from documents_relationships to graphs_relationships
- RELATIONSHIP_COPY_QUERY = f"""
- INSERT INTO {self._get_table_name("graphs_relationships")} (
- subject, predicate, object, description, subject_id, object_id,
- weight, chunk_ids, parent_id, metadata, description_embedding
- )
- SELECT
- subject, predicate, object, description, subject_id, object_id,
- weight, chunk_ids, $1, metadata, description_embedding
- FROM {self._get_table_name("documents_relationships")}
- WHERE parent_id = ANY($2)
- """
- await self.connection_manager.execute_query(
- RELATIONSHIP_COPY_QUERY, [id, document_ids]
- )
- # Add document_ids to the graph
- UPDATE_GRAPH_QUERY = f"""
- UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
- SET document_ids = array_cat(
- CASE
- WHEN document_ids IS NULL THEN ARRAY[]::uuid[]
- ELSE document_ids
- END,
- $2::uuid[]
- )
- WHERE id = $1
- """
- await self.connection_manager.execute_query(
- UPDATE_GRAPH_QUERY, [id, document_ids]
- )
- return True
- async def update(
- self,
- collection_id: UUID,
- name: Optional[str] = None,
- description: Optional[str] = None,
- ) -> GraphResponse:
- """Update an existing graph."""
- update_fields = []
- params: list = []
- param_index = 1
- if name is not None:
- update_fields.append(f"name = ${param_index}")
- params.append(name)
- param_index += 1
- if description is not None:
- update_fields.append(f"description = ${param_index}")
- params.append(description)
- param_index += 1
- if not update_fields:
- raise R2RException(status_code=400, message="No fields to update")
- update_fields.append("updated_at = NOW()")
- params.append(collection_id)
- query = f"""
- UPDATE {self._get_table_name(PostgresGraphsHandler.TABLE_NAME)}
- SET {', '.join(update_fields)}
- WHERE id = ${param_index}
- RETURNING id, name, description, status, created_at, updated_at, collection_id, document_ids
- """
- try:
- result = await self.connection_manager.fetchrow_query(
- query, params
- )
- if not result:
- raise R2RException(status_code=404, message="Graph not found")
- return GraphResponse(
- id=result["id"],
- collection_id=result["collection_id"],
- name=result["name"],
- description=result["description"],
- status=result["status"],
- created_at=result["created_at"],
- document_ids=result["document_ids"] or [],
- updated_at=result["updated_at"],
- )
- except Exception as e:
- raise HTTPException(
- status_code=500,
- detail=f"An error occurred while updating the graph: {e}",
- ) from e
- async def get_creation_estimate(
- self,
- graph_creation_settings: KGCreationSettings,
- document_id: Optional[UUID] = None,
- collection_id: Optional[UUID] = None,
- ):
- """Get the estimated cost and time for creating a KG."""
- if bool(document_id) ^ bool(collection_id) is False:
- raise ValueError(
- "Exactly one of document_id or collection_id must be provided."
- )
- # todo: harmonize the document_id and id fields: postgres table contains document_id, but other places use id.
- document_ids = (
- [document_id]
- if document_id
- else [
- doc.id for doc in (await self.collections_handler.documents_in_collection(collection_id, offset=0, limit=-1))["results"] # type: ignore
- ]
- )
- chunk_counts = await self.connection_manager.fetch_query(
- f"SELECT document_id, COUNT(*) as chunk_count FROM {self._get_table_name('vectors')} "
- f"WHERE document_id = ANY($1) GROUP BY document_id",
- [document_ids],
- )
- total_chunks = (
- sum(doc["chunk_count"] for doc in chunk_counts)
- // graph_creation_settings.chunk_merge_count
- )
- estimated_entities = (total_chunks * 10, total_chunks * 20)
- estimated_relationships = (
- int(estimated_entities[0] * 1.25),
- int(estimated_entities[1] * 1.5),
- )
- estimated_llm_calls = (
- total_chunks * 2 + estimated_entities[0],
- total_chunks * 2 + estimated_entities[1],
- )
- total_in_out_tokens = tuple(
- 2000 * calls // 1000000 for calls in estimated_llm_calls
- )
- cost_per_million = llm_cost_per_million_tokens(
- graph_creation_settings.generation_config.model
- )
- estimated_cost = tuple(
- tokens * cost_per_million for tokens in total_in_out_tokens
- )
- total_time_in_minutes = tuple(
- tokens * 10 / 60 for tokens in total_in_out_tokens
- )
- return {
- "message": 'Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `extract-triples` with `--run` in the cli, or `run_type="run"` in the client.',
- "document_count": len(document_ids),
- "number_of_jobs_created": len(document_ids) + 1,
- "total_chunks": total_chunks,
- "estimated_entities": _get_str_estimation_output(
- estimated_entities
- ),
- "estimated_relationships": _get_str_estimation_output(
- estimated_relationships
- ),
- "estimated_llm_calls": _get_str_estimation_output(
- estimated_llm_calls
- ),
- "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(
- total_in_out_tokens
- ),
- "estimated_cost_in_usd": _get_str_estimation_output(
- estimated_cost
- ),
- "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: "
- + _get_str_estimation_output(total_time_in_minutes),
- }
- async def get_enrichment_estimate(
- self,
- collection_id: UUID | None = None,
- graph_id: UUID | None = None,
- graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
- ):
- """Get the estimated cost and time for enriching a KG."""
- if collection_id is not None:
- document_ids = [
- doc.id
- for doc in (
- await self.collections_handler.documents_in_collection(collection_id, offset=0, limit=-1) # type: ignore
- )["results"]
- ]
- # Get entity and relationship counts
- entity_count = (
- await self.connection_manager.fetch_query(
- f"SELECT COUNT(*) FROM {self._get_table_name('entity')} WHERE document_id = ANY($1);",
- [document_ids],
- )
- )[0]["count"]
- if not entity_count:
- raise ValueError(
- "No entities found in the graph. Please run `extract-triples` first."
- )
- relationship_count = (
- await self.connection_manager.fetch_query(
- f"""SELECT COUNT(*) FROM {self._get_table_name("documents_relationships")} WHERE document_id = ANY($1);""",
- [document_ids],
- )
- )[0]["count"]
- else:
- entity_count = (
- await self.connection_manager.fetch_query(
- f"SELECT COUNT(*) FROM {self._get_table_name('entity')} WHERE $1 = ANY(graph_ids);",
- [graph_id],
- )
- )[0]["count"]
- if not entity_count:
- raise ValueError(
- "No entities found in the graph. Please run `extract-triples` first."
- )
- relationship_count = (
- await self.connection_manager.fetch_query(
- f"SELECT COUNT(*) FROM {self._get_table_name('relationship')} WHERE $1 = ANY(graph_ids);",
- [graph_id],
- )
- )[0]["count"]
- # Calculate estimates
- estimated_llm_calls = (entity_count // 10, entity_count // 5)
- tokens_in_millions = tuple(
- 2000 * calls / 1000000 for calls in estimated_llm_calls
- )
- cost_per_million = llm_cost_per_million_tokens(
- graph_enrichment_settings.generation_config.model # type: ignore
- )
- estimated_cost = tuple(
- tokens * cost_per_million for tokens in tokens_in_millions
- )
- estimated_time = tuple(
- tokens * 10 / 60 for tokens in tokens_in_millions
- )
- return {
- "message": 'Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `build-communities` with `--run` in the cli, or `run_type="run"` in the client.',
- "total_entities": entity_count,
- "total_relationships": relationship_count,
- "estimated_llm_calls": _get_str_estimation_output(
- estimated_llm_calls
- ),
- "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(
- tokens_in_millions
- ),
- "estimated_cost_in_usd": _get_str_estimation_output(
- estimated_cost
- ),
- "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: "
- + _get_str_estimation_output(estimated_time),
- }
- async def get_entities(
- self,
- parent_id: UUID,
- offset: int,
- limit: int,
- entity_ids: Optional[list[UUID]] = None,
- entity_names: Optional[list[str]] = None,
- include_embeddings: bool = False,
- ) -> tuple[list[Entity], int]:
- """
- Get entities for a graph.
- Args:
- offset: Number of records to skip
- limit: Maximum number of records to return (-1 for no limit)
- parent_id: UUID of the collection
- entity_ids: Optional list of entity IDs to filter by
- entity_names: Optional list of entity names to filter by
- include_embeddings: Whether to include embeddings in the response
- Returns:
- Tuple of (list of entities, total count)
- """
- conditions = ["parent_id = $1"]
- params: list[Any] = [parent_id]
- param_index = 2
- if entity_ids:
- conditions.append(f"id = ANY(${param_index})")
- params.append(entity_ids)
- param_index += 1
- if entity_names:
- conditions.append(f"name = ANY(${param_index})")
- params.append(entity_names)
- param_index += 1
- # Count query - uses the same conditions but without offset/limit
- COUNT_QUERY = f"""
- SELECT COUNT(*)
- FROM {self._get_table_name("graphs_entities")}
- WHERE {' AND '.join(conditions)}
- """
- count = (
- await self.connection_manager.fetch_query(COUNT_QUERY, params)
- )[0]["count"]
- # Define base columns to select
- select_fields = """
- id, name, category, description, parent_id,
- chunk_ids, metadata
- """
- if include_embeddings:
- select_fields += ", description_embedding"
- # Main query for fetching entities with pagination
- QUERY = f"""
- SELECT {select_fields}
- FROM {self._get_table_name("graphs_entities")}
- WHERE {' AND '.join(conditions)}
- ORDER BY created_at
- OFFSET ${param_index}
- """
- params.append(offset)
- param_index += 1
- if limit != -1:
- QUERY += f" LIMIT ${param_index}"
- params.append(limit)
- rows = await self.connection_manager.fetch_query(QUERY, params)
- entities = []
- for row in rows:
- entity_dict = dict(row)
- if isinstance(entity_dict["metadata"], str):
- with contextlib.suppress(json.JSONDecodeError):
- entity_dict["metadata"] = json.loads(
- entity_dict["metadata"]
- )
- entities.append(Entity(**entity_dict))
- return entities, count
- async def get_relationships(
- self,
- parent_id: UUID,
- offset: int,
- limit: int,
- relationship_ids: Optional[list[UUID]] = None,
- relationship_types: Optional[list[str]] = None,
- include_embeddings: bool = False,
- ) -> tuple[list[Relationship], int]:
- """
- Get relationships for a graph.
- Args:
- parent_id: UUID of the graph
- offset: Number of records to skip
- limit: Maximum number of records to return (-1 for no limit)
- relationship_ids: Optional list of relationship IDs to filter by
- relationship_types: Optional list of relationship types to filter by
- include_metadata: Whether to include metadata in the response
- Returns:
- Tuple of (list of relationships, total count)
- """
- conditions = ["parent_id = $1"]
- params: list[Any] = [parent_id]
- param_index = 2
- if relationship_ids:
- conditions.append(f"id = ANY(${param_index})")
- params.append(relationship_ids)
- param_index += 1
- if relationship_types:
- conditions.append(f"predicate = ANY(${param_index})")
- params.append(relationship_types)
- param_index += 1
- # Count query - uses the same conditions but without offset/limit
- COUNT_QUERY = f"""
- SELECT COUNT(*)
- FROM {self._get_table_name("graphs_relationships")}
- WHERE {' AND '.join(conditions)}
- """
- count = (
- await self.connection_manager.fetch_query(COUNT_QUERY, params)
- )[0]["count"]
- # Define base columns to select
- select_fields = """
- id, subject, predicate, object, weight, chunk_ids, parent_id, metadata
- """
- if include_embeddings:
- select_fields += ", description_embedding"
- # Main query for fetching relationships with pagination
- QUERY = f"""
- SELECT {select_fields}
- FROM {self._get_table_name("graphs_relationships")}
- WHERE {' AND '.join(conditions)}
- ORDER BY created_at
- OFFSET ${param_index}
- """
- params.append(offset)
- param_index += 1
- if limit != -1:
- QUERY += f" LIMIT ${param_index}"
- params.append(limit)
- rows = await self.connection_manager.fetch_query(QUERY, params)
- relationships = []
- for row in rows:
- relationship_dict = dict(row)
- if isinstance(relationship_dict["metadata"], str):
- with contextlib.suppress(json.JSONDecodeError):
- relationship_dict["metadata"] = json.loads(
- relationship_dict["metadata"]
- )
- relationships.append(Relationship(**relationship_dict))
- return relationships, count
- async def add_entities(
- self,
- entities: list[Entity],
- table_name: str,
- conflict_columns: list[str] = [],
- ) -> asyncpg.Record:
- """
- Upsert entities into the entities_raw table. These are raw entities extracted from the document.
- Args:
- entities: list[Entity]: list of entities to upsert
- collection_name: str: name of the collection
- Returns:
- result: asyncpg.Record: result of the upsert operation
- """
- cleaned_entities = []
- for entity in entities:
- entity_dict = entity.to_dict()
- entity_dict["chunk_ids"] = (
- entity_dict["chunk_ids"]
- if entity_dict.get("chunk_ids")
- else []
- )
- entity_dict["description_embedding"] = (
- str(entity_dict["description_embedding"])
- if entity_dict.get("description_embedding") # type: ignore
- else None
- )
- cleaned_entities.append(entity_dict)
- return await _add_objects(
- objects=cleaned_entities,
- full_table_name=self._get_table_name(table_name),
- connection_manager=self.connection_manager,
- conflict_columns=conflict_columns,
- )
- async def get_all_relationships(
- self,
- collection_id: UUID | None,
- graph_id: UUID | None,
- document_ids: Optional[list[UUID]] = None,
- ) -> list[Relationship]:
- QUERY = f"""
- SELECT id, subject, predicate, weight, object, parent_id FROM {self._get_table_name("graphs_relationships")} WHERE parent_id = ANY($1)
- """
- relationships = await self.connection_manager.fetch_query(
- QUERY, [collection_id]
- )
- return [Relationship(**relationship) for relationship in relationships]
- async def has_document(self, graph_id: UUID, document_id: UUID) -> bool:
- """
- Check if a document exists in the graph's document_ids array.
- Args:
- graph_id (UUID): ID of the graph to check
- document_id (UUID): ID of the document to look for
- Returns:
- bool: True if document exists in graph, False otherwise
- Raises:
- R2RException: If graph not found
- """
- QUERY = f"""
- SELECT EXISTS (
- SELECT 1
- FROM {self._get_table_name("graphs")}
- WHERE id = $1
- AND document_ids IS NOT NULL
- AND $2 = ANY(document_ids)
- ) as exists;
- """
- result = await self.connection_manager.fetchrow_query(
- QUERY, [graph_id, document_id]
- )
- if result is None:
- raise R2RException(f"Graph {graph_id} not found", 404)
- return result["exists"]
- async def get_communities(
- self,
- parent_id: UUID,
- offset: int,
- limit: int,
- community_ids: Optional[list[UUID]] = None,
- include_embeddings: bool = False,
- ) -> tuple[list[Community], int]:
- """
- Get communities for a graph.
- Args:
- collection_id: UUID of the collection
- offset: Number of records to skip
- limit: Maximum number of records to return (-1 for no limit)
- community_ids: Optional list of community IDs to filter by
- include_embeddings: Whether to include embeddings in the response
- Returns:
- Tuple of (list of communities, total count)
- """
- conditions = ["collection_id = $1"]
- params: list[Any] = [parent_id]
- param_index = 2
- if community_ids:
- conditions.append(f"id = ANY(${param_index})")
- params.append(community_ids)
- param_index += 1
- select_fields = """
- id, collection_id, name, summary, findings, rating, rating_explanation
- """
- if include_embeddings:
- select_fields += ", description_embedding"
- COUNT_QUERY = f"""
- SELECT COUNT(*)
- FROM {self._get_table_name("graphs_communities")}
- WHERE {' AND '.join(conditions)}
- """
- count = (
- await self.connection_manager.fetch_query(COUNT_QUERY, params)
- )[0]["count"]
- QUERY = f"""
- SELECT {select_fields}
- FROM {self._get_table_name("graphs_communities")}
- WHERE {' AND '.join(conditions)}
- ORDER BY created_at
- OFFSET ${param_index}
- """
- params.append(offset)
- param_index += 1
- if limit != -1:
- QUERY += f" LIMIT ${param_index}"
- params.append(limit)
- rows = await self.connection_manager.fetch_query(QUERY, params)
- communities = []
- for row in rows:
- community_dict = dict(row)
- communities.append(Community(**community_dict))
- return communities, count
- async def add_community(self, community: Community) -> None:
- # TODO: Fix in the short term.
- # we need to do this because postgres insert needs to be a string
- community.description_embedding = str(community.description_embedding) # type: ignore[assignment]
- non_null_attrs = {
- k: v for k, v in community.__dict__.items() if v is not None
- }
- columns = ", ".join(non_null_attrs.keys())
- placeholders = ", ".join(f"${i+1}" for i in range(len(non_null_attrs)))
- conflict_columns = ", ".join(
- [f"{k} = EXCLUDED.{k}" for k in non_null_attrs]
- )
- QUERY = f"""
- INSERT INTO {self._get_table_name("graphs_communities")} ({columns})
- VALUES ({placeholders})
- ON CONFLICT (community_id, level, collection_id) DO UPDATE SET
- {conflict_columns}
- """
- await self.connection_manager.execute_many(
- QUERY, [tuple(non_null_attrs.values())]
- )
- async def delete(self, collection_id: UUID) -> None:
- graphs = await self.get(graph_id=collection_id, offset=0, limit=-1)
- if len(graphs["results"]) == 0:
- raise R2RException(
- message=f"Graph not found for collection {collection_id}",
- status_code=404,
- )
- await self.reset(collection_id)
- # set status to PENDING for this collection.
- QUERY = f"""
- UPDATE {self._get_table_name("collections")} SET graph_cluster_status = $1 WHERE id = $2
- """
- await self.connection_manager.execute_query(
- QUERY, [KGExtractionStatus.PENDING, collection_id]
- )
- # Delete the graph
- QUERY = f"""
- DELETE FROM {self._get_table_name("graphs")} WHERE collection_id = $1
- """
- async def perform_graph_clustering(
- self,
- collection_id: UUID,
- leiden_params: dict[str, Any],
- clustering_mode: str,
- ) -> Tuple[int, Any]:
- """
- Calls the external clustering service to cluster the KG.
- """
- offset = 0
- page_size = 1000
- all_relationships = []
- while True:
- relationships, count = await self.relationships.get(
- parent_id=collection_id,
- store_type=StoreType.GRAPHS,
- offset=offset,
- limit=page_size,
- )
- if not relationships:
- break
- all_relationships.extend(relationships)
- offset += len(relationships)
- if offset >= count:
- break
- relationship_ids_cache = await self._get_relationship_ids_cache(
- all_relationships
- )
- logger.info(
- f"Clustering over {len(all_relationships)} relationships for {collection_id} with settings: {leiden_params}"
- )
- return await self._cluster_and_add_community_info(
- relationships=all_relationships,
- relationship_ids_cache=relationship_ids_cache,
- leiden_params=leiden_params,
- collection_id=collection_id,
- clustering_mode=clustering_mode,
- )
- async def _call_clustering_service(
- self, relationships: list[Relationship], leiden_params: dict[str, Any]
- ) -> list[dict]:
- """
- Calls the external Graspologic clustering service, sending relationships and parameters.
- Expects a response with 'communities' field.
- """
- # Convert relationships to a JSON-friendly format
- rel_data = []
- for r in relationships:
- rel_data.append(
- {
- "id": str(r.id),
- "subject": r.subject,
- "object": r.object,
- "weight": r.weight if r.weight is not None else 1.0,
- }
- )
- endpoint = os.environ.get("CLUSTERING_SERVICE_URL")
- if not endpoint:
- raise ValueError("CLUSTERING_SERVICE_URL not set.")
- url = f"{endpoint}/cluster"
- payload = {"relationships": rel_data, "leiden_params": leiden_params}
- async with httpx.AsyncClient() as client:
- response = await client.post(url, json=payload, timeout=3600)
- response.raise_for_status()
- data = response.json()
- communities = data.get("communities", [])
- return communities
- async def _create_graph_and_cluster(
- self,
- relationships: list[Relationship],
- leiden_params: dict[str, Any],
- clustering_mode: str = "remote",
- ) -> Any:
- """
- Create a graph and cluster it. If clustering_mode='local', use hierarchical_leiden locally.
- If clustering_mode='remote', call the external service.
- """
- if clustering_mode == "remote":
- logger.info("Sending request to external clustering service...")
- communities = await self._call_clustering_service(
- relationships, leiden_params
- )
- logger.info("Received communities from clustering service.")
- return communities
- else:
- # Local mode: run hierarchical_leiden directly
- G = self.nx.Graph()
- for relationship in relationships:
- G.add_edge(
- relationship.subject,
- relationship.object,
- weight=relationship.weight,
- id=relationship.id,
- )
- logger.info(
- f"Graph has {len(G.nodes)} nodes and {len(G.edges)} edges"
- )
- return await self._compute_leiden_communities(G, leiden_params)
- async def _cluster_and_add_community_info(
- self,
- relationships: list[Relationship],
- relationship_ids_cache: dict[str, list[int]],
- leiden_params: dict[str, Any],
- collection_id: Optional[UUID] = None,
- clustering_mode: str = "local",
- ) -> Tuple[int, Any]:
- # clear if there is any old information
- conditions = []
- if collection_id is not None:
- conditions.append("collection_id = $1")
- await asyncio.sleep(0.1)
- start_time = time.time()
- logger.info(f"Creating graph and clustering for {collection_id}")
- hierarchical_communities = await self._create_graph_and_cluster(
- relationships=relationships,
- leiden_params=leiden_params,
- clustering_mode=clustering_mode,
- )
- logger.info(
- f"Computing Leiden communities completed, time {time.time() - start_time:.2f} seconds."
- )
- def relationship_ids(node: str) -> list[int]:
- return relationship_ids_cache.get(node, [])
- logger.info(
- f"Cached {len(relationship_ids_cache)} relationship ids, time {time.time() - start_time:.2f} seconds."
- )
- # If remote: hierarchical_communities is a list of dicts like:
- # [{"node": str, "cluster": int, "level": int}, ...]
- # If local: hierarchical_communities is the returned structure from hierarchical_leiden (list of named tuples)
- if clustering_mode == "remote":
- if not hierarchical_communities:
- num_communities = 0
- else:
- num_communities = (
- max(item["cluster"] for item in hierarchical_communities)
- + 1
- )
- else:
- # Local mode: hierarchical_communities returned by hierarchical_leiden
- # According to the original code, it's likely a list of items with .cluster attribute
- if not hierarchical_communities:
- num_communities = 0
- else:
- num_communities = (
- max(item.cluster for item in hierarchical_communities) + 1
- )
- logger.info(
- f"Generated {num_communities} communities, time {time.time() - start_time:.2f} seconds."
- )
- return num_communities, hierarchical_communities
- async def _get_relationship_ids_cache(
- self, relationships: list[Relationship]
- ) -> dict[str, list[int]]:
- relationship_ids_cache: dict[str, list[int]] = {}
- for relationship in relationships:
- if relationship.subject is not None:
- relationship_ids_cache.setdefault(relationship.subject, [])
- if relationship.id is not None:
- relationship_ids_cache[relationship.subject].append(
- int(relationship.id)
- )
- if relationship.object is not None:
- relationship_ids_cache.setdefault(relationship.object, [])
- if relationship.id is not None:
- relationship_ids_cache[relationship.object].append(
- int(relationship.id)
- )
- return relationship_ids_cache
- async def get_entity_map(
- self, offset: int, limit: int, document_id: UUID
- ) -> dict[str, dict[str, list[dict[str, Any]]]]:
- QUERY1 = f"""
- WITH entities_list AS (
- SELECT DISTINCT name
- FROM {self._get_table_name("documents_entities")}
- WHERE parent_id = $1
- ORDER BY name ASC
- LIMIT {limit} OFFSET {offset}
- )
- SELECT e.name, e.description, e.category,
- (SELECT array_agg(DISTINCT x) FROM unnest(e.chunk_ids) x) AS chunk_ids,
- e.parent_id
- FROM {self._get_table_name("documents_entities")} e
- JOIN entities_list el ON e.name = el.name
- GROUP BY e.name, e.description, e.category, e.chunk_ids, e.parent_id
- ORDER BY e.name;"""
- entities_list = await self.connection_manager.fetch_query(
- QUERY1, [document_id]
- )
- entities_list = [Entity(**entity) for entity in entities_list]
- QUERY2 = f"""
- WITH entities_list AS (
- SELECT DISTINCT name
- FROM {self._get_table_name("documents_entities")}
- WHERE parent_id = $1
- ORDER BY name ASC
- LIMIT {limit} OFFSET {offset}
- )
- SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description,
- (SELECT array_agg(DISTINCT x) FROM unnest(t.chunk_ids) x) AS chunk_ids, t.parent_id
- FROM {self._get_table_name("documents_relationships")} t
- JOIN entities_list el ON t.subject = el.name
- ORDER BY t.subject, t.predicate, t.object;
- """
- relationships_list = await self.connection_manager.fetch_query(
- QUERY2, [document_id]
- )
- relationships_list = [
- Relationship(**relationship) for relationship in relationships_list
- ]
- entity_map: dict[str, dict[str, list[Any]]] = {}
- for entity in entities_list:
- if entity.name not in entity_map:
- entity_map[entity.name] = {"entities": [], "relationships": []}
- entity_map[entity.name]["entities"].append(entity)
- for relationship in relationships_list:
- if relationship.subject in entity_map:
- entity_map[relationship.subject]["relationships"].append(
- relationship
- )
- if relationship.object in entity_map:
- entity_map[relationship.object]["relationships"].append(
- relationship
- )
- return entity_map
- async def graph_search(
- self, query: str, **kwargs: Any
- ) -> AsyncGenerator[Any, None]:
- """
- Perform semantic search with similarity scores while maintaining exact same structure.
- """
- query_embedding = kwargs.get("query_embedding", None)
- if query_embedding is None:
- raise ValueError(
- "query_embedding must be provided for semantic search"
- )
- search_type = kwargs.get(
- "search_type", "entities"
- ) # entities | relationships | communities
- embedding_type = kwargs.get("embedding_type", "description_embedding")
- property_names = kwargs.get("property_names", ["name", "description"])
- # Add metadata if not present
- if "metadata" not in property_names:
- property_names.append("metadata")
- filters = kwargs.get("filters", {})
- limit = kwargs.get("limit", 10)
- use_fulltext_search = kwargs.get("use_fulltext_search", True)
- use_hybrid_search = kwargs.get("use_hybrid_search", True)
- if use_hybrid_search or use_fulltext_search:
- logger.warning(
- "Hybrid and fulltext search not supported for graph search, ignoring."
- )
- table_name = f"graphs_{search_type}"
- property_names_str = ", ".join(property_names)
- # Build the WHERE clause from filters
- params: list[str | int | bytes] = [
- json.dumps(query_embedding),
- limit,
- ]
- conditions_clause = self._build_filters(filters, params, search_type)
- where_clause = (
- f"WHERE {conditions_clause}" if conditions_clause else ""
- )
- # Construct the query
- # Note: For vector similarity, we use <=> for distance. The smaller the number, the more similar.
- # We'll convert that to similarity_score by doing (1 - distance).
- QUERY = f"""
- SELECT
- {property_names_str},
- ({embedding_type} <=> $1) as similarity_score
- FROM {self._get_table_name(table_name)}
- {where_clause}
- ORDER BY {embedding_type} <=> $1
- LIMIT $2;
- """
- results = await self.connection_manager.fetch_query(
- QUERY, tuple(params)
- )
- for result in results:
- output = {
- prop: result[prop] for prop in property_names if prop in result
- }
- output["similarity_score"] = 1 - float(result["similarity_score"])
- yield output
- def _build_filters(
- self, filter_dict: dict, parameters: list[Any], search_type: str
- ) -> str:
- """
- Build a WHERE clause from a nested filter dictionary for the graph search.
- For communities we use collection_id as primary key filter; for entities/relationships we use parent_id.
- """
- # Determine primary identifier column depending on search_type
- # communities: use collection_id
- # entities/relationships: use parent_id
- base_id_column = (
- "collection_id" if search_type == "communities" else "parent_id"
- )
- def parse_condition(key: str, value: Any) -> str:
- # This function returns a single condition (string) or empty if no valid condition.
- # Supported keys:
- # - base_id_column (collection_id or parent_id)
- # - metadata fields: metadata.some_field
- # Supported ops: $eq, $ne, $lt, $lte, $gt, $gte, $in, $contains
- if key == base_id_column:
- # e.g. {"collection_id": {"$eq": "<some-uuid>"}}
- if isinstance(value, dict):
- op, clause = next(iter(value.items()))
- if op == "$eq":
- parameters.append(str(clause))
- return f"{base_id_column} = ${len(parameters)}::uuid"
- elif op == "$in":
- # $in expects a list of UUIDs
- parameters.append([str(x) for x in clause])
- return f"{base_id_column} = ANY(${len(parameters)}::uuid[])"
- else:
- # direct equality?
- parameters.append(str(value))
- return f"{base_id_column} = ${len(parameters)}::uuid"
- elif key.startswith("metadata."):
- # Handle metadata filters
- # Example: {"metadata.some_key": {"$eq": "value"}}
- field = key.split("metadata.")[1]
- if isinstance(value, dict):
- op, clause = next(iter(value.items()))
- if op == "$eq":
- parameters.append(clause)
- return f"(metadata->>'{field}') = ${len(parameters)}"
- elif op == "$ne":
- parameters.append(clause)
- return f"(metadata->>'{field}') != ${len(parameters)}"
- elif op == "$lt":
- parameters.append(clause)
- return f"(metadata->>'{field}')::float < ${len(parameters)}::float"
- elif op == "$lte":
- parameters.append(clause)
- return f"(metadata->>'{field}')::float <= ${len(parameters)}::float"
- elif op == "$gt":
- parameters.append(clause)
- return f"(metadata->>'{field}')::float > ${len(parameters)}::float"
- elif op == "$gte":
- parameters.append(clause)
- return f"(metadata->>'{field}')::float >= ${len(parameters)}::float"
- elif op == "$in":
- # Ensure clause is a list
- if not isinstance(clause, list):
- raise Exception(
- "argument to $in filter must be a list"
- )
- # Append the Python list as a parameter; many drivers can convert Python lists to arrays
- parameters.append(clause)
- # Cast the parameter to a text array type
- return f"(metadata->>'{key}')::text = ANY(${len(parameters)}::text[])"
- # elif op == "$in":
- # # For $in, we assume an array of values and check if the field is in that set.
- # # Note: This is simplistic, adjust as needed.
- # parameters.append(clause)
- # # convert field to text and check membership
- # return f"(metadata->>'{field}') = ANY(SELECT jsonb_array_elements_text(${len(parameters)}::jsonb))"
- elif op == "$contains":
- # $contains for metadata likely means metadata @> clause in JSON.
- # If clause is dict or list, we use json containment.
- parameters.append(json.dumps(clause))
- return f"metadata @> ${len(parameters)}::jsonb"
- else:
- # direct equality
- parameters.append(value)
- return f"(metadata->>'{field}') = ${len(parameters)}"
- # Add additional conditions for other columns if needed
- # If key not recognized, return empty so it doesn't break query
- return ""
- def parse_filter(fd: dict) -> str:
- filter_conditions = []
- for k, v in fd.items():
- if k == "$and":
- and_parts = [parse_filter(sub) for sub in v if sub]
- # Remove empty strings
- and_parts = [x for x in and_parts if x.strip()]
- if and_parts:
- filter_conditions.append(
- f"({' AND '.join(and_parts)})"
- )
- elif k == "$or":
- or_parts = [parse_filter(sub) for sub in v if sub]
- # Remove empty strings
- or_parts = [x for x in or_parts if x.strip()]
- if or_parts:
- filter_conditions.append(f"({' OR '.join(or_parts)})")
- else:
- # Regular condition
- c = parse_condition(k, v)
- if c and c.strip():
- filter_conditions.append(c)
- if not filter_conditions:
- return ""
- if len(filter_conditions) == 1:
- return filter_conditions[0]
- return " AND ".join(filter_conditions)
- return parse_filter(filter_dict)
- async def _compute_leiden_communities(
- self,
- graph: Any,
- leiden_params: dict[str, Any],
- ) -> Any:
- """Compute Leiden communities."""
- try:
- from graspologic.partition import hierarchical_leiden
- if "random_seed" not in leiden_params:
- leiden_params["random_seed"] = (
- 7272 # add seed to control randomness
- )
- start_time = time.time()
- logger.info(
- f"Running Leiden clustering with params: {leiden_params}"
- )
- community_mapping = hierarchical_leiden(graph, **leiden_params)
- logger.info(
- f"Leiden clustering completed in {time.time() - start_time:.2f} seconds."
- )
- return community_mapping
- except ImportError as e:
- raise ImportError("Please install the graspologic package.") from e
- async def get_existing_document_entity_chunk_ids(
- self, document_id: UUID
- ) -> list[str]:
- QUERY = f"""
- SELECT DISTINCT unnest(chunk_ids) AS chunk_id FROM {self._get_table_name("documents_entities")} WHERE parent_id = $1
- """
- return [
- item["chunk_id"]
- for item in await self.connection_manager.fetch_query(
- QUERY, [document_id]
- )
- ]
- async def get_entity_count(
- self,
- collection_id: Optional[UUID] = None,
- document_id: Optional[UUID] = None,
- distinct: bool = False,
- entity_table_name: str = "entity",
- ) -> int:
- if collection_id is None and document_id is None:
- raise ValueError(
- "Either collection_id or document_id must be provided."
- )
- conditions = ["parent_id = $1"]
- params = [str(document_id)]
- count_value = "DISTINCT name" if distinct else "*"
- QUERY = f"""
- SELECT COUNT({count_value}) FROM {self._get_table_name(entity_table_name)}
- WHERE {" AND ".join(conditions)}
- """
- return (await self.connection_manager.fetch_query(QUERY, params))[0][
- "count"
- ]
- async def update_entity_descriptions(self, entities: list[Entity]):
- query = f"""
- UPDATE {self._get_table_name("graphs_entities")}
- SET description = $3, description_embedding = $4
- WHERE name = $1 AND graph_id = $2
- """
- inputs = [
- (
- entity.name,
- entity.parent_id,
- entity.description,
- entity.description_embedding,
- )
- for entity in entities
- ]
- await self.connection_manager.execute_many(query, inputs) # type: ignore
- def _json_serialize(obj):
- if isinstance(obj, UUID):
- return str(obj)
- elif isinstance(obj, (datetime.datetime, datetime.date)):
- return obj.isoformat()
- raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
- async def _add_objects(
- objects: list[dict],
- full_table_name: str,
- connection_manager: PostgresConnectionManager,
- conflict_columns: list[str] = [],
- exclude_metadata: list[str] = [],
- ) -> list[UUID]:
- """
- Bulk insert objects into the specified table using jsonb_to_recordset.
- """
- # Exclude specified metadata and prepare data
- cleaned_objects = []
- for obj in objects:
- cleaned_obj = {
- k: v
- for k, v in obj.items()
- if k not in exclude_metadata and v is not None
- }
- cleaned_objects.append(cleaned_obj)
- # Serialize the list of objects to JSON
- json_data = json.dumps(cleaned_objects, default=_json_serialize)
- # Prepare the column definitions for jsonb_to_recordset
- columns = cleaned_objects[0].keys()
- column_defs = []
- for col in columns:
- # Map Python types to PostgreSQL types
- sample_value = cleaned_objects[0][col]
- if "embedding" in col:
- pg_type = "vector"
- elif "chunk_ids" in col or "document_ids" in col or "graph_ids" in col:
- pg_type = "uuid[]"
- elif col == "id" or "_id" in col:
- pg_type = "uuid"
- elif isinstance(sample_value, str):
- pg_type = "text"
- elif isinstance(sample_value, UUID):
- pg_type = "uuid"
- elif isinstance(sample_value, (int, float)):
- pg_type = "numeric"
- elif isinstance(sample_value, list) and all(
- isinstance(x, UUID) for x in sample_value
- ):
- pg_type = "uuid[]"
- elif isinstance(sample_value, list):
- pg_type = "jsonb"
- elif isinstance(sample_value, dict):
- pg_type = "jsonb"
- elif isinstance(sample_value, bool):
- pg_type = "boolean"
- elif isinstance(sample_value, (datetime.datetime, datetime.date)):
- pg_type = "timestamp"
- else:
- raise TypeError(
- f"Unsupported data type for column '{col}': {type(sample_value)}"
- )
- column_defs.append(f"{col} {pg_type}")
- columns_str = ", ".join(columns)
- column_defs_str = ", ".join(column_defs)
- if conflict_columns:
- conflict_columns_str = ", ".join(conflict_columns)
- update_columns_str = ", ".join(
- f"{col}=EXCLUDED.{col}"
- for col in columns
- if col not in conflict_columns
- )
- on_conflict_clause = f"ON CONFLICT ({conflict_columns_str}) DO UPDATE SET {update_columns_str}"
- else:
- on_conflict_clause = ""
- QUERY = f"""
- INSERT INTO {full_table_name} ({columns_str})
- SELECT {columns_str}
- FROM jsonb_to_recordset($1::jsonb)
- AS x({column_defs_str})
- {on_conflict_clause}
- RETURNING id;
- """
- # Execute the query
- result = await connection_manager.fetch_query(QUERY, [json_data])
- # Extract and return the IDs
- return [record["id"] for record in result]
|