12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- import logging
- from typing import Any, AsyncGenerator
- from uuid import UUID
- from core.base import (
- AsyncPipe,
- AsyncState,
- CompletionProvider,
- EmbeddingProvider,
- )
- # from ...database.postgres import PostgresDatabaseProvider
- from core.database import PostgresDatabaseProvider
- logger = logging.getLogger()
- class GraphClusteringPipe(AsyncPipe):
- """
- Clusters entities and relationships into communities within the knowledge graph using hierarchical Leiden algorithm.
- """
- def __init__(
- self,
- database_provider: PostgresDatabaseProvider,
- llm_provider: CompletionProvider,
- embedding_provider: EmbeddingProvider,
- config: AsyncPipe.PipeConfig,
- *args,
- **kwargs,
- ):
- """
- Initializes the KG clustering pipe with necessary components and configurations.
- """
- super().__init__(
- config=config or AsyncPipe.PipeConfig(name="kg_cluster_pipe"),
- )
- self.database_provider = database_provider
- self.llm_provider = llm_provider
- self.embedding_provider = embedding_provider
- async def cluster_kg(
- self,
- collection_id: UUID,
- leiden_params: dict,
- clustering_mode: str,
- ):
- """
- Clusters the knowledge graph relationships into communities using hierarchical Leiden algorithm. Uses graspologic library.
- """
- num_communities = await self.database_provider.graphs_handler.perform_graph_clustering(
- collection_id=collection_id,
- leiden_params=leiden_params,
- clustering_mode=clustering_mode,
- )
- return {
- "num_communities": num_communities,
- }
- async def _run_logic( # type: ignore
- self,
- input: AsyncPipe.Input,
- state: AsyncState,
- run_id: UUID,
- *args: Any,
- **kwargs: Any,
- ) -> AsyncGenerator[dict, None]:
- """
- Executes the KG clustering pipe: clustering entities and relationships into communities.
- """
- collection_id = input.message.get("collection_id", None)
- leiden_params = input.message["leiden_params"]
- clustering_mode = input.message["clustering_mode"]
- yield await self.cluster_kg(
- collection_id=collection_id,
- leiden_params=leiden_params,
- clustering_mode=clustering_mode,
- )
|