graph.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import json
  2. from dataclasses import dataclass
  3. from datetime import datetime
  4. from enum import Enum
  5. from typing import Any, Optional
  6. from uuid import UUID
  7. from pydantic import Field
  8. from .base import R2RSerializable
  9. class Entity(R2RSerializable):
  10. """An entity extracted from a document."""
  11. name: str
  12. description: Optional[str] = None
  13. category: Optional[str] = None
  14. metadata: Optional[dict[str, Any] | str] = None
  15. id: Optional[UUID] = None
  16. parent_id: Optional[UUID] = None # graph_id | document_id
  17. description_embedding: Optional[list[float] | str] = None
  18. chunk_ids: Optional[list[UUID]] = []
  19. def __str__(self):
  20. return f"{self.name}:{self.category}"
  21. def __init__(self, **kwargs):
  22. super().__init__(**kwargs)
  23. if isinstance(self.metadata, str):
  24. try:
  25. self.metadata = json.loads(self.metadata)
  26. except json.JSONDecodeError:
  27. self.metadata = self.metadata
  28. class Relationship(R2RSerializable):
  29. """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities."""
  30. id: Optional[UUID] = None
  31. subject: str
  32. predicate: str
  33. object: str
  34. description: str | None = None
  35. subject_id: Optional[UUID] = None
  36. object_id: Optional[UUID] = None
  37. weight: float | None = 1.0
  38. chunk_ids: Optional[list[UUID]] = []
  39. parent_id: Optional[UUID] = None
  40. description_embedding: Optional[list[float] | str] = None
  41. metadata: Optional[dict[str, Any] | str] = None
  42. def __init__(self, **kwargs):
  43. super().__init__(**kwargs)
  44. if isinstance(self.metadata, str):
  45. try:
  46. self.metadata = json.loads(self.metadata)
  47. except json.JSONDecodeError:
  48. self.metadata = self.metadata
  49. @dataclass
  50. class Community(R2RSerializable):
  51. name: str = ""
  52. summary: str = ""
  53. level: Optional[int] = None
  54. findings: list[str] = []
  55. id: Optional[int | UUID] = None
  56. community_id: Optional[UUID] = None
  57. collection_id: Optional[UUID] = None
  58. rating: float | None = None
  59. rating_explanation: str | None = None
  60. description_embedding: list[float] | None = None
  61. attributes: dict[str, Any] | None = None
  62. created_at: datetime = Field(
  63. default_factory=datetime.utcnow,
  64. )
  65. updated_at: datetime = Field(
  66. default_factory=datetime.utcnow,
  67. )
  68. def __init__(self, **kwargs):
  69. if isinstance(kwargs.get("attributes", None), str):
  70. kwargs["attributes"] = json.loads(kwargs["attributes"])
  71. if isinstance(kwargs.get("embedding", None), str):
  72. kwargs["embedding"] = json.loads(kwargs["embedding"])
  73. super().__init__(**kwargs)
  74. @classmethod
  75. def from_dict(cls, data: dict[str, Any] | str) -> "Community":
  76. parsed_data: dict[str, Any] = (
  77. json.loads(data) if isinstance(data, str) else data
  78. )
  79. if isinstance(parsed_data.get("embedding", None), str):
  80. parsed_data["embedding"] = json.loads(parsed_data["embedding"])
  81. return cls(**parsed_data)
  82. class KGExtraction(R2RSerializable):
  83. """A protocol for a knowledge graph extraction."""
  84. entities: list[Entity]
  85. relationships: list[Relationship]
  86. class Graph(R2RSerializable):
  87. id: UUID | None = Field()
  88. name: str
  89. description: Optional[str] = None
  90. created_at: datetime = Field(
  91. default_factory=datetime.utcnow,
  92. )
  93. updated_at: datetime = Field(
  94. default_factory=datetime.utcnow,
  95. )
  96. status: str = "pending"
  97. class Config:
  98. populate_by_name = True
  99. from_attributes = True
  100. @classmethod
  101. def from_dict(cls, data: dict[str, Any] | str) -> "Graph":
  102. """Create a Graph instance from a dictionary."""
  103. # Convert string to dict if needed
  104. parsed_data: dict[str, Any] = (
  105. json.loads(data) if isinstance(data, str) else data
  106. )
  107. return cls(**parsed_data)
  108. def __init__(self, **kwargs):
  109. super().__init__(**kwargs)
  110. class StoreType(str, Enum):
  111. GRAPHS = "graphs"
  112. DOCUMENTS = "documents"