clustering.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import logging
  2. from typing import Any, AsyncGenerator
  3. from uuid import UUID
  4. from core.base import (
  5. AsyncPipe,
  6. AsyncState,
  7. CompletionProvider,
  8. EmbeddingProvider,
  9. )
  10. # from ...database.postgres import PostgresDatabaseProvider
  11. from core.database import PostgresDatabaseProvider
  12. logger = logging.getLogger()
  13. class GraphClusteringPipe(AsyncPipe):
  14. """
  15. Clusters entities and relationships into communities within the knowledge graph using hierarchical Leiden algorithm.
  16. """
  17. def __init__(
  18. self,
  19. database_provider: PostgresDatabaseProvider,
  20. llm_provider: CompletionProvider,
  21. embedding_provider: EmbeddingProvider,
  22. config: AsyncPipe.PipeConfig,
  23. *args,
  24. **kwargs,
  25. ):
  26. """
  27. Initializes the KG clustering pipe with necessary components and configurations.
  28. """
  29. super().__init__(
  30. config=config or AsyncPipe.PipeConfig(name="kg_cluster_pipe"),
  31. )
  32. self.database_provider = database_provider
  33. self.llm_provider = llm_provider
  34. self.embedding_provider = embedding_provider
  35. async def cluster_kg(
  36. self,
  37. collection_id: UUID,
  38. leiden_params: dict,
  39. clustering_mode: str,
  40. ):
  41. """
  42. Clusters the knowledge graph relationships into communities using hierarchical Leiden algorithm. Uses graspologic library.
  43. """
  44. num_communities = await self.database_provider.graphs_handler.perform_graph_clustering(
  45. collection_id=collection_id,
  46. leiden_params=leiden_params,
  47. clustering_mode=clustering_mode,
  48. )
  49. return {
  50. "num_communities": num_communities,
  51. }
  52. async def _run_logic( # type: ignore
  53. self,
  54. input: AsyncPipe.Input,
  55. state: AsyncState,
  56. run_id: UUID,
  57. *args: Any,
  58. **kwargs: Any,
  59. ) -> AsyncGenerator[dict, None]:
  60. """
  61. Executes the KG clustering pipe: clustering entities and relationships into communities.
  62. """
  63. collection_id = input.message.get("collection_id", None)
  64. leiden_params = input.message["leiden_params"]
  65. clustering_mode = input.message["clustering_mode"]
  66. yield await self.cluster_kg(
  67. collection_id=collection_id,
  68. leiden_params=leiden_params,
  69. clustering_mode=clustering_mode,
  70. )