123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480 |
- """Abstractions for search functionality."""
- from copy import copy
- from enum import Enum
- from typing import Any, Optional
- from uuid import UUID
- from pydantic import Field
- from .base import R2RSerializable
- from .llm import GenerationConfig
- from .vector import IndexMeasure
- class ChunkSearchResult(R2RSerializable):
- """Result of a search operation."""
- id: UUID
- document_id: UUID
- owner_id: Optional[UUID]
- collection_ids: list[UUID]
- score: float
- text: str
- metadata: dict[str, Any]
- def __str__(self) -> str:
- return f"ChunkSearchResult(id={self.id}, document_id={self.document_id}, score={self.score}, text={self.text}, metadata={self.metadata})"
- def __repr__(self) -> str:
- return self.__str__()
- def as_dict(self) -> dict:
- return {
- "id": self.id,
- "document_id": self.document_id,
- "owner_id": self.owner_id,
- "collection_ids": self.collection_ids,
- "score": self.score,
- "text": self.text,
- "metadata": self.metadata,
- }
- class Config:
- populate_by_name = True
- json_schema_extra = {
- "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
- "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
- "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
- "collection_ids": [],
- "score": 0.23943702876567796,
- "text": "Example text from the document",
- "metadata": {
- "title": "example_document.pdf",
- "associated_query": "What is the capital of France?",
- },
- }
- class KGSearchResultType(str, Enum):
- ENTITY = "entity"
- RELATIONSHIP = "relationship"
- COMMUNITY = "community"
- class KGEntityResult(R2RSerializable):
- name: str
- description: str
- metadata: Optional[dict[str, Any]] = None
- class Config:
- json_schema_extra = {
- "name": "Entity Name",
- "description": "Entity Description",
- "metadata": {},
- }
- class KGRelationshipResult(R2RSerializable):
- # name: str
- subject: str
- predicate: str
- object: str
- metadata: Optional[dict[str, Any]] = None
- score: Optional[float] = None
- # name: str
- # description: str
- # metadata: Optional[dict[str, Any]] = None
- class Config:
- json_schema_extra = {
- "name": "Relationship Name",
- "description": "Relationship Description",
- "metadata": {},
- }
- class KGCommunityResult(R2RSerializable):
- name: str
- summary: str
- rating: float
- rating_explanation: str
- findings: list[str]
- metadata: Optional[dict[str, Any]] = None
- class Config:
- json_schema_extra = {
- "name": "Community Name",
- "summary": "Community Summary",
- "rating": 9,
- "rating_explanation": "Rating Explanation",
- "findings": ["Finding 1", "Finding 2"],
- "metadata": {},
- }
- class GraphSearchResult(R2RSerializable):
- content: KGEntityResult | KGRelationshipResult | KGCommunityResult
- result_type: Optional[KGSearchResultType] = None
- chunk_ids: Optional[list[UUID]] = None
- metadata: dict[str, Any] = {}
- score: Optional[float] = None
- class Config:
- json_schema_extra = {
- "content": KGEntityResult.Config.json_schema_extra,
- "result_type": "entity",
- "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
- "metadata": {"associated_query": "What is the capital of France?"},
- }
- class WebSearchResult(R2RSerializable):
- title: str
- link: str
- snippet: str
- position: int
- type: str = "organic"
- date: Optional[str] = None
- sitelinks: Optional[list[dict]] = None
- class RelatedSearchResult(R2RSerializable):
- query: str
- type: str = "related"
- class PeopleAlsoAskResult(R2RSerializable):
- question: str
- snippet: str
- link: str
- title: str
- type: str = "peopleAlsoAsk"
- class WebSearchResponse(R2RSerializable):
- organic_results: list[WebSearchResult] = []
- related_searches: list[RelatedSearchResult] = []
- people_also_ask: list[PeopleAlsoAskResult] = []
- @classmethod
- def from_serper_results(cls, results: list[dict]) -> "WebSearchResponse":
- organic = []
- related = []
- paa = []
- for result in results:
- if result["type"] == "organic":
- organic.append(WebSearchResult(**result))
- elif result["type"] == "relatedSearches":
- related.append(RelatedSearchResult(**result))
- elif result["type"] == "peopleAlsoAsk":
- paa.append(PeopleAlsoAskResult(**result))
- return cls(
- organic_results=organic,
- related_searches=related,
- people_also_ask=paa,
- )
- class AggregateSearchResult(R2RSerializable):
- """Result of an aggregate search operation."""
- chunk_search_results: Optional[list[ChunkSearchResult]]
- graph_search_results: Optional[list[GraphSearchResult]] = None
- web_search_results: Optional[list[WebSearchResult]] = None
- def __str__(self) -> str:
- return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results})"
- def __repr__(self) -> str:
- return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results})"
- def as_dict(self) -> dict:
- return {
- "chunk_search_results": (
- [result.as_dict() for result in self.chunk_search_results]
- if self.chunk_search_results
- else []
- ),
- "graph_search_results": (
- [result.to_dict() for result in self.graph_search_results]
- if self.graph_search_results
- else []
- ),
- "web_search_results": (
- [result.to_dict() for result in self.web_search_results]
- if self.web_search_results
- else []
- ),
- }
- from enum import Enum
- from typing import Any, Optional
- from uuid import UUID
- from pydantic import Field
- from .base import R2RSerializable
- from .llm import GenerationConfig
- from .vector import IndexMeasure
- class HybridSearchSettings(R2RSerializable):
- """Settings for hybrid search combining full-text and semantic search."""
- full_text_weight: float = Field(
- default=1.0, description="Weight to apply to full text search"
- )
- semantic_weight: float = Field(
- default=5.0, description="Weight to apply to semantic search"
- )
- full_text_limit: int = Field(
- default=200,
- description="Maximum number of results to return from full text search",
- )
- rrf_k: int = Field(
- default=50, description="K-value for RRF (Rank Reciprocal Fusion)"
- )
- class ChunkSearchSettings(R2RSerializable):
- """Settings specific to chunk/vector search."""
- index_measure: IndexMeasure = Field(
- default=IndexMeasure.cosine_distance,
- description="The distance measure to use for indexing",
- )
- probes: int = Field(
- default=10,
- description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.",
- )
- ef_search: int = Field(
- default=40,
- description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.",
- )
- enabled: bool = Field(
- default=True,
- description="Whether to enable chunk search",
- )
- class GraphSearchSettings(R2RSerializable):
- """Settings specific to knowledge graph search."""
- generation_config: GenerationConfig = Field(
- default_factory=GenerationConfig,
- description="Configuration for text generation during graph search.",
- )
- graphrag_map_system: str = Field(
- default="graphrag_map_system",
- description="The system prompt for the graphrag map prompt.",
- )
- graphrag_reduce_system: str = Field(
- default="graphrag_reduce_system",
- description="The system prompt for the graphrag reduce prompt.",
- )
- max_community_description_length: int = Field(
- default=65536,
- )
- max_llm_queries_for_global_search: int = Field(
- default=250,
- )
- limits: dict[str, int] = Field(
- default={},
- )
- enabled: bool = Field(
- default=True,
- description="Whether to enable graph search",
- )
- class SearchSettings(R2RSerializable):
- """Main search settings class that combines shared settings with specialized settings for chunks and KG."""
- # Search type flags
- use_hybrid_search: bool = Field(
- default=False,
- description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.",
- )
- use_semantic_search: bool = Field(
- default=True,
- description="Whether to use semantic search",
- )
- use_fulltext_search: bool = Field(
- default=False,
- description="Whether to use full-text search",
- )
- # Common search parameters
- filters: dict[str, Any] = Field(
- default_factory=dict,
- description="""Filters to apply to the search. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
- Commonly seen filters include operations include the following:
- `{"document_id": {"$eq": "9fbe403b-..."}}`
- `{"document_id": {"$in": ["9fbe403b-...", "3e157b3a-..."]}}`
- `{"collection_ids": {"$overlap": ["122fdf6a-...", "..."]}}`
- `{"$and": {"$document_id": ..., "collection_ids": ...}}`""",
- )
- limit: int = Field(
- default=10,
- description="Maximum number of results to return",
- ge=1,
- le=1_000,
- )
- offset: int = Field(
- default=0,
- ge=0,
- description="Offset to paginate search results",
- )
- include_metadatas: bool = Field(
- default=True,
- description="Whether to include element metadata in the search results",
- )
- include_scores: bool = Field(
- default=True,
- description="Whether to include search score values in the search results",
- )
- # Search strategy and settings
- search_strategy: str = Field(
- default="vanilla",
- description="Search strategy to use (e.g., 'vanilla', 'query_fusion', 'hyde')",
- )
- hybrid_settings: HybridSearchSettings = Field(
- default_factory=HybridSearchSettings,
- description="Settings for hybrid search (only used if `use_semantic_search` and `use_fulltext_search` are both true)",
- )
- # Specialized settings
- chunk_settings: ChunkSearchSettings = Field(
- default_factory=ChunkSearchSettings,
- description="Settings specific to chunk/vector search",
- )
- graph_settings: GraphSearchSettings = Field(
- default_factory=GraphSearchSettings,
- description="Settings specific to knowledge graph search",
- )
- class Config:
- populate_by_name = True
- json_encoders = {UUID: str}
- json_schema_extra = {
- "use_semantic_search": True,
- "use_fulltext_search": False,
- "use_hybrid_search": False,
- "filters": {"category": "technology"},
- "limit": 20,
- "offset": 0,
- "search_strategy": "vanilla",
- "hybrid_settings": {
- "full_text_weight": 1.0,
- "semantic_weight": 5.0,
- "full_text_limit": 200,
- "rrf_k": 50,
- },
- "chunk_settings": {
- "enabled": True,
- "index_measure": "cosine_distance",
- "include_metadata": True,
- "probes": 10,
- "ef_search": 40,
- },
- "graph_settings": {
- "enabled": True,
- "generation_config": GenerationConfig.Config.json_schema_extra,
- "max_community_description_length": 65536,
- "max_llm_queries_for_global_search": 250,
- "limits": {
- "entity": 20,
- "relationship": 20,
- "community": 20,
- },
- },
- }
- def __init__(self, **data):
- # Handle legacy search_filters field
- data["filters"] = {
- **data.get("filters", {}),
- **data.get("search_filters", {}),
- }
- super().__init__(**data)
- def model_dump(self, *args, **kwargs):
- dump = super().model_dump(*args, **kwargs)
- return dump
- @classmethod
- def get_default(cls, mode: str) -> "SearchSettings":
- """Return default search settings for a given mode."""
- if mode == "basic":
- # A simpler search that relies primarily on semantic search.
- return cls(
- use_semantic_search=True,
- use_fulltext_search=False,
- use_hybrid_search=False,
- search_strategy="vanilla",
- # Other relevant defaults can be provided here as needed
- )
- elif mode == "advanced":
- # A more powerful, combined search that leverages both semantic and fulltext.
- return cls(
- use_semantic_search=True,
- use_fulltext_search=True,
- use_hybrid_search=True,
- search_strategy="hyde",
- # Other advanced defaults as needed
- )
- else:
- # For 'custom' or unrecognized modes, return a basic empty config.
- return cls()
- class SearchMode(str, Enum):
- """Search modes for the search endpoint."""
- basic = "basic"
- advanced = "advanced"
- custom = "custom"
- def select_search_filters(
- auth_user: Any,
- search_settings: SearchSettings,
- ) -> dict[str, Any]:
- filters = copy(search_settings.filters)
- selected_collections = None
- if not auth_user.is_superuser:
- user_collections = set(auth_user.collection_ids)
- for key in filters.keys():
- if "collection_ids" in key:
- selected_collections = set(filters[key]["$overlap"])
- break
- if selected_collections:
- allowed_collections = user_collections.intersection(
- selected_collections
- )
- else:
- allowed_collections = user_collections
- # for non-superusers, we filter by user_id and selected & allowed collections
- collection_filters = {
- "$or": [
- {"owner_id": {"$eq": auth_user.id}},
- {"collection_ids": {"$overlap": list(allowed_collections)}},
- ] # type: ignore
- }
- filters.pop("collection_ids", None)
- if filters != {}:
- filters = {"$and": [collection_filters, filters]} # type: ignore
- return filters
|