123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- import json
- from dataclasses import dataclass
- from datetime import datetime
- from enum import Enum
- from typing import Any, Optional
- from uuid import UUID
- from pydantic import Field
- from ..abstractions.llm import GenerationConfig
- from .base import R2RSerializable
- class Entity(R2RSerializable):
- """An entity extracted from a document."""
- name: str
- description: Optional[str] = None
- category: Optional[str] = None
- metadata: Optional[dict[str, Any]] = None
- id: Optional[UUID] = None
- parent_id: Optional[UUID] = None # graph_id | document_id
- description_embedding: Optional[list[float] | str] = None
- chunk_ids: Optional[list[UUID]] = []
- def __str__(self):
- return f"{self.name}:{self.category}"
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- if isinstance(self.metadata, str):
- try:
- self.metadata = json.loads(self.metadata)
- except json.JSONDecodeError:
- self.metadata = self.metadata
- class Relationship(R2RSerializable):
- """A relationship between two entities.
- This is a generic relationship, and can be used to represent any type of
- relationship between any two entities.
- """
- id: Optional[UUID] = None
- subject: str
- predicate: str
- object: str
- description: Optional[str] = None
- subject_id: Optional[UUID] = None
- object_id: Optional[UUID] = None
- weight: float | None = 1.0
- chunk_ids: Optional[list[UUID]] = []
- parent_id: Optional[UUID] = None
- description_embedding: Optional[list[float] | str] = None
- metadata: Optional[dict[str, Any] | str] = None
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- if isinstance(self.metadata, str):
- try:
- self.metadata = json.loads(self.metadata)
- except json.JSONDecodeError:
- self.metadata = self.metadata
- @dataclass
- class Community(R2RSerializable):
- name: str = ""
- summary: str = ""
- level: Optional[int] = None
- findings: list[str] = []
- id: Optional[int | UUID] = None
- community_id: Optional[UUID] = None
- collection_id: Optional[UUID] = None
- rating: Optional[float] = None
- rating_explanation: Optional[str] = None
- description_embedding: Optional[list[float]] = None
- attributes: dict[str, Any] | None = None
- created_at: datetime = Field(
- default_factory=datetime.utcnow,
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow,
- )
- def __init__(self, **kwargs):
- if isinstance(kwargs.get("attributes", None), str):
- kwargs["attributes"] = json.loads(kwargs["attributes"])
- if isinstance(kwargs.get("embedding", None), str):
- kwargs["embedding"] = json.loads(kwargs["embedding"])
- super().__init__(**kwargs)
- @classmethod
- def from_dict(cls, data: dict[str, Any] | str) -> "Community":
- parsed_data: dict[str, Any] = (
- json.loads(data) if isinstance(data, str) else data
- )
- if isinstance(parsed_data.get("embedding", None), str):
- parsed_data["embedding"] = json.loads(parsed_data["embedding"])
- return cls(**parsed_data)
- class GraphExtraction(R2RSerializable):
- """A protocol for a knowledge graph extraction."""
- entities: list[Entity]
- relationships: list[Relationship]
- class Graph(R2RSerializable):
- id: UUID | None = Field()
- name: str
- description: Optional[str] = None
- created_at: datetime = Field(
- default_factory=datetime.utcnow,
- )
- updated_at: datetime = Field(
- default_factory=datetime.utcnow,
- )
- status: str = "pending"
- class Config:
- populate_by_name = True
- from_attributes = True
- @classmethod
- def from_dict(cls, data: dict[str, Any] | str) -> "Graph":
- """Create a Graph instance from a dictionary."""
- # Convert string to dict if needed
- parsed_data: dict[str, Any] = (
- json.loads(data) if isinstance(data, str) else data
- )
- return cls(**parsed_data)
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- class StoreType(str, Enum):
- GRAPHS = "graphs"
- DOCUMENTS = "documents"
- class GraphCreationSettings(R2RSerializable):
- """Settings for knowledge graph creation."""
- graph_extraction_prompt: str = Field(
- default="graph_extraction",
- description="The prompt to use for knowledge graph extraction.",
- )
- graph_entity_description_prompt: str = Field(
- default="graph_entity_description",
- description="The prompt to use for entity description generation.",
- )
- entity_types: list[str] = Field(
- default=[],
- description="The types of entities to extract.",
- )
- relation_types: list[str] = Field(
- default=[],
- description="The types of relations to extract.",
- )
- chunk_merge_count: int = Field(
- default=2,
- description="""The number of extractions to merge into a single graph
- extraction.""",
- )
- max_knowledge_relationships: int = Field(
- default=100,
- description="""The maximum number of knowledge relationships to extract
- from each chunk.""",
- )
- max_description_input_length: int = Field(
- default=65536,
- description="""The maximum length of the description for a node in the
- graph.""",
- )
- generation_config: Optional[GenerationConfig] = Field(
- default=None,
- description="Configuration for text generation during graph enrichment.",
- )
- automatic_deduplication: bool = Field(
- default=False,
- description="Whether to automatically deduplicate entities.",
- )
- class GraphEnrichmentSettings(R2RSerializable):
- """Settings for knowledge graph enrichment."""
- force_graph_search_results_enrichment: bool = Field(
- default=False,
- description="""Force run the enrichment step even if graph creation is
- still in progress for some documents.""",
- )
- graph_communities_prompt: str = Field(
- default="graph_communities",
- description="The prompt to use for knowledge graph enrichment.",
- )
- max_summary_input_length: int = Field(
- default=65536,
- description="The maximum length of the summary for a community.",
- )
- generation_config: Optional[GenerationConfig] = Field(
- default=None,
- description="Configuration for text generation during graph enrichment.",
- )
- leiden_params: dict = Field(
- default_factory=dict,
- description="Parameters for the Leiden algorithm.",
- )
- class GraphCommunitySettings(R2RSerializable):
- """Settings for knowledge graph community enrichment."""
- force_graph_search_results_enrichment: bool = Field(
- default=False,
- description="""Force run the enrichment step even if graph creation is
- still in progress for some documents.""",
- )
- graph_communities: str = Field(
- default="graph_communities",
- description="The prompt to use for knowledge graph enrichment.",
- )
- max_summary_input_length: int = Field(
- default=65536,
- description="The maximum length of the summary for a community.",
- )
- generation_config: Optional[GenerationConfig] = Field(
- default=None,
- description="Configuration for text generation during graph enrichment.",
- )
- leiden_params: dict = Field(
- default_factory=dict,
- description="Parameters for the Leiden algorithm.",
- )
|