search.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. """Abstractions for search functionality."""
  2. from copy import copy
  3. from enum import Enum
  4. from typing import Any, Optional
  5. from uuid import NAMESPACE_DNS, UUID, uuid5
  6. from pydantic import Field
  7. from .base import R2RSerializable
  8. from .document import DocumentResponse
  9. from .llm import GenerationConfig
  10. from .vector import IndexMeasure
  11. def generate_id_from_label(label) -> UUID:
  12. return uuid5(NAMESPACE_DNS, label)
  13. class ChunkSearchResult(R2RSerializable):
  14. """Result of a search operation."""
  15. id: UUID
  16. document_id: UUID
  17. owner_id: Optional[UUID]
  18. collection_ids: list[UUID]
  19. score: Optional[float] = None
  20. text: str
  21. metadata: dict[str, Any]
  22. def __str__(self) -> str:
  23. if self.score:
  24. return (
  25. f"ChunkSearchResult(score={self.score:.3f}, text={self.text})"
  26. )
  27. else:
  28. return f"ChunkSearchResult(text={self.text})"
  29. def __repr__(self) -> str:
  30. return self.__str__()
  31. def as_dict(self) -> dict:
  32. return {
  33. "id": self.id,
  34. "document_id": self.document_id,
  35. "owner_id": self.owner_id,
  36. "collection_ids": self.collection_ids,
  37. "score": self.score,
  38. "text": self.text,
  39. "metadata": self.metadata,
  40. }
  41. class Config:
  42. populate_by_name = True
  43. json_schema_extra = {
  44. "example": {
  45. "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
  46. "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
  47. "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
  48. "collection_ids": [],
  49. "score": 0.23943702876567796,
  50. "text": "Example text from the document",
  51. "metadata": {
  52. "title": "example_document.pdf",
  53. "associated_query": "What is the capital of France?",
  54. },
  55. }
  56. }
  57. class GraphSearchResultType(str, Enum):
  58. ENTITY = "entity"
  59. RELATIONSHIP = "relationship"
  60. COMMUNITY = "community"
  61. class GraphEntityResult(R2RSerializable):
  62. id: Optional[UUID] = None
  63. name: str
  64. description: str
  65. metadata: Optional[dict[str, Any]] = None
  66. class Config:
  67. json_schema_extra = {
  68. "example": {
  69. "name": "Entity Name",
  70. "description": "Entity Description",
  71. "metadata": {},
  72. }
  73. }
  74. class GraphRelationshipResult(R2RSerializable):
  75. id: Optional[UUID] = None
  76. subject: str
  77. predicate: str
  78. object: str
  79. subject_id: Optional[UUID] = None
  80. object_id: Optional[UUID] = None
  81. metadata: Optional[dict[str, Any]] = None
  82. score: Optional[float] = None
  83. description: str | None = None
  84. class Config:
  85. json_schema_extra = {
  86. "example": {
  87. "name": "Relationship Name",
  88. "description": "Relationship Description",
  89. "metadata": {},
  90. }
  91. }
  92. def __str__(self) -> str:
  93. return f"GraphRelationshipResult(subject={self.subject}, predicate={self.predicate}, object={self.object})"
  94. class GraphCommunityResult(R2RSerializable):
  95. id: Optional[UUID] = None
  96. name: str
  97. summary: str
  98. metadata: Optional[dict[str, Any]] = None
  99. class Config:
  100. json_schema_extra = {
  101. "example": {
  102. "name": "Community Name",
  103. "summary": "Community Summary",
  104. "rating": 9,
  105. "rating_explanation": "Rating Explanation",
  106. "metadata": {},
  107. }
  108. }
  109. def __str__(self) -> str:
  110. return (
  111. f"GraphCommunityResult(name={self.name}, summary={self.summary})"
  112. )
  113. class GraphSearchResult(R2RSerializable):
  114. content: GraphEntityResult | GraphRelationshipResult | GraphCommunityResult
  115. result_type: Optional[GraphSearchResultType] = None
  116. chunk_ids: Optional[list[UUID]] = None
  117. metadata: dict[str, Any] = {}
  118. score: Optional[float] = None
  119. id: UUID
  120. def __str__(self) -> str:
  121. return f"GraphSearchResult(content={self.content}, result_type={self.result_type})"
  122. class Config:
  123. populate_by_name = True
  124. json_schema_extra = {
  125. "example": {
  126. "content": {
  127. "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
  128. "name": "Entity Name",
  129. "description": "Entity Description",
  130. "metadata": {},
  131. },
  132. "result_type": "entity",
  133. "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
  134. "metadata": {
  135. "associated_query": "What is the capital of France?"
  136. },
  137. }
  138. }
  139. class WebPageSearchResult(R2RSerializable):
  140. title: Optional[str] = None
  141. link: Optional[str] = None
  142. snippet: Optional[str] = None
  143. position: int
  144. type: str = "organic"
  145. date: Optional[str] = None
  146. sitelinks: Optional[list[dict]] = None
  147. id: UUID
  148. class Config:
  149. json_schema_extra = {
  150. "example": {
  151. "title": "Page Title",
  152. "link": "https://example.com/page",
  153. "snippet": "Page snippet",
  154. "position": 1,
  155. "date": "2021-01-01",
  156. "sitelinks": [
  157. {
  158. "title": "Sitelink Title",
  159. "link": "https://example.com/sitelink",
  160. }
  161. ],
  162. }
  163. }
  164. def __str__(self) -> str:
  165. return f"WebPageSearchResult(title={self.title}, link={self.link}, snippet={self.snippet})"
  166. class RelatedSearchResult(R2RSerializable):
  167. query: str
  168. type: str = "related"
  169. id: UUID
  170. class PeopleAlsoAskResult(R2RSerializable):
  171. question: str
  172. snippet: str
  173. link: str
  174. title: str
  175. id: UUID
  176. type: str = "peopleAlsoAsk"
  177. class WebSearchResult(R2RSerializable):
  178. organic_results: list[WebPageSearchResult] = []
  179. related_searches: list[RelatedSearchResult] = []
  180. people_also_ask: list[PeopleAlsoAskResult] = []
  181. @classmethod
  182. def from_serper_results(cls, results: list[dict]) -> "WebSearchResult":
  183. organic = []
  184. related = []
  185. paa = []
  186. for result in results:
  187. if result["type"] == "organic":
  188. organic.append(
  189. WebPageSearchResult(
  190. **result, id=generate_id_from_label(result.get("link"))
  191. )
  192. )
  193. elif result["type"] == "relatedSearches":
  194. related.append(
  195. RelatedSearchResult(
  196. **result,
  197. id=generate_id_from_label(result.get("query")),
  198. )
  199. )
  200. elif result["type"] == "peopleAlsoAsk":
  201. paa.append(
  202. PeopleAlsoAskResult(
  203. **result, id=generate_id_from_label(result.get("link"))
  204. )
  205. )
  206. return cls(
  207. organic_results=organic,
  208. related_searches=related,
  209. people_also_ask=paa,
  210. )
  211. class AggregateSearchResult(R2RSerializable):
  212. """Result of an aggregate search operation."""
  213. chunk_search_results: Optional[list[ChunkSearchResult]] = None
  214. graph_search_results: Optional[list[GraphSearchResult]] = None
  215. web_page_search_results: Optional[list[WebPageSearchResult]] = None
  216. web_search_results: Optional[list[WebSearchResult]] = None
  217. document_search_results: Optional[list[DocumentResponse]] = None
  218. generic_tool_result: Optional[Any] = (
  219. None # FIXME: Give this a proper generic type
  220. )
  221. def __str__(self) -> str:
  222. fields = [
  223. f"{field_name}={str(field_value)}"
  224. for field_name, field_value in self.__dict__.items()
  225. ]
  226. return f"AggregateSearchResult({', '.join(fields)})"
  227. def as_dict(self) -> dict:
  228. return {
  229. "chunk_search_results": (
  230. [result.as_dict() for result in self.chunk_search_results]
  231. if self.chunk_search_results
  232. else []
  233. ),
  234. "graph_search_results": (
  235. [result.to_dict() for result in self.graph_search_results]
  236. if self.graph_search_results
  237. else []
  238. ),
  239. "web_page_search_results": (
  240. [result.to_dict() for result in self.web_page_search_results]
  241. if self.web_page_search_results
  242. else []
  243. ),
  244. "web_search_results": (
  245. [result.to_dict() for result in self.web_search_results]
  246. if self.web_search_results
  247. else []
  248. ),
  249. "document_search_results": (
  250. [cdr.to_dict() for cdr in self.document_search_results]
  251. if self.document_search_results
  252. else []
  253. ),
  254. "generic_tool_result": (
  255. [result.to_dict() for result in self.generic_tool_result]
  256. if self.generic_tool_result
  257. else []
  258. ),
  259. }
  260. class Config:
  261. populate_by_name = True
  262. json_schema_extra = {
  263. "example": {
  264. "chunk_search_results": [
  265. {
  266. "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
  267. "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
  268. "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
  269. "collection_ids": [],
  270. "score": 0.23943702876567796,
  271. "text": "Example text from the document",
  272. "metadata": {
  273. "title": "example_document.pdf",
  274. "associated_query": "What is the capital of France?",
  275. },
  276. }
  277. ],
  278. "graph_search_results": [
  279. {
  280. "content": {
  281. "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
  282. "name": "Entity Name",
  283. "description": "Entity Description",
  284. "metadata": {},
  285. },
  286. "result_type": "entity",
  287. "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
  288. "metadata": {
  289. "associated_query": "What is the capital of France?"
  290. },
  291. }
  292. ],
  293. "web_page_search_results": [
  294. {
  295. "title": "Page Title",
  296. "link": "https://example.com/page",
  297. "snippet": "Page snippet",
  298. "position": 1,
  299. "date": "2021-01-01",
  300. "sitelinks": [
  301. {
  302. "title": "Sitelink Title",
  303. "link": "https://example.com/sitelink",
  304. }
  305. ],
  306. }
  307. ],
  308. "web_search_results": [
  309. {
  310. "title": "Page Title",
  311. "link": "https://example.com/page",
  312. "snippet": "Page snippet",
  313. "position": 1,
  314. "date": "2021-01-01",
  315. "sitelinks": [
  316. {
  317. "title": "Sitelink Title",
  318. "link": "https://example.com/sitelink",
  319. }
  320. ],
  321. }
  322. ],
  323. "document_search_results": [
  324. {
  325. "document": {
  326. "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
  327. "title": "Document Title",
  328. "chunks": ["Chunk 1", "Chunk 2"],
  329. "metadata": {},
  330. },
  331. }
  332. ],
  333. "generic_tool_result": [
  334. {
  335. "result": "Generic tool result",
  336. "metadata": {"key": "value"},
  337. }
  338. ],
  339. }
  340. }
  341. class HybridSearchSettings(R2RSerializable):
  342. """Settings for hybrid search combining full-text and semantic search."""
  343. full_text_weight: float = Field(
  344. default=1.0, description="Weight to apply to full text search"
  345. )
  346. semantic_weight: float = Field(
  347. default=5.0, description="Weight to apply to semantic search"
  348. )
  349. full_text_limit: int = Field(
  350. default=200,
  351. description="Maximum number of results to return from full text search",
  352. )
  353. rrf_k: int = Field(
  354. default=50, description="K-value for RRF (Rank Reciprocal Fusion)"
  355. )
  356. class ChunkSearchSettings(R2RSerializable):
  357. """Settings specific to chunk/vector search."""
  358. index_measure: IndexMeasure = Field(
  359. default=IndexMeasure.cosine_distance,
  360. description="The distance measure to use for indexing",
  361. )
  362. probes: int = Field(
  363. default=10,
  364. description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.",
  365. )
  366. ef_search: int = Field(
  367. default=40,
  368. description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.",
  369. )
  370. enabled: bool = Field(
  371. default=True,
  372. description="Whether to enable chunk search",
  373. )
  374. class GraphSearchSettings(R2RSerializable):
  375. """Settings specific to knowledge graph search."""
  376. limits: dict[str, int] = Field(
  377. default={},
  378. )
  379. enabled: bool = Field(
  380. default=True,
  381. description="Whether to enable graph search",
  382. )
  383. class SearchSettings(R2RSerializable):
  384. """Main search settings class that combines shared settings with
  385. specialized settings for chunks and graph."""
  386. # Search type flags
  387. use_hybrid_search: bool = Field(
  388. default=False,
  389. description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.",
  390. )
  391. use_semantic_search: bool = Field(
  392. default=True,
  393. description="Whether to use semantic search",
  394. )
  395. use_fulltext_search: bool = Field(
  396. default=False,
  397. description="Whether to use full-text search",
  398. )
  399. # Common search parameters
  400. filters: dict[str, Any] = Field(
  401. default_factory=dict,
  402. description="""Filters to apply to the search. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
  403. Commonly seen filters include operations include the following:
  404. `{"document_id": {"$eq": "9fbe403b-..."}}`
  405. `{"document_id": {"$in": ["9fbe403b-...", "3e157b3a-..."]}}`
  406. `{"collection_ids": {"$overlap": ["122fdf6a-...", "..."]}}`
  407. `{"$and": {"$document_id": ..., "collection_ids": ...}}`""",
  408. )
  409. limit: int = Field(
  410. default=10,
  411. description="Maximum number of results to return",
  412. ge=1,
  413. le=1_000,
  414. )
  415. offset: int = Field(
  416. default=0,
  417. ge=0,
  418. description="Offset to paginate search results",
  419. )
  420. include_metadatas: bool = Field(
  421. default=True,
  422. description="Whether to include element metadata in the search results",
  423. )
  424. include_scores: bool = Field(
  425. default=True,
  426. description="""Whether to include search score values in the
  427. search results""",
  428. )
  429. # Search strategy and settings
  430. search_strategy: str = Field(
  431. default="vanilla",
  432. description="""Search strategy to use
  433. (e.g., 'vanilla', 'query_fusion', 'hyde')""",
  434. )
  435. hybrid_settings: HybridSearchSettings = Field(
  436. default_factory=HybridSearchSettings,
  437. description="""Settings for hybrid search (only used if
  438. `use_semantic_search` and `use_fulltext_search` are both true)""",
  439. )
  440. # Specialized settings
  441. chunk_settings: ChunkSearchSettings = Field(
  442. default_factory=ChunkSearchSettings,
  443. description="Settings specific to chunk/vector search",
  444. )
  445. graph_settings: GraphSearchSettings = Field(
  446. default_factory=GraphSearchSettings,
  447. description="Settings specific to knowledge graph search",
  448. )
  449. # For HyDE or multi-query:
  450. num_sub_queries: int = Field(
  451. default=5,
  452. description="Number of sub-queries/hypothetical docs to generate when using hyde or rag_fusion search strategies.",
  453. )
  454. class Config:
  455. populate_by_name = True
  456. json_encoders = {UUID: str}
  457. json_schema_extra = {
  458. "example": {
  459. "use_semantic_search": True,
  460. "use_fulltext_search": False,
  461. "use_hybrid_search": False,
  462. "filters": {"category": "technology"},
  463. "limit": 20,
  464. "offset": 0,
  465. "search_strategy": "vanilla",
  466. "hybrid_settings": {
  467. "full_text_weight": 1.0,
  468. "semantic_weight": 5.0,
  469. "full_text_limit": 200,
  470. "rrf_k": 50,
  471. },
  472. "chunk_settings": {
  473. "enabled": True,
  474. "index_measure": "cosine_distance",
  475. "include_metadata": True,
  476. "probes": 10,
  477. "ef_search": 40,
  478. },
  479. "graph_settings": {
  480. "enabled": True,
  481. "generation_config": GenerationConfig.Config.json_schema_extra,
  482. "max_community_description_length": 65536,
  483. "max_llm_queries_for_global_search": 250,
  484. "limits": {
  485. "entity": 20,
  486. "relationship": 20,
  487. "community": 20,
  488. },
  489. },
  490. }
  491. }
  492. def __init__(self, **data):
  493. # Handle legacy search_filters field
  494. data["filters"] = {
  495. **data.get("filters", {}),
  496. **data.get("search_filters", {}),
  497. }
  498. super().__init__(**data)
  499. def model_dump(self, *args, **kwargs):
  500. return super().model_dump(*args, **kwargs)
  501. @classmethod
  502. def get_default(cls, mode: str) -> "SearchSettings":
  503. """Return default search settings for a given mode."""
  504. if mode == "basic":
  505. # A simpler search that relies primarily on semantic search.
  506. return cls(
  507. use_semantic_search=True,
  508. use_fulltext_search=False,
  509. use_hybrid_search=False,
  510. search_strategy="vanilla",
  511. # Other relevant defaults can be provided here as needed
  512. )
  513. elif mode == "advanced":
  514. # A more powerful, combined search that leverages both semantic and fulltext.
  515. return cls(
  516. use_semantic_search=True,
  517. use_fulltext_search=True,
  518. use_hybrid_search=True,
  519. search_strategy="hyde",
  520. # Other advanced defaults as needed
  521. )
  522. else:
  523. # For 'custom' or unrecognized modes, return a basic empty config.
  524. return cls()
  525. class SearchMode(str, Enum):
  526. """Search modes for the search endpoint."""
  527. basic = "basic"
  528. advanced = "advanced"
  529. custom = "custom"
  530. def select_search_filters(
  531. auth_user: Any,
  532. search_settings: SearchSettings,
  533. ) -> dict[str, Any]:
  534. filters = copy(search_settings.filters)
  535. selected_collections = None
  536. if not auth_user.is_superuser:
  537. user_collections = set(auth_user.collection_ids)
  538. for key in filters.keys():
  539. if "collection_ids" in key:
  540. selected_collections = set(map(UUID, filters[key]["$overlap"]))
  541. break
  542. if selected_collections:
  543. allowed_collections = user_collections.intersection(
  544. selected_collections
  545. )
  546. else:
  547. allowed_collections = user_collections
  548. # for non-superusers, we filter by user_id and selected & allowed collections
  549. collection_filters = {
  550. "$or": [
  551. {"owner_id": {"$eq": auth_user.id}},
  552. {"collection_ids": {"$overlap": list(allowed_collections)}},
  553. ] # type: ignore
  554. }
  555. filters.pop("collection_ids", None)
  556. if filters != {}:
  557. filters = {"$and": [collection_filters, filters]} # type: ignore
  558. else:
  559. filters = collection_filters
  560. return filters