graph.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import json
  2. from dataclasses import dataclass
  3. from datetime import datetime
  4. from typing import Any, Optional
  5. from uuid import UUID
  6. from pydantic import Field
  7. from .base import R2RSerializable
  8. class Entity(R2RSerializable):
  9. """An entity extracted from a document."""
  10. name: str
  11. description: Optional[str] = None
  12. category: Optional[str] = None
  13. metadata: Optional[dict[str, Any] | str] = None
  14. id: Optional[UUID] = None
  15. parent_id: Optional[UUID] = None # graph_id | document_id
  16. description_embedding: Optional[list[float] | str] = None
  17. chunk_ids: Optional[list[UUID]] = []
  18. def __str__(self):
  19. return f"{self.name}:{self.category}"
  20. def __init__(self, **kwargs):
  21. super().__init__(**kwargs)
  22. if isinstance(self.metadata, str):
  23. try:
  24. self.metadata = json.loads(self.metadata)
  25. except json.JSONDecodeError:
  26. self.metadata = self.metadata
  27. class Relationship(R2RSerializable):
  28. """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities."""
  29. id: Optional[UUID] = None
  30. subject: str
  31. predicate: str
  32. object: str
  33. description: str | None = None
  34. subject_id: Optional[UUID] = None
  35. object_id: Optional[UUID] = None
  36. weight: float | None = 1.0
  37. chunk_ids: Optional[list[UUID]] = []
  38. parent_id: Optional[UUID] = None
  39. description_embedding: Optional[list[float] | str] = None
  40. metadata: Optional[dict[str, Any] | str] = None
  41. def __init__(self, **kwargs):
  42. super().__init__(**kwargs)
  43. if isinstance(self.metadata, str):
  44. try:
  45. self.metadata = json.loads(self.metadata)
  46. except json.JSONDecodeError:
  47. self.metadata = self.metadata
  48. @dataclass
  49. class Community(R2RSerializable):
  50. name: str = ""
  51. summary: str = ""
  52. level: Optional[int] = None
  53. findings: list[str] = []
  54. id: Optional[int | UUID] = None
  55. community_id: Optional[UUID] = None
  56. collection_id: Optional[UUID] = None
  57. rating: float | None = None
  58. rating_explanation: str | None = None
  59. description_embedding: list[float] | None = None
  60. attributes: dict[str, Any] | None = None
  61. created_at: datetime = Field(
  62. default_factory=datetime.utcnow,
  63. )
  64. updated_at: datetime = Field(
  65. default_factory=datetime.utcnow,
  66. )
  67. def __init__(self, **kwargs):
  68. if isinstance(kwargs.get("attributes", None), str):
  69. kwargs["attributes"] = json.loads(kwargs["attributes"])
  70. if isinstance(kwargs.get("embedding", None), str):
  71. kwargs["embedding"] = json.loads(kwargs["embedding"])
  72. super().__init__(**kwargs)
  73. @classmethod
  74. def from_dict(cls, data: dict[str, Any] | str) -> "Community":
  75. parsed_data: dict[str, Any] = (
  76. json.loads(data) if isinstance(data, str) else data
  77. )
  78. if isinstance(parsed_data.get("embedding", None), str):
  79. parsed_data["embedding"] = json.loads(parsed_data["embedding"])
  80. return cls(**parsed_data)
  81. class KGExtraction(R2RSerializable):
  82. """A protocol for a knowledge graph extraction."""
  83. entities: list[Entity]
  84. relationships: list[Relationship]
  85. class Graph(R2RSerializable):
  86. id: UUID | None = Field()
  87. name: str
  88. description: Optional[str] = None
  89. created_at: datetime = Field(
  90. default_factory=datetime.utcnow,
  91. )
  92. updated_at: datetime = Field(
  93. default_factory=datetime.utcnow,
  94. )
  95. status: str = "pending"
  96. class Config:
  97. populate_by_name = True
  98. from_attributes = True
  99. @classmethod
  100. def from_dict(cls, data: dict[str, Any] | str) -> "Graph":
  101. """Create a Graph instance from a dictionary."""
  102. # Convert string to dict if needed
  103. parsed_data: dict[str, Any] = (
  104. json.loads(data) if isinstance(data, str) else data
  105. )
  106. return cls(**parsed_data)
  107. def __init__(self, **kwargs):
  108. super().__init__(**kwargs)