graph.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import json
  2. from dataclasses import dataclass
  3. from datetime import datetime
  4. from enum import Enum
  5. from typing import Any, Optional
  6. from uuid import UUID
  7. from pydantic import Field
  8. from ..abstractions.llm import GenerationConfig
  9. from .base import R2RSerializable
  10. class Entity(R2RSerializable):
  11. """An entity extracted from a document."""
  12. name: str
  13. description: Optional[str] = None
  14. category: Optional[str] = None
  15. metadata: Optional[dict[str, Any]] = None
  16. id: Optional[UUID] = None
  17. parent_id: Optional[UUID] = None # graph_id | document_id
  18. description_embedding: Optional[list[float] | str] = None
  19. chunk_ids: Optional[list[UUID]] = []
  20. def __str__(self):
  21. return f"{self.name}:{self.category}"
  22. def __init__(self, **kwargs):
  23. super().__init__(**kwargs)
  24. if isinstance(self.metadata, str):
  25. try:
  26. self.metadata = json.loads(self.metadata)
  27. except json.JSONDecodeError:
  28. self.metadata = self.metadata
  29. class Relationship(R2RSerializable):
  30. """A relationship between two entities.
  31. This is a generic relationship, and can be used to represent any type of
  32. relationship between any two entities.
  33. """
  34. id: Optional[UUID] = None
  35. subject: str
  36. predicate: str
  37. object: str
  38. description: Optional[str] = None
  39. subject_id: Optional[UUID] = None
  40. object_id: Optional[UUID] = None
  41. weight: float | None = 1.0
  42. chunk_ids: Optional[list[UUID]] = []
  43. parent_id: Optional[UUID] = None
  44. description_embedding: Optional[list[float] | str] = None
  45. metadata: Optional[dict[str, Any] | str] = None
  46. def __init__(self, **kwargs):
  47. super().__init__(**kwargs)
  48. if isinstance(self.metadata, str):
  49. try:
  50. self.metadata = json.loads(self.metadata)
  51. except json.JSONDecodeError:
  52. self.metadata = self.metadata
  53. @dataclass
  54. class Community(R2RSerializable):
  55. name: str = ""
  56. summary: str = ""
  57. level: Optional[int] = None
  58. findings: list[str] = []
  59. id: Optional[int | UUID] = None
  60. community_id: Optional[UUID] = None
  61. collection_id: Optional[UUID] = None
  62. rating: Optional[float] = None
  63. rating_explanation: Optional[str] = None
  64. description_embedding: Optional[list[float]] = None
  65. attributes: dict[str, Any] | None = None
  66. created_at: datetime = Field(
  67. default_factory=datetime.utcnow,
  68. )
  69. updated_at: datetime = Field(
  70. default_factory=datetime.utcnow,
  71. )
  72. def __init__(self, **kwargs):
  73. if isinstance(kwargs.get("attributes", None), str):
  74. kwargs["attributes"] = json.loads(kwargs["attributes"])
  75. if isinstance(kwargs.get("embedding", None), str):
  76. kwargs["embedding"] = json.loads(kwargs["embedding"])
  77. super().__init__(**kwargs)
  78. @classmethod
  79. def from_dict(cls, data: dict[str, Any] | str) -> "Community":
  80. parsed_data: dict[str, Any] = (
  81. json.loads(data) if isinstance(data, str) else data
  82. )
  83. if isinstance(parsed_data.get("embedding", None), str):
  84. parsed_data["embedding"] = json.loads(parsed_data["embedding"])
  85. return cls(**parsed_data)
  86. class GraphExtraction(R2RSerializable):
  87. """A protocol for a knowledge graph extraction."""
  88. entities: list[Entity]
  89. relationships: list[Relationship]
  90. class Graph(R2RSerializable):
  91. id: UUID | None = Field()
  92. name: str
  93. description: Optional[str] = None
  94. created_at: datetime = Field(
  95. default_factory=datetime.utcnow,
  96. )
  97. updated_at: datetime = Field(
  98. default_factory=datetime.utcnow,
  99. )
  100. status: str = "pending"
  101. class Config:
  102. populate_by_name = True
  103. from_attributes = True
  104. @classmethod
  105. def from_dict(cls, data: dict[str, Any] | str) -> "Graph":
  106. """Create a Graph instance from a dictionary."""
  107. # Convert string to dict if needed
  108. parsed_data: dict[str, Any] = (
  109. json.loads(data) if isinstance(data, str) else data
  110. )
  111. return cls(**parsed_data)
  112. def __init__(self, **kwargs):
  113. super().__init__(**kwargs)
  114. class StoreType(str, Enum):
  115. GRAPHS = "graphs"
  116. DOCUMENTS = "documents"
  117. class GraphCreationSettings(R2RSerializable):
  118. """Settings for knowledge graph creation."""
  119. graph_extraction_prompt: str = Field(
  120. default="graph_extraction",
  121. description="The prompt to use for knowledge graph extraction.",
  122. )
  123. graph_entity_description_prompt: str = Field(
  124. default="graph_entity_description",
  125. description="The prompt to use for entity description generation.",
  126. )
  127. entity_types: list[str] = Field(
  128. default=[],
  129. description="The types of entities to extract.",
  130. )
  131. relation_types: list[str] = Field(
  132. default=[],
  133. description="The types of relations to extract.",
  134. )
  135. chunk_merge_count: int = Field(
  136. default=2,
  137. description="""The number of extractions to merge into a single graph
  138. extraction.""",
  139. )
  140. max_knowledge_relationships: int = Field(
  141. default=100,
  142. description="""The maximum number of knowledge relationships to extract
  143. from each chunk.""",
  144. )
  145. max_description_input_length: int = Field(
  146. default=65536,
  147. description="""The maximum length of the description for a node in the
  148. graph.""",
  149. )
  150. generation_config: Optional[GenerationConfig] = Field(
  151. default=None,
  152. description="Configuration for text generation during graph enrichment.",
  153. )
  154. automatic_deduplication: bool = Field(
  155. default=False,
  156. description="Whether to automatically deduplicate entities.",
  157. )
  158. class GraphEnrichmentSettings(R2RSerializable):
  159. """Settings for knowledge graph enrichment."""
  160. force_graph_search_results_enrichment: bool = Field(
  161. default=False,
  162. description="""Force run the enrichment step even if graph creation is
  163. still in progress for some documents.""",
  164. )
  165. graph_communities_prompt: str = Field(
  166. default="graph_communities",
  167. description="The prompt to use for knowledge graph enrichment.",
  168. )
  169. max_summary_input_length: int = Field(
  170. default=65536,
  171. description="The maximum length of the summary for a community.",
  172. )
  173. generation_config: Optional[GenerationConfig] = Field(
  174. default=None,
  175. description="Configuration for text generation during graph enrichment.",
  176. )
  177. leiden_params: dict = Field(
  178. default_factory=dict,
  179. description="Parameters for the Leiden algorithm.",
  180. )
  181. class GraphCommunitySettings(R2RSerializable):
  182. """Settings for knowledge graph community enrichment."""
  183. force_graph_search_results_enrichment: bool = Field(
  184. default=False,
  185. description="""Force run the enrichment step even if graph creation is
  186. still in progress for some documents.""",
  187. )
  188. graph_communities: str = Field(
  189. default="graph_communities",
  190. description="The prompt to use for knowledge graph enrichment.",
  191. )
  192. max_summary_input_length: int = Field(
  193. default=65536,
  194. description="The maximum length of the summary for a community.",
  195. )
  196. generation_config: Optional[GenerationConfig] = Field(
  197. default=None,
  198. description="Configuration for text generation during graph enrichment.",
  199. )
  200. leiden_params: dict = Field(
  201. default_factory=dict,
  202. description="Parameters for the Leiden algorithm.",
  203. )