search.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. """Abstractions for search functionality."""
  2. from copy import copy
  3. from enum import Enum
  4. from typing import Any, Optional
  5. from uuid import UUID
  6. from pydantic import Field
  7. from .base import R2RSerializable
  8. from .llm import GenerationConfig
  9. from .vector import IndexMeasure
  10. class ChunkSearchResult(R2RSerializable):
  11. """Result of a search operation."""
  12. id: UUID
  13. document_id: UUID
  14. owner_id: Optional[UUID]
  15. collection_ids: list[UUID]
  16. score: float
  17. text: str
  18. metadata: dict[str, Any]
  19. def __str__(self) -> str:
  20. return f"ChunkSearchResult(id={self.id}, document_id={self.document_id}, score={self.score}, text={self.text}, metadata={self.metadata})"
  21. def __repr__(self) -> str:
  22. return self.__str__()
  23. def as_dict(self) -> dict:
  24. return {
  25. "id": self.id,
  26. "document_id": self.document_id,
  27. "owner_id": self.owner_id,
  28. "collection_ids": self.collection_ids,
  29. "score": self.score,
  30. "text": self.text,
  31. "metadata": self.metadata,
  32. }
  33. class Config:
  34. populate_by_name = True
  35. json_schema_extra = {
  36. "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
  37. "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
  38. "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
  39. "collection_ids": [],
  40. "score": 0.23943702876567796,
  41. "text": "Example text from the document",
  42. "metadata": {
  43. "title": "example_document.pdf",
  44. "associated_query": "What is the capital of France?",
  45. },
  46. }
  47. class KGSearchResultType(str, Enum):
  48. ENTITY = "entity"
  49. RELATIONSHIP = "relationship"
  50. COMMUNITY = "community"
  51. class KGEntityResult(R2RSerializable):
  52. name: str
  53. description: str
  54. metadata: Optional[dict[str, Any]] = None
  55. class Config:
  56. json_schema_extra = {
  57. "name": "Entity Name",
  58. "description": "Entity Description",
  59. "metadata": {},
  60. }
  61. class KGRelationshipResult(R2RSerializable):
  62. # name: str
  63. subject: str
  64. predicate: str
  65. object: str
  66. metadata: Optional[dict[str, Any]] = None
  67. score: Optional[float] = None
  68. # name: str
  69. # description: str
  70. # metadata: Optional[dict[str, Any]] = None
  71. class Config:
  72. json_schema_extra = {
  73. "name": "Relationship Name",
  74. "description": "Relationship Description",
  75. "metadata": {},
  76. }
  77. class KGCommunityResult(R2RSerializable):
  78. name: str
  79. summary: str
  80. rating: float
  81. rating_explanation: str
  82. findings: list[str]
  83. metadata: Optional[dict[str, Any]] = None
  84. class Config:
  85. json_schema_extra = {
  86. "name": "Community Name",
  87. "summary": "Community Summary",
  88. "rating": 9,
  89. "rating_explanation": "Rating Explanation",
  90. "findings": ["Finding 1", "Finding 2"],
  91. "metadata": {},
  92. }
  93. class GraphSearchResult(R2RSerializable):
  94. content: KGEntityResult | KGRelationshipResult | KGCommunityResult
  95. result_type: Optional[KGSearchResultType] = None
  96. chunk_ids: Optional[list[UUID]] = None
  97. metadata: dict[str, Any] = {}
  98. score: Optional[float] = None
  99. class Config:
  100. json_schema_extra = {
  101. "content": KGEntityResult.Config.json_schema_extra,
  102. "result_type": "entity",
  103. "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
  104. "metadata": {"associated_query": "What is the capital of France?"},
  105. }
  106. class WebSearchResult(R2RSerializable):
  107. title: str
  108. link: str
  109. snippet: str
  110. position: int
  111. type: str = "organic"
  112. date: Optional[str] = None
  113. sitelinks: Optional[list[dict]] = None
  114. class RelatedSearchResult(R2RSerializable):
  115. query: str
  116. type: str = "related"
  117. class PeopleAlsoAskResult(R2RSerializable):
  118. question: str
  119. snippet: str
  120. link: str
  121. title: str
  122. type: str = "peopleAlsoAsk"
  123. class WebSearchResponse(R2RSerializable):
  124. organic_results: list[WebSearchResult] = []
  125. related_searches: list[RelatedSearchResult] = []
  126. people_also_ask: list[PeopleAlsoAskResult] = []
  127. @classmethod
  128. def from_serper_results(cls, results: list[dict]) -> "WebSearchResponse":
  129. organic = []
  130. related = []
  131. paa = []
  132. for result in results:
  133. if result["type"] == "organic":
  134. organic.append(WebSearchResult(**result))
  135. elif result["type"] == "relatedSearches":
  136. related.append(RelatedSearchResult(**result))
  137. elif result["type"] == "peopleAlsoAsk":
  138. paa.append(PeopleAlsoAskResult(**result))
  139. return cls(
  140. organic_results=organic,
  141. related_searches=related,
  142. people_also_ask=paa,
  143. )
  144. class AggregateSearchResult(R2RSerializable):
  145. """Result of an aggregate search operation."""
  146. chunk_search_results: Optional[list[ChunkSearchResult]]
  147. graph_search_results: Optional[list[GraphSearchResult]] = None
  148. web_search_results: Optional[list[WebSearchResult]] = None
  149. def __str__(self) -> str:
  150. return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results})"
  151. def __repr__(self) -> str:
  152. return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results})"
  153. def as_dict(self) -> dict:
  154. return {
  155. "chunk_search_results": (
  156. [result.as_dict() for result in self.chunk_search_results]
  157. if self.chunk_search_results
  158. else []
  159. ),
  160. "graph_search_results": (
  161. [result.to_dict() for result in self.graph_search_results]
  162. if self.graph_search_results
  163. else []
  164. ),
  165. "web_search_results": (
  166. [result.to_dict() for result in self.web_search_results]
  167. if self.web_search_results
  168. else []
  169. ),
  170. }
  171. from enum import Enum
  172. from typing import Any, Optional
  173. from uuid import UUID
  174. from pydantic import Field
  175. from .base import R2RSerializable
  176. from .llm import GenerationConfig
  177. from .vector import IndexMeasure
  178. class HybridSearchSettings(R2RSerializable):
  179. """Settings for hybrid search combining full-text and semantic search."""
  180. full_text_weight: float = Field(
  181. default=1.0, description="Weight to apply to full text search"
  182. )
  183. semantic_weight: float = Field(
  184. default=5.0, description="Weight to apply to semantic search"
  185. )
  186. full_text_limit: int = Field(
  187. default=200,
  188. description="Maximum number of results to return from full text search",
  189. )
  190. rrf_k: int = Field(
  191. default=50, description="K-value for RRF (Rank Reciprocal Fusion)"
  192. )
  193. class ChunkSearchSettings(R2RSerializable):
  194. """Settings specific to chunk/vector search."""
  195. index_measure: IndexMeasure = Field(
  196. default=IndexMeasure.cosine_distance,
  197. description="The distance measure to use for indexing",
  198. )
  199. probes: int = Field(
  200. default=10,
  201. description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.",
  202. )
  203. ef_search: int = Field(
  204. default=40,
  205. description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.",
  206. )
  207. enabled: bool = Field(
  208. default=True,
  209. description="Whether to enable chunk search",
  210. )
  211. class GraphSearchSettings(R2RSerializable):
  212. """Settings specific to knowledge graph search."""
  213. generation_config: GenerationConfig = Field(
  214. default_factory=GenerationConfig,
  215. description="Configuration for text generation during graph search.",
  216. )
  217. graphrag_map_system: str = Field(
  218. default="graphrag_map_system",
  219. description="The system prompt for the graphrag map prompt.",
  220. )
  221. graphrag_reduce_system: str = Field(
  222. default="graphrag_reduce_system",
  223. description="The system prompt for the graphrag reduce prompt.",
  224. )
  225. max_community_description_length: int = Field(
  226. default=65536,
  227. )
  228. max_llm_queries_for_global_search: int = Field(
  229. default=250,
  230. )
  231. limits: dict[str, int] = Field(
  232. default={},
  233. )
  234. enabled: bool = Field(
  235. default=True,
  236. description="Whether to enable graph search",
  237. )
  238. class SearchSettings(R2RSerializable):
  239. """Main search settings class that combines shared settings with specialized settings for chunks and KG."""
  240. # Search type flags
  241. use_hybrid_search: bool = Field(
  242. default=False,
  243. description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.",
  244. )
  245. use_semantic_search: bool = Field(
  246. default=True,
  247. description="Whether to use semantic search",
  248. )
  249. use_fulltext_search: bool = Field(
  250. default=False,
  251. description="Whether to use full-text search",
  252. )
  253. # Common search parameters
  254. filters: dict[str, Any] = Field(
  255. default_factory=dict,
  256. description="""Filters to apply to the search. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
  257. Commonly seen filters include operations include the following:
  258. `{"document_id": {"$eq": "9fbe403b-..."}}`
  259. `{"document_id": {"$in": ["9fbe403b-...", "3e157b3a-..."]}}`
  260. `{"collection_ids": {"$overlap": ["122fdf6a-...", "..."]}}`
  261. `{"$and": {"$document_id": ..., "collection_ids": ...}}`""",
  262. )
  263. limit: int = Field(
  264. default=10,
  265. description="Maximum number of results to return",
  266. ge=1,
  267. le=1_000,
  268. )
  269. offset: int = Field(
  270. default=0,
  271. ge=0,
  272. description="Offset to paginate search results",
  273. )
  274. include_metadatas: bool = Field(
  275. default=True,
  276. description="Whether to include element metadata in the search results",
  277. )
  278. include_scores: bool = Field(
  279. default=True,
  280. description="Whether to include search score values in the search results",
  281. )
  282. # Search strategy and settings
  283. search_strategy: str = Field(
  284. default="vanilla",
  285. description="Search strategy to use (e.g., 'vanilla', 'query_fusion', 'hyde')",
  286. )
  287. hybrid_settings: HybridSearchSettings = Field(
  288. default_factory=HybridSearchSettings,
  289. description="Settings for hybrid search (only used if `use_semantic_search` and `use_fulltext_search` are both true)",
  290. )
  291. # Specialized settings
  292. chunk_settings: ChunkSearchSettings = Field(
  293. default_factory=ChunkSearchSettings,
  294. description="Settings specific to chunk/vector search",
  295. )
  296. graph_settings: GraphSearchSettings = Field(
  297. default_factory=GraphSearchSettings,
  298. description="Settings specific to knowledge graph search",
  299. )
  300. class Config:
  301. populate_by_name = True
  302. json_encoders = {UUID: str}
  303. json_schema_extra = {
  304. "use_semantic_search": True,
  305. "use_fulltext_search": False,
  306. "use_hybrid_search": False,
  307. "filters": {"category": "technology"},
  308. "limit": 20,
  309. "offset": 0,
  310. "search_strategy": "vanilla",
  311. "hybrid_settings": {
  312. "full_text_weight": 1.0,
  313. "semantic_weight": 5.0,
  314. "full_text_limit": 200,
  315. "rrf_k": 50,
  316. },
  317. "chunk_settings": {
  318. "enabled": True,
  319. "index_measure": "cosine_distance",
  320. "include_metadata": True,
  321. "probes": 10,
  322. "ef_search": 40,
  323. },
  324. "graph_settings": {
  325. "enabled": True,
  326. "generation_config": GenerationConfig.Config.json_schema_extra,
  327. "max_community_description_length": 65536,
  328. "max_llm_queries_for_global_search": 250,
  329. "limits": {
  330. "entity": 20,
  331. "relationship": 20,
  332. "community": 20,
  333. },
  334. },
  335. }
  336. def __init__(self, **data):
  337. # Handle legacy search_filters field
  338. data["filters"] = {
  339. **data.get("filters", {}),
  340. **data.get("search_filters", {}),
  341. }
  342. super().__init__(**data)
  343. def model_dump(self, *args, **kwargs):
  344. dump = super().model_dump(*args, **kwargs)
  345. return dump
  346. @classmethod
  347. def get_default(cls, mode: str) -> "SearchSettings":
  348. """Return default search settings for a given mode."""
  349. if mode == "basic":
  350. # A simpler search that relies primarily on semantic search.
  351. return cls(
  352. use_semantic_search=True,
  353. use_fulltext_search=False,
  354. use_hybrid_search=False,
  355. search_strategy="vanilla",
  356. # Other relevant defaults can be provided here as needed
  357. )
  358. elif mode == "advanced":
  359. # A more powerful, combined search that leverages both semantic and fulltext.
  360. return cls(
  361. use_semantic_search=True,
  362. use_fulltext_search=True,
  363. use_hybrid_search=True,
  364. search_strategy="hyde",
  365. # Other advanced defaults as needed
  366. )
  367. else:
  368. # For 'custom' or unrecognized modes, return a basic empty config.
  369. return cls()
  370. class SearchMode(str, Enum):
  371. """Search modes for the search endpoint."""
  372. basic = "basic"
  373. advanced = "advanced"
  374. custom = "custom"
  375. def select_search_filters(
  376. auth_user: Any,
  377. search_settings: SearchSettings,
  378. ) -> dict[str, Any]:
  379. filters = copy(search_settings.filters)
  380. selected_collections = None
  381. if not auth_user.is_superuser:
  382. user_collections = set(auth_user.collection_ids)
  383. for key in filters.keys():
  384. if "collection_ids" in key:
  385. selected_collections = set(filters[key]["$overlap"])
  386. break
  387. if selected_collections:
  388. allowed_collections = user_collections.intersection(
  389. selected_collections
  390. )
  391. else:
  392. allowed_collections = user_collections
  393. # for non-superusers, we filter by user_id and selected & allowed collections
  394. collection_filters = {
  395. "$or": [
  396. {"owner_id": {"$eq": auth_user.id}},
  397. {"collection_ids": {"$overlap": list(allowed_collections)}},
  398. ] # type: ignore
  399. }
  400. filters.pop("collection_ids", None)
  401. if filters != {}:
  402. filters = {"$and": [collection_filters, filters]} # type: ignore
  403. return filters