search.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. """Abstractions for search functionality."""
  2. from copy import copy
  3. from enum import Enum
  4. from typing import Any, Optional
  5. from uuid import UUID
  6. from pydantic import Field
  7. from .base import R2RSerializable
  8. from .llm import GenerationConfig
  9. from .vector import IndexMeasure
  10. class ChunkSearchResult(R2RSerializable):
  11. """Result of a search operation."""
  12. id: UUID
  13. document_id: UUID
  14. owner_id: Optional[UUID]
  15. collection_ids: list[UUID]
  16. score: float
  17. text: str
  18. metadata: dict[str, Any]
  19. def __str__(self) -> str:
  20. return f"ChunkSearchResult(id={self.id}, document_id={self.document_id}, score={self.score}, text={self.text}, metadata={self.metadata})"
  21. def __repr__(self) -> str:
  22. return self.__str__()
  23. def as_dict(self) -> dict:
  24. return {
  25. "id": self.id,
  26. "document_id": self.document_id,
  27. "owner_id": self.owner_id,
  28. "collection_ids": self.collection_ids,
  29. "score": self.score,
  30. "text": self.text,
  31. "metadata": self.metadata,
  32. }
  33. class Config:
  34. populate_by_name = True
  35. json_schema_extra = {
  36. "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
  37. "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
  38. "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
  39. "collection_ids": [],
  40. "score": 0.23943702876567796,
  41. "text": "Example text from the document",
  42. "metadata": {
  43. "title": "example_document.pdf",
  44. "associated_query": "What is the capital of France?",
  45. },
  46. }
  47. class KGSearchResultType(str, Enum):
  48. ENTITY = "entity"
  49. RELATIONSHIP = "relationship"
  50. COMMUNITY = "community"
  51. class KGEntityResult(R2RSerializable):
  52. name: str
  53. description: str
  54. metadata: Optional[dict[str, Any]] = None
  55. class Config:
  56. json_schema_extra = {
  57. "name": "Entity Name",
  58. "description": "Entity Description",
  59. "metadata": {},
  60. }
  61. class KGRelationshipResult(R2RSerializable):
  62. # name: str
  63. subject: str
  64. predicate: str
  65. object: str
  66. metadata: Optional[dict[str, Any]] = None
  67. score: Optional[float] = None
  68. # name: str
  69. # description: str
  70. # metadata: Optional[dict[str, Any]] = None
  71. class Config:
  72. json_schema_extra = {
  73. "name": "Relationship Name",
  74. "description": "Relationship Description",
  75. "metadata": {},
  76. }
  77. class KGCommunityResult(R2RSerializable):
  78. name: str
  79. summary: str
  80. rating: float
  81. rating_explanation: str
  82. findings: list[str]
  83. metadata: Optional[dict[str, Any]] = None
  84. class Config:
  85. json_schema_extra = {
  86. "name": "Community Name",
  87. "summary": "Community Summary",
  88. "rating": 9,
  89. "rating_explanation": "Rating Explanation",
  90. "findings": ["Finding 1", "Finding 2"],
  91. "metadata": {},
  92. }
  93. class KGGlobalResult(R2RSerializable):
  94. name: str
  95. description: str
  96. metadata: Optional[dict[str, Any]] = None
  97. class Config:
  98. json_schema_extra = {
  99. "name": "Global Result Name",
  100. "description": "Global Result Description",
  101. "metadata": {},
  102. }
  103. class GraphSearchResult(R2RSerializable):
  104. content: (
  105. KGEntityResult
  106. | KGRelationshipResult
  107. | KGCommunityResult
  108. | KGGlobalResult
  109. )
  110. result_type: Optional[KGSearchResultType] = None
  111. chunk_ids: Optional[list[UUID]] = None
  112. metadata: dict[str, Any] = {}
  113. score: Optional[float] = None
  114. class Config:
  115. json_schema_extra = {
  116. "content": KGEntityResult.Config.json_schema_extra,
  117. "result_type": "entity",
  118. "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
  119. "metadata": {"associated_query": "What is the capital of France?"},
  120. }
  121. class WebSearchResult(R2RSerializable):
  122. title: str
  123. link: str
  124. snippet: str
  125. position: int
  126. type: str = "organic"
  127. date: Optional[str] = None
  128. sitelinks: Optional[list[dict]] = None
  129. class RelatedSearchResult(R2RSerializable):
  130. query: str
  131. type: str = "related"
  132. class PeopleAlsoAskResult(R2RSerializable):
  133. question: str
  134. snippet: str
  135. link: str
  136. title: str
  137. type: str = "peopleAlsoAsk"
  138. class WebSearchResponse(R2RSerializable):
  139. organic_results: list[WebSearchResult] = []
  140. related_searches: list[RelatedSearchResult] = []
  141. people_also_ask: list[PeopleAlsoAskResult] = []
  142. @classmethod
  143. def from_serper_results(cls, results: list[dict]) -> "WebSearchResponse":
  144. organic = []
  145. related = []
  146. paa = []
  147. for result in results:
  148. if result["type"] == "organic":
  149. organic.append(WebSearchResult(**result))
  150. elif result["type"] == "relatedSearches":
  151. related.append(RelatedSearchResult(**result))
  152. elif result["type"] == "peopleAlsoAsk":
  153. paa.append(PeopleAlsoAskResult(**result))
  154. return cls(
  155. organic_results=organic,
  156. related_searches=related,
  157. people_also_ask=paa,
  158. )
  159. class AggregateSearchResult(R2RSerializable):
  160. """Result of an aggregate search operation."""
  161. chunk_search_results: Optional[list[ChunkSearchResult]]
  162. graph_search_results: Optional[list[GraphSearchResult]] = None
  163. web_search_results: Optional[list[WebSearchResult]] = None
  164. def __str__(self) -> str:
  165. return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results})"
  166. def __repr__(self) -> str:
  167. return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results})"
  168. def as_dict(self) -> dict:
  169. return {
  170. "chunk_search_results": (
  171. [result.as_dict() for result in self.chunk_search_results]
  172. if self.chunk_search_results
  173. else []
  174. ),
  175. "graph_search_results": (
  176. [result.to_dict() for result in self.graph_search_results]
  177. if self.graph_search_results
  178. else []
  179. ),
  180. "web_search_results": (
  181. [result.to_dict() for result in self.web_search_results]
  182. if self.web_search_results
  183. else []
  184. ),
  185. }
  186. from enum import Enum
  187. from typing import Any, Optional
  188. from uuid import UUID
  189. from pydantic import Field
  190. from .base import R2RSerializable
  191. from .llm import GenerationConfig
  192. from .vector import IndexMeasure
  193. class HybridSearchSettings(R2RSerializable):
  194. """Settings for hybrid search combining full-text and semantic search."""
  195. full_text_weight: float = Field(
  196. default=1.0, description="Weight to apply to full text search"
  197. )
  198. semantic_weight: float = Field(
  199. default=5.0, description="Weight to apply to semantic search"
  200. )
  201. full_text_limit: int = Field(
  202. default=200,
  203. description="Maximum number of results to return from full text search",
  204. )
  205. rrf_k: int = Field(
  206. default=50, description="K-value for RRF (Rank Reciprocal Fusion)"
  207. )
  208. class ChunkSearchSettings(R2RSerializable):
  209. """Settings specific to chunk/vector search."""
  210. index_measure: IndexMeasure = Field(
  211. default=IndexMeasure.cosine_distance,
  212. description="The distance measure to use for indexing",
  213. )
  214. probes: int = Field(
  215. default=10,
  216. description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.",
  217. )
  218. ef_search: int = Field(
  219. default=40,
  220. description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.",
  221. )
  222. enabled: bool = Field(
  223. default=True,
  224. description="Whether to enable chunk search",
  225. )
  226. class GraphSearchSettings(R2RSerializable):
  227. """Settings specific to knowledge graph search."""
  228. generation_config: GenerationConfig = Field(
  229. default_factory=GenerationConfig,
  230. description="Configuration for text generation during graph search.",
  231. )
  232. graphrag_map_system: str = Field(
  233. default="graphrag_map_system",
  234. description="The system prompt for the graphrag map prompt.",
  235. )
  236. graphrag_reduce_system: str = Field(
  237. default="graphrag_reduce_system",
  238. description="The system prompt for the graphrag reduce prompt.",
  239. )
  240. max_community_description_length: int = Field(
  241. default=65536,
  242. )
  243. max_llm_queries_for_global_search: int = Field(
  244. default=250,
  245. )
  246. limits: dict[str, int] = Field(
  247. default={},
  248. )
  249. enabled: bool = Field(
  250. default=True,
  251. description="Whether to enable graph search",
  252. )
  253. class SearchSettings(R2RSerializable):
  254. """Main search settings class that combines shared settings with specialized settings for chunks and KG."""
  255. # Search type flags
  256. use_hybrid_search: bool = Field(
  257. default=False,
  258. description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.",
  259. )
  260. use_semantic_search: bool = Field(
  261. default=True,
  262. description="Whether to use semantic search",
  263. )
  264. use_fulltext_search: bool = Field(
  265. default=False,
  266. description="Whether to use full-text search",
  267. )
  268. # Common search parameters
  269. filters: dict[str, Any] = Field(
  270. default_factory=dict,
  271. description="""Filters to apply to the search. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
  272. Commonly seen filters include operations include the following:
  273. `{"document_id": {"$eq": "9fbe403b-..."}}`
  274. `{"document_id": {"$in": ["9fbe403b-...", "3e157b3a-..."]}}`
  275. `{"collection_ids": {"$overlap": ["122fdf6a-...", "..."]}}`
  276. `{"$and": {"$document_id": ..., "collection_ids": ...}}`""",
  277. )
  278. limit: int = Field(
  279. default=10,
  280. description="Maximum number of results to return",
  281. ge=1,
  282. le=1_000,
  283. )
  284. offset: int = Field(
  285. default=0,
  286. ge=0,
  287. description="Offset to paginate search results",
  288. )
  289. include_metadatas: bool = Field(
  290. default=True,
  291. description="Whether to include element metadata in the search results",
  292. )
  293. include_scores: bool = Field(
  294. default=True,
  295. description="Whether to include search score values in the search results",
  296. )
  297. # Search strategy and settings
  298. search_strategy: str = Field(
  299. default="vanilla",
  300. description="Search strategy to use (e.g., 'vanilla', 'query_fusion', 'hyde')",
  301. )
  302. hybrid_settings: HybridSearchSettings = Field(
  303. default_factory=HybridSearchSettings,
  304. description="Settings for hybrid search (only used if `use_semantic_search` and `use_fulltext_search` are both true)",
  305. )
  306. # Specialized settings
  307. chunk_settings: ChunkSearchSettings = Field(
  308. default_factory=ChunkSearchSettings,
  309. description="Settings specific to chunk/vector search",
  310. )
  311. graph_settings: GraphSearchSettings = Field(
  312. default_factory=GraphSearchSettings,
  313. description="Settings specific to knowledge graph search",
  314. )
  315. class Config:
  316. populate_by_name = True
  317. json_encoders = {UUID: str}
  318. json_schema_extra = {
  319. "use_semantic_search": True,
  320. "use_fulltext_search": False,
  321. "use_hybrid_search": False,
  322. "filters": {"category": "technology"},
  323. "limit": 20,
  324. "offset": 0,
  325. "search_strategy": "vanilla",
  326. "hybrid_settings": {
  327. "full_text_weight": 1.0,
  328. "semantic_weight": 5.0,
  329. "full_text_limit": 200,
  330. "rrf_k": 50,
  331. },
  332. "chunk_settings": {
  333. "enabled": True,
  334. "index_measure": "cosine_distance",
  335. "include_metadata": True,
  336. "probes": 10,
  337. "ef_search": 40,
  338. },
  339. "graph_settings": {
  340. "enabled": True,
  341. "generation_config": GenerationConfig.Config.json_schema_extra,
  342. "max_community_description_length": 65536,
  343. "max_llm_queries_for_global_search": 250,
  344. "limits": {
  345. "entity": 20,
  346. "relationship": 20,
  347. "community": 20,
  348. },
  349. },
  350. }
  351. def __init__(self, **data):
  352. # Handle legacy search_filters field
  353. data["filters"] = {
  354. **data.get("filters", {}),
  355. **data.get("search_filters", {}),
  356. }
  357. super().__init__(**data)
  358. def model_dump(self, *args, **kwargs):
  359. dump = super().model_dump(*args, **kwargs)
  360. return dump
  361. @classmethod
  362. def get_default(cls, mode: str) -> "SearchSettings":
  363. """Return default search settings for a given mode."""
  364. if mode == "basic":
  365. # A simpler search that relies primarily on semantic search.
  366. return cls(
  367. use_semantic_search=True,
  368. use_fulltext_search=False,
  369. use_hybrid_search=False,
  370. search_strategy="vanilla",
  371. # Other relevant defaults can be provided here as needed
  372. )
  373. elif mode == "advanced":
  374. # A more powerful, combined search that leverages both semantic and fulltext.
  375. return cls(
  376. use_semantic_search=True,
  377. use_fulltext_search=True,
  378. use_hybrid_search=True,
  379. search_strategy="hyde",
  380. # Other advanced defaults as needed
  381. )
  382. else:
  383. # For 'custom' or unrecognized modes, return a basic empty config.
  384. return cls()
  385. class SearchMode(str, Enum):
  386. """Search modes for the search endpoint."""
  387. basic = "basic"
  388. advanced = "advanced"
  389. custom = "custom"
  390. def select_search_filters(
  391. auth_user: Any,
  392. search_settings: SearchSettings,
  393. ) -> dict[str, Any]:
  394. filters = copy(search_settings.filters)
  395. selected_collections = None
  396. if not auth_user.is_superuser:
  397. user_collections = set(auth_user.collection_ids)
  398. for key in filters.keys():
  399. if "collection_ids" in key:
  400. selected_collections = set(filters[key]["$overlap"])
  401. break
  402. if selected_collections:
  403. allowed_collections = user_collections.intersection(
  404. selected_collections
  405. )
  406. else:
  407. allowed_collections = user_collections
  408. # for non-superusers, we filter by user_id and selected & allowed collections
  409. collection_filters = {
  410. "$or": [
  411. {"owner_id": {"$eq": auth_user.id}},
  412. {"collection_ids": {"$overlap": list(allowed_collections)}},
  413. ] # type: ignore
  414. }
  415. filters.pop("collection_ids", None)
  416. if filters != {}:
  417. filters = {"$and": [collection_filters, filters]} # type: ignore
  418. return filters