kg.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. from enum import Enum
  2. from pydantic import Field
  3. from .base import R2RSerializable
  4. from .llm import GenerationConfig
  5. class KGRunType(str, Enum):
  6. """Type of KG run."""
  7. ESTIMATE = "estimate"
  8. RUN = "run" # deprecated
  9. def __str__(self):
  10. return self.value
  11. GraphRunType = KGRunType
  12. class KGCreationSettings(R2RSerializable):
  13. """Settings for knowledge graph creation."""
  14. clustering_mode: str = Field(
  15. default="local",
  16. description="Whether to use remote clustering for graph creation.",
  17. )
  18. graphrag_relationships_extraction_few_shot: str = Field(
  19. default="graphrag_relationships_extraction_few_shot",
  20. description="The prompt to use for knowledge graph extraction.",
  21. alias="graphrag_relationships_extraction_few_shot", # TODO - mark deprecated & remove
  22. )
  23. graph_entity_description_prompt: str = Field(
  24. default="graphrag_entity_description",
  25. description="The prompt to use for entity description generation.",
  26. alias="graphrag_entity_description_prompt", # TODO - mark deprecated & remove
  27. )
  28. entity_types: list[str] = Field(
  29. default=[],
  30. description="The types of entities to extract.",
  31. )
  32. relation_types: list[str] = Field(
  33. default=[],
  34. description="The types of relations to extract.",
  35. )
  36. chunk_merge_count: int = Field(
  37. default=4,
  38. description="The number of extractions to merge into a single KG extraction.",
  39. )
  40. max_knowledge_relationships: int = Field(
  41. default=100,
  42. description="The maximum number of knowledge relationships to extract from each chunk.",
  43. )
  44. max_description_input_length: int = Field(
  45. default=65536,
  46. description="The maximum length of the description for a node in the graph.",
  47. )
  48. generation_config: GenerationConfig = Field(
  49. default_factory=GenerationConfig,
  50. description="Configuration for text generation during graph enrichment.",
  51. )
  52. class KGEnrichmentSettings(R2RSerializable):
  53. """Settings for knowledge graph enrichment."""
  54. force_kg_enrichment: bool = Field(
  55. default=False,
  56. description="Force run the enrichment step even if graph creation is still in progress for some documents.",
  57. )
  58. graphrag_communities: str = Field(
  59. default="graphrag_communities",
  60. description="The prompt to use for knowledge graph enrichment.",
  61. alias="graphrag_communities", # TODO - mark deprecated & remove
  62. )
  63. max_summary_input_length: int = Field(
  64. default=65536,
  65. description="The maximum length of the summary for a community.",
  66. )
  67. generation_config: GenerationConfig = Field(
  68. default_factory=GenerationConfig,
  69. description="Configuration for text generation during graph enrichment.",
  70. )
  71. leiden_params: dict = Field(
  72. default_factory=dict,
  73. description="Parameters for the Leiden algorithm.",
  74. )
  75. class GraphCommunitySettings(R2RSerializable):
  76. """Settings for knowledge graph community enrichment."""
  77. force_kg_enrichment: bool = Field(
  78. default=False,
  79. description="Force run the enrichment step even if graph creation is still in progress for some documents.",
  80. )
  81. graphrag_communities: str = Field(
  82. default="graphrag_communities",
  83. description="The prompt to use for knowledge graph enrichment.",
  84. )
  85. max_summary_input_length: int = Field(
  86. default=65536,
  87. description="The maximum length of the summary for a community.",
  88. )
  89. generation_config: GenerationConfig = Field(
  90. default_factory=GenerationConfig,
  91. description="Configuration for text generation during graph enrichment.",
  92. )
  93. leiden_params: dict = Field(
  94. default_factory=dict,
  95. description="Parameters for the Leiden algorithm.",
  96. )