test_retrieval.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. import uuid
  2. import pytest
  3. from r2r import Message, R2RClient, R2RException, SearchMode
  4. @pytest.fixture(scope="session")
  5. def config():
  6. class TestConfig:
  7. base_url = "http://localhost:7272"
  8. superuser_email = "admin@example.com"
  9. superuser_password = "change_me_immediately"
  10. return TestConfig()
  11. @pytest.fixture(scope="session")
  12. def client(config):
  13. """Create a client instance and log in as a superuser."""
  14. client = R2RClient(config.base_url)
  15. client.users.login(config.superuser_email, config.superuser_password)
  16. return client
  17. def test_search_basic_mode(client):
  18. resp = client.retrieval.search(query="Aristotle", search_mode="basic")
  19. assert "results" in resp, "No results field in search response"
  20. def test_search_advanced_mode_with_filters(client):
  21. filters = {"metadata.document_type": {"$eq": "txt"}}
  22. resp = client.retrieval.search(
  23. query="Philosophy",
  24. search_mode="advanced",
  25. search_settings={"filters": filters, "limit": 5},
  26. )
  27. assert "results" in resp, "No results in advanced mode search"
  28. def test_search_custom_mode(client):
  29. resp = client.retrieval.search(
  30. query="Greek philosophers",
  31. search_mode="custom",
  32. search_settings={"use_semantic_search": True, "limit": 3},
  33. )
  34. assert "results" in resp, "No results in custom mode search"
  35. def test_rag_query(client):
  36. resp = client.retrieval.rag(
  37. query="Summarize Aristotle's contributions to logic",
  38. rag_generation_config={"stream": False, "max_tokens": 100},
  39. search_settings={"use_semantic_search": True, "limit": 3},
  40. )["results"]
  41. assert "completion" in resp, "RAG response missing 'completion'"
  42. def test_rag_with_filter(client):
  43. # Ensure a doc with metadata.tier='test' is created
  44. # generate a random string
  45. suffix = str(uuid.uuid4())
  46. client.documents.create(
  47. raw_text=f"Aristotle was a Greek philosopher, contributions to philosophy were in logic, {suffix}.",
  48. metadata={"tier": "test"},
  49. )
  50. resp = client.retrieval.rag(
  51. query="What were aristotle's contributions to philosophy?",
  52. rag_generation_config={"stream": False, "max_tokens": 100},
  53. search_settings={
  54. "filters": {"metadata.tier": {"$eq": "test"}},
  55. "use_semantic_search": True,
  56. "limit": 3,
  57. },
  58. )["results"]
  59. assert "completion" in resp, "RAG response missing 'completion'"
  60. def test_rag_stream_query(client):
  61. resp = client.retrieval.rag(
  62. query="Detail the philosophical schools Aristotle influenced",
  63. rag_generation_config={"stream": True, "max_tokens": 50},
  64. search_settings={"use_semantic_search": True, "limit": 2},
  65. )
  66. # Consume a few chunks from the async generator
  67. import asyncio
  68. async def consume_stream():
  69. count = 0
  70. async for chunk in resp:
  71. count += 1
  72. if count > 1:
  73. break
  74. return count
  75. count = asyncio.run(consume_stream())
  76. assert count > 0, "No chunks received from streamed RAG query"
  77. def test_agent_query(client):
  78. msg = Message(role="user", content="What is Aristotle known for?")
  79. resp = client.retrieval.agent(
  80. message=msg,
  81. rag_generation_config={"stream": False, "max_tokens": 100},
  82. search_settings={"use_semantic_search": True, "limit": 3},
  83. )
  84. assert "results" in resp, "Agent response missing 'results'"
  85. assert len(resp["results"]) > 0, "No messages returned by agent"
  86. def test_agent_query_stream(client):
  87. msg = Message(role="user", content="Explain Aristotle's logic in steps.")
  88. resp = client.retrieval.agent(
  89. message=msg,
  90. rag_generation_config={"stream": True, "max_tokens": 50},
  91. search_settings={"use_semantic_search": True, "limit": 3},
  92. )
  93. import asyncio
  94. async def consume_stream():
  95. count = 0
  96. async for chunk in resp:
  97. count += 1
  98. if count > 1:
  99. break
  100. return count
  101. count = asyncio.run(consume_stream())
  102. assert count > 0, "No streaming chunks received from agent"
  103. def test_completion(client):
  104. messages = [
  105. {"role": "system", "content": "You are a helpful assistant."},
  106. {"role": "user", "content": "What is the capital of France?"},
  107. {"role": "assistant", "content": "The capital of France is Paris."},
  108. {"role": "user", "content": "What about Italy?"},
  109. ]
  110. resp = client.retrieval.completion(
  111. messages, generation_config={"max_tokens": 50}
  112. )
  113. assert "results" in resp, "Completion response missing 'results'"
  114. assert "choices" in resp["results"], "No choices in completion result"
  115. def test_embedding(client):
  116. text = "Who is Aristotle?"
  117. resp = client.retrieval.embedding(text=text)["results"]
  118. assert len(resp) > 0, "No embedding vector returned"
  119. def test_error_handling(client):
  120. # Missing query should raise an error
  121. with pytest.raises(R2RException) as exc_info:
  122. client.retrieval.search(query=None) # type: ignore
  123. assert exc_info.value.status_code in [
  124. 400,
  125. 422,
  126. ], "Expected validation error for missing query"
  127. def test_no_results_scenario(client):
  128. resp = client.retrieval.search(
  129. query="aslkfjaldfjal",
  130. search_mode="custom",
  131. search_settings={
  132. "limit": 5,
  133. "use_semantic_search": False,
  134. "use_fulltext_search": True,
  135. },
  136. )
  137. results = resp.get("results", {}).get("chunk_search_results", [])
  138. assert len(results) == 0, "Expected no results for nonsense query"
  139. def test_pagination_limit_one(client):
  140. client.documents.create(
  141. chunks=[
  142. "a" + " " + str(uuid.uuid4()),
  143. "b" + " " + str(uuid.uuid4()),
  144. "c" + " " + str(uuid.uuid4()),
  145. ]
  146. )
  147. resp = client.retrieval.search(
  148. query="Aristotle", search_mode="basic", search_settings={"limit": 1}
  149. )["results"]
  150. assert (
  151. len(resp["chunk_search_results"]) == 1
  152. ), "Expected one result with limit=1"
  153. def test_pagination_offset(client):
  154. resp0 = client.retrieval.search(
  155. query="Aristotle",
  156. search_mode="basic",
  157. search_settings={"limit": 1, "offset": 0},
  158. )["results"]
  159. resp1 = client.retrieval.search(
  160. query="Aristotle",
  161. search_mode="basic",
  162. search_settings={"limit": 1, "offset": 1},
  163. )["results"]
  164. assert (
  165. resp0["chunk_search_results"][0]["text"]
  166. != resp1["chunk_search_results"][0]["text"]
  167. ), "Offset should return different results"
  168. def test_rag_task_prompt_override(client):
  169. custom_prompt = """
  170. Answer the query given immediately below given the context. End your answer with: [END-TEST-PROMPT]
  171. ### Query:
  172. {query}
  173. ### Context:
  174. {context}
  175. """
  176. resp = client.retrieval.rag(
  177. query="Tell me about Aristotle",
  178. rag_generation_config={"stream": False, "max_tokens": 50},
  179. search_settings={"use_semantic_search": True, "limit": 3},
  180. task_prompt_override=custom_prompt,
  181. )
  182. answer = resp["results"]["completion"]["choices"][0]["message"]["content"]
  183. assert (
  184. "[END-TEST-PROMPT]" in answer
  185. ), "Custom prompt override not reflected in RAG answer"
  186. def test_agent_conversation_id(client):
  187. conversation_id = client.conversations.create()["results"]["id"]
  188. msg = Message(role="user", content="What is Aristotle known for?")
  189. resp = client.retrieval.agent(
  190. message=msg,
  191. rag_generation_config={"stream": False, "max_tokens": 50},
  192. search_settings={"use_semantic_search": True, "limit": 3},
  193. conversation_id=conversation_id,
  194. )
  195. assert (
  196. len(resp.get("results", [])) > 0
  197. ), "No results from agent with conversation_id"
  198. msg2 = Message(role="user", content="Can you elaborate more?")
  199. resp2 = client.retrieval.agent(
  200. message=msg2,
  201. rag_generation_config={"stream": False, "max_tokens": 50},
  202. search_settings={"use_semantic_search": True, "limit": 3},
  203. conversation_id=conversation_id,
  204. )
  205. assert (
  206. len(resp2.get("results", [])) > 0
  207. ), "No results from agent in second turn of conversation"
  208. def test_complex_filters_and_fulltext(client, test_collection):
  209. # collection_id, doc_ids = _setup_collection_with_documents(client)
  210. # rating > 5
  211. me = client.users.me()
  212. # include owner id and collection ids to make robust against other database interactions from other users
  213. filters = {
  214. "rating": {"$gt": 5},
  215. "owner_id": {"$eq": client.users.me()["results"]["id"]},
  216. "collection_ids": {
  217. "$overlap": [str(test_collection["collection_id"])]
  218. },
  219. }
  220. resp = client.retrieval.search(
  221. query="a",
  222. search_mode=SearchMode.custom,
  223. search_settings={"use_semantic_search": True, "filters": filters},
  224. )["results"]
  225. results = resp["chunk_search_results"]
  226. assert (
  227. len(results) == 2
  228. ), f"Expected 2 docs with rating > 5, got {len(results)}"
  229. # category in [ancient, modern]
  230. filters = {
  231. "metadata.category": {"$in": ["ancient", "modern"]},
  232. "owner_id": {"$eq": client.users.me()["results"]["id"]},
  233. "collection_ids": {
  234. "$overlap": [str(test_collection["collection_id"])]
  235. },
  236. }
  237. resp = client.retrieval.search(
  238. query="b",
  239. search_mode=SearchMode.custom,
  240. search_settings={"use_semantic_search": True, "filters": filters},
  241. )["results"]
  242. results = resp["chunk_search_results"]
  243. assert len(results) == 4, f"Expected all 4 docs, got {len(results)}"
  244. # rating > 5 AND category=modern
  245. filters = {
  246. "$and": [
  247. {"metadata.rating": {"$gt": 5}},
  248. {"metadata.category": {"$eq": "modern"}},
  249. {"owner_id": {"$eq": client.users.me()["results"]["id"]}},
  250. {
  251. "collection_ids": {
  252. "$overlap": [str(test_collection["collection_id"])]
  253. }
  254. },
  255. ],
  256. }
  257. resp = client.retrieval.search(
  258. query="d",
  259. search_mode=SearchMode.custom,
  260. search_settings={"filters": filters},
  261. )["results"]
  262. results = resp["chunk_search_results"]
  263. assert (
  264. len(results) == 2
  265. ), f"Expected 2 modern docs with rating>5, got {len(results)}"
  266. # full-text search: "unique_philosopher"
  267. resp = client.retrieval.search(
  268. query="unique_philosopher",
  269. search_mode=SearchMode.custom,
  270. search_settings={
  271. "use_fulltext_search": True,
  272. "use_semantic_search": False,
  273. "filters": {
  274. "owner_id": {"$eq": client.users.me()["results"]["id"]},
  275. "collection_ids": {
  276. "$overlap": [str(test_collection["collection_id"])]
  277. },
  278. },
  279. },
  280. )["results"]
  281. results = resp["chunk_search_results"]
  282. assert (
  283. len(results) == 1
  284. ), f"Expected 1 doc for unique_philosopher, got {len(results)}"
  285. def test_complex_nested_filters(client, test_collection):
  286. # Setup docs
  287. # _setup_collection_with_documents(client)
  288. # ((category=ancient OR rating<5) AND tags contains 'philosophy')
  289. filters = {
  290. "$and": [
  291. {
  292. "$or": [
  293. {"metadata.category": {"$eq": "ancient"}},
  294. {"metadata.rating": {"$lt": 5}},
  295. ]
  296. },
  297. {"metadata.tags": {"$contains": ["philosophy"]}},
  298. {"owner_id": {"$eq": client.users.me()["results"]["id"]}},
  299. {
  300. "collection_ids": {
  301. "$overlap": [str(test_collection["collection_id"])]
  302. }
  303. },
  304. ],
  305. }
  306. resp = client.retrieval.search(
  307. query="complex",
  308. search_mode="custom",
  309. search_settings={"filters": filters},
  310. )["results"]
  311. results = resp["chunk_search_results"]
  312. assert len(results) == 2, f"Expected 2 docs, got {len(results)}"
  313. def test_invalid_operator(client):
  314. filters = {"metadata.category": {"$like": "%ancient%"}}
  315. with pytest.raises(R2RException):
  316. client.retrieval.search(
  317. query="abc",
  318. search_mode="custom",
  319. search_settings={"filters": filters},
  320. )
  321. def test_filters_no_match(client):
  322. filters = {"metadata.category": {"$in": ["nonexistent"]}}
  323. resp = client.retrieval.search(
  324. query="noresults",
  325. search_mode="custom",
  326. search_settings={"filters": filters},
  327. )["results"]
  328. results = resp["chunk_search_results"]
  329. assert len(results) == 0, f"Expected 0 docs, got {len(results)}"
  330. def test_pagination_extremes(client):
  331. chunk_list = client.chunks.list()
  332. total_entries = chunk_list["total_entries"]
  333. offset = total_entries + 100
  334. resp = client.retrieval.search(
  335. query="Aristotle",
  336. search_mode="basic",
  337. search_settings={"limit": 10, "offset": offset},
  338. )["results"]
  339. results = resp["chunk_search_results"]
  340. assert (
  341. len(results) == 0
  342. ), f"Expected no results at large offset, got {len(results)}"
  343. def test_full_text_stopwords(client):
  344. resp = client.retrieval.search(
  345. query="the",
  346. search_mode="custom",
  347. search_settings={
  348. "use_fulltext_search": True,
  349. "use_semantic_search": False,
  350. "limit": 5,
  351. },
  352. )
  353. assert "results" in resp, "No results field in stopword query response"
  354. def test_full_text_non_ascii(client):
  355. resp = client.retrieval.search(
  356. query="Aristotélēs",
  357. search_mode="custom",
  358. search_settings={
  359. "use_fulltext_search": True,
  360. "use_semantic_search": False,
  361. "limit": 3,
  362. },
  363. )
  364. assert "results" in resp, "No results field in non-ASCII query response"
  365. def test_missing_fields(client):
  366. filters = {"metadata.someNonExistentField": {"$eq": "anything"}}
  367. resp = client.retrieval.search(
  368. query="missingfield",
  369. search_mode="custom",
  370. search_settings={"filters": filters},
  371. )["results"]
  372. results = resp["chunk_search_results"]
  373. assert (
  374. len(results) == 0
  375. ), f"Expected 0 docs for a non-existent field, got {len(results)}"
  376. def test_rag_with_large_context(client):
  377. resp = client.retrieval.rag(
  378. query="Explain the contributions of Kant in detail",
  379. rag_generation_config={"stream": False, "max_tokens": 200},
  380. search_settings={"use_semantic_search": True, "limit": 10},
  381. )
  382. results = resp.get("results", {})
  383. assert "completion" in results, "RAG large context missing 'completion'"
  384. completion = results["completion"]["choices"][0]["message"]["content"]
  385. assert len(completion) > 0, "RAG large context returned empty answer"
  386. def test_agent_long_conversation(client):
  387. conversation = client.conversations.create()["results"]
  388. conversation_id = conversation["id"]
  389. msg1 = Message(role="user", content="What were Aristotle's main ideas?")
  390. resp1 = client.retrieval.agent(
  391. message=msg1,
  392. rag_generation_config={"stream": False, "max_tokens": 100},
  393. search_settings={"use_semantic_search": True, "limit": 5},
  394. conversation_id=conversation_id,
  395. )
  396. assert "results" in resp1, "No results in first turn of conversation"
  397. msg2 = Message(
  398. role="user", content="How did these ideas influence modern philosophy?"
  399. )
  400. resp2 = client.retrieval.agent(
  401. message=msg2,
  402. rag_generation_config={"stream": False, "max_tokens": 100},
  403. search_settings={"use_semantic_search": True, "limit": 5},
  404. conversation_id=conversation_id,
  405. )
  406. assert "results" in resp2, "No results in second turn of conversation"
  407. msg3 = Message(role="user", content="Now tell me about Descartes.")
  408. resp3 = client.retrieval.agent(
  409. message=msg3,
  410. rag_generation_config={"stream": False, "max_tokens": 100},
  411. search_settings={"use_semantic_search": True, "limit": 5},
  412. conversation_id=conversation_id,
  413. )
  414. assert "results" in resp3, "No results in third turn of conversation"
  415. def test_filter_by_document_type(client):
  416. random_suffix = str(uuid.uuid4())
  417. client.documents.create(
  418. chunks=[
  419. f"a {random_suffix}",
  420. f"b {random_suffix}",
  421. f"c {random_suffix}",
  422. ]
  423. )
  424. filters = {"document_type": {"$eq": "txt"}}
  425. resp = client.retrieval.search(
  426. query="a", search_settings={"filters": filters}
  427. )["results"]
  428. results = resp["chunk_search_results"]
  429. # Depending on your environment, if no txt documents exist this might fail.
  430. # Adjust accordingly to ensure there's a txt doc available or mock if needed.
  431. assert len(results) > 0, "No results found for filter by document type"