test_retrieval.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. import uuid
  2. import pytest
  3. from r2r import Message, R2RClient, R2RException, SearchMode
  4. @pytest.fixture(scope="session")
  5. def config():
  6. class TestConfig:
  7. base_url = "http://localhost:7272"
  8. superuser_email = "admin@example.com"
  9. superuser_password = "change_me_immediately"
  10. return TestConfig()
  11. @pytest.fixture(scope="session")
  12. def client(config):
  13. """Create a client instance and log in as a superuser."""
  14. client = R2RClient(config.base_url)
  15. client.users.login(config.superuser_email, config.superuser_password)
  16. return client
  17. def test_search_basic_mode(client):
  18. resp = client.retrieval.search(query="Aristotle", search_mode="basic")
  19. assert "results" in resp, "No results field in search response"
  20. def test_search_advanced_mode_with_filters(client):
  21. filters = {"document_type": {"$eq": "txt"}}
  22. resp = client.retrieval.search(
  23. query="Philosophy",
  24. search_mode="advanced",
  25. search_settings={"filters": filters, "limit": 5},
  26. )
  27. assert "results" in resp, "No results in advanced mode search"
  28. def test_search_custom_mode(client):
  29. resp = client.retrieval.search(
  30. query="Greek philosophers",
  31. search_mode="custom",
  32. search_settings={"use_semantic_search": True, "limit": 3},
  33. )
  34. assert "results" in resp, "No results in custom mode search"
  35. def test_rag_query(client):
  36. resp = client.retrieval.rag(
  37. query="Summarize Aristotle's contributions to logic",
  38. rag_generation_config={"stream": False, "max_tokens": 100},
  39. search_settings={"use_semantic_search": True, "limit": 3},
  40. )["results"]
  41. assert "completion" in resp, "RAG response missing 'completion'"
  42. def test_rag_with_filter(client):
  43. # Ensure a doc with metadata.tier='test' is created
  44. # generate a random string
  45. suffix = str(uuid.uuid4())
  46. client.documents.create(
  47. raw_text=f"Aristotle was a Greek philosopher, contributions to philosophy were in logic, {suffix}.",
  48. metadata={"tier": "test"},
  49. )
  50. resp = client.retrieval.rag(
  51. query="What were aristotle's contributions to philosophy?",
  52. rag_generation_config={"stream": False, "max_tokens": 100},
  53. search_settings={
  54. "filters": {"metadata.tier": {"$eq": "test"}},
  55. "use_semantic_search": True,
  56. "limit": 3,
  57. },
  58. )["results"]
  59. assert "completion" in resp, "RAG response missing 'completion'"
  60. def test_rag_stream_query(client):
  61. resp = client.retrieval.rag(
  62. query="Detail the philosophical schools Aristotle influenced",
  63. rag_generation_config={"stream": True, "max_tokens": 50},
  64. search_settings={"use_semantic_search": True, "limit": 2},
  65. )
  66. # Consume a few chunks from the async generator
  67. import asyncio
  68. async def consume_stream():
  69. count = 0
  70. async for chunk in resp:
  71. count += 1
  72. if count > 1:
  73. break
  74. return count
  75. count = asyncio.run(consume_stream())
  76. assert count > 0, "No chunks received from streamed RAG query"
  77. def test_agent_query(client):
  78. msg = Message(role="user", content="What is Aristotle known for?")
  79. resp = client.retrieval.agent(
  80. message=msg,
  81. rag_generation_config={"stream": False, "max_tokens": 100},
  82. search_settings={"use_semantic_search": True, "limit": 3},
  83. )
  84. assert "results" in resp, "Agent response missing 'results'"
  85. assert len(resp["results"]) > 0, "No messages returned by agent"
  86. def test_agent_query_stream(client):
  87. msg = Message(role="user", content="Explain Aristotle's logic in steps.")
  88. resp = client.retrieval.agent(
  89. message=msg,
  90. rag_generation_config={"stream": True, "max_tokens": 50},
  91. search_settings={"use_semantic_search": True, "limit": 3},
  92. )
  93. import asyncio
  94. async def consume_stream():
  95. count = 0
  96. async for chunk in resp:
  97. count += 1
  98. if count > 1:
  99. break
  100. return count
  101. count = asyncio.run(consume_stream())
  102. assert count > 0, "No streaming chunks received from agent"
  103. def test_completion(client):
  104. messages = [
  105. {"role": "system", "content": "You are a helpful assistant."},
  106. {"role": "user", "content": "What is the capital of France?"},
  107. {"role": "assistant", "content": "The capital of France is Paris."},
  108. {"role": "user", "content": "What about Italy?"},
  109. ]
  110. resp = client.retrieval.completion(
  111. messages, generation_config={"max_tokens": 50}
  112. )
  113. assert "results" in resp, "Completion response missing 'results'"
  114. assert "choices" in resp["results"], "No choices in completion result"
  115. def test_embedding(client):
  116. text = "Who is Aristotle?"
  117. resp = client.retrieval.embedding(text=text)["results"]
  118. assert len(resp) > 0, "No embedding vector returned"
  119. def test_error_handling(client):
  120. # Missing query should raise an error
  121. with pytest.raises(R2RException) as exc_info:
  122. client.retrieval.search(query=None) # type: ignore
  123. assert exc_info.value.status_code in [
  124. 400,
  125. 422,
  126. ], "Expected validation error for missing query"
  127. def test_no_results_scenario(client):
  128. resp = client.retrieval.search(
  129. query="aslkfjaldfjal",
  130. search_mode="custom",
  131. search_settings={
  132. "limit": 5,
  133. "use_semantic_search": False,
  134. "use_fulltext_search": True,
  135. },
  136. )
  137. results = resp.get("results", {}).get("chunk_search_results", [])
  138. assert len(results) == 0, "Expected no results for nonsense query"
  139. def test_pagination_limit_one(client):
  140. client.documents.create(
  141. chunks=[
  142. "a" + " " + str(uuid.uuid4()),
  143. "b" + " " + str(uuid.uuid4()),
  144. "c" + " " + str(uuid.uuid4()),
  145. ]
  146. )
  147. resp = client.retrieval.search(
  148. query="Aristotle", search_mode="basic", search_settings={"limit": 1}
  149. )["results"]
  150. assert (
  151. len(resp["chunk_search_results"]) == 1
  152. ), "Expected one result with limit=1"
  153. def test_pagination_offset(client):
  154. resp0 = client.retrieval.search(
  155. query="Aristotle",
  156. search_mode="basic",
  157. search_settings={"limit": 1, "offset": 0},
  158. )["results"]
  159. resp1 = client.retrieval.search(
  160. query="Aristotle",
  161. search_mode="basic",
  162. search_settings={"limit": 1, "offset": 1},
  163. )["results"]
  164. assert (
  165. resp0["chunk_search_results"][0]["text"]
  166. != resp1["chunk_search_results"][0]["text"]
  167. ), "Offset should return different results"
  168. def test_rag_task_prompt_override(client):
  169. custom_prompt = """
  170. Answer the query given immediately below given the context. End your answer with: [END-TEST-PROMPT]
  171. ### Query:
  172. {query}
  173. ### Context:
  174. {context}
  175. """
  176. resp = client.retrieval.rag(
  177. query="Tell me about Aristotle",
  178. rag_generation_config={"stream": False, "max_tokens": 50},
  179. search_settings={"use_semantic_search": True, "limit": 3},
  180. task_prompt_override=custom_prompt,
  181. )
  182. answer = resp["results"]["completion"]["choices"][0]["message"]["content"]
  183. assert (
  184. "[END-TEST-PROMPT]" in answer
  185. ), "Custom prompt override not reflected in RAG answer"
  186. def test_agent_conversation_id(client):
  187. conversation_id = client.conversations.create()["results"]["id"]
  188. msg = Message(role="user", content="What is Aristotle known for?")
  189. resp = client.retrieval.agent(
  190. message=msg,
  191. rag_generation_config={"stream": False, "max_tokens": 50},
  192. search_settings={"use_semantic_search": True, "limit": 3},
  193. conversation_id=conversation_id,
  194. )
  195. assert (
  196. len(resp.get("results", [])) > 0
  197. ), "No results from agent with conversation_id"
  198. msg2 = Message(role="user", content="Can you elaborate more?")
  199. resp2 = client.retrieval.agent(
  200. message=msg2,
  201. rag_generation_config={"stream": False, "max_tokens": 50},
  202. search_settings={"use_semantic_search": True, "limit": 3},
  203. conversation_id=conversation_id,
  204. )
  205. assert (
  206. len(resp2.get("results", [])) > 0
  207. ), "No results from agent in second turn of conversation"
  208. # def _setup_collection_with_documents(client):
  209. # collection_name = f"Test Collection {uuid.uuid4()}"
  210. # collection_id = client.collections.create(name=collection_name)["results"][
  211. # "id"
  212. # ]
  213. # docs = [
  214. # {
  215. # "text": f"Aristotle was a Greek philosopher who studied under Plato {str(uuid.uuid4())}.",
  216. # "metadata": {
  217. # "rating": 5,
  218. # "tags": ["philosophy", "greek"],
  219. # "category": "ancient",
  220. # },
  221. # },
  222. # {
  223. # "text": f"Socrates is considered a founder of Western philosophy {str(uuid.uuid4())}.",
  224. # "metadata": {
  225. # "rating": 3,
  226. # "tags": ["philosophy", "classical"],
  227. # "category": "ancient",
  228. # },
  229. # },
  230. # {
  231. # "text": f"Rene Descartes was a French philosopher. unique_philosopher {str(uuid.uuid4())}",
  232. # "metadata": {
  233. # "rating": 8,
  234. # "tags": ["rationalism", "french"],
  235. # "category": "modern",
  236. # },
  237. # },
  238. # {
  239. # "text": f"Immanuel Kant, a German philosopher, influenced Enlightenment thought {str(uuid.uuid4())}.",
  240. # "metadata": {
  241. # "rating": 7,
  242. # "tags": ["enlightenment", "german"],
  243. # "category": "modern",
  244. # },
  245. # },
  246. # ]
  247. # doc_ids = []
  248. # for doc in docs:
  249. # result = client.documents.create(
  250. # raw_text=doc["text"], metadata=doc["metadata"]
  251. # )["results"]
  252. # doc_id = result["document_id"]
  253. # doc_ids.append(doc_id)
  254. # client.collections.add_document(collection_id, doc_id)
  255. # return collection_id, doc_ids
  256. def test_complex_filters_and_fulltext(client, test_collection):
  257. # collection_id, doc_ids = _setup_collection_with_documents(client)
  258. # rating > 5
  259. filters = {"rating": {"$gt": 5}}
  260. resp = client.retrieval.search(
  261. query="a",
  262. search_mode=SearchMode.custom,
  263. search_settings={"use_semantic_search": True, "filters": filters},
  264. )["results"]
  265. results = resp["chunk_search_results"]
  266. print("results = ", results)
  267. assert (
  268. len(results) == 2
  269. ), f"Expected 2 docs with rating > 5, got {len(results)}"
  270. # category in [ancient, modern]
  271. filters = {"metadata.category": {"$in": ["ancient", "modern"]}}
  272. resp = client.retrieval.search(
  273. query="b",
  274. search_mode=SearchMode.custom,
  275. search_settings={"use_semantic_search": True, "filters": filters},
  276. )["results"]
  277. results = resp["chunk_search_results"]
  278. assert len(results) == 4, f"Expected all 4 docs, got {len(results)}"
  279. # rating > 5 AND category=modern
  280. filters = {
  281. "$and": [
  282. {"metadata.rating": {"$gt": 5}},
  283. {"metadata.category": {"$eq": "modern"}},
  284. ]
  285. }
  286. resp = client.retrieval.search(
  287. query="d",
  288. search_mode=SearchMode.custom,
  289. search_settings={"filters": filters},
  290. )["results"]
  291. results = resp["chunk_search_results"]
  292. assert (
  293. len(results) == 2
  294. ), f"Expected 2 modern docs with rating>5, got {len(results)}"
  295. # full-text search: "unique_philosopher"
  296. resp = client.retrieval.search(
  297. query="unique_philosopher",
  298. search_mode=SearchMode.custom,
  299. search_settings={
  300. "use_fulltext_search": True,
  301. "use_semantic_search": False,
  302. },
  303. )["results"]
  304. results = resp["chunk_search_results"]
  305. assert (
  306. len(results) == 1
  307. ), f"Expected 1 doc for unique_philosopher, got {len(results)}"
  308. def test_complex_nested_filters(client, test_collection):
  309. # Setup docs
  310. # _setup_collection_with_documents(client)
  311. # ((category=ancient OR rating<5) AND tags contains 'philosophy')
  312. filters = {
  313. "$and": [
  314. {
  315. "$or": [
  316. {"metadata.category": {"$eq": "ancient"}},
  317. {"metadata.rating": {"$lt": 5}},
  318. ]
  319. },
  320. {"metadata.tags": {"$contains": ["philosophy"]}},
  321. ]
  322. }
  323. resp = client.retrieval.search(
  324. query="complex",
  325. search_mode="custom",
  326. search_settings={"filters": filters},
  327. )["results"]
  328. results = resp["chunk_search_results"]
  329. print("results = ", results)
  330. assert len(results) == 2, f"Expected 2 docs, got {len(results)}"
  331. def test_invalid_operator(client):
  332. filters = {"metadata.category": {"$like": "%ancient%"}}
  333. with pytest.raises(R2RException):
  334. client.retrieval.search(
  335. query="abc",
  336. search_mode="custom",
  337. search_settings={"filters": filters},
  338. )
  339. def test_filters_no_match(client):
  340. filters = {"metadata.category": {"$in": ["nonexistent"]}}
  341. resp = client.retrieval.search(
  342. query="noresults",
  343. search_mode="custom",
  344. search_settings={"filters": filters},
  345. )["results"]
  346. results = resp["chunk_search_results"]
  347. assert len(results) == 0, f"Expected 0 docs, got {len(results)}"
  348. def test_pagination_extremes(client):
  349. base_resp = client.retrieval.search(query="Aristotle", search_mode="basic")
  350. total_entries = base_resp.get("page_info", {}).get("total_entries", 0)
  351. offset = total_entries + 100
  352. resp = client.retrieval.search(
  353. query="Aristotle",
  354. search_mode="basic",
  355. search_settings={"limit": 10, "offset": offset},
  356. )["results"]
  357. results = resp["chunk_search_results"]
  358. assert (
  359. len(results) == 0
  360. ), f"Expected no results at large offset, got {len(results)}"
  361. def test_full_text_stopwords(client):
  362. resp = client.retrieval.search(
  363. query="the",
  364. search_mode="custom",
  365. search_settings={
  366. "use_fulltext_search": True,
  367. "use_semantic_search": False,
  368. "limit": 5,
  369. },
  370. )
  371. assert "results" in resp, "No results field in stopword query response"
  372. def test_full_text_non_ascii(client):
  373. resp = client.retrieval.search(
  374. query="Aristotélēs",
  375. search_mode="custom",
  376. search_settings={
  377. "use_fulltext_search": True,
  378. "use_semantic_search": False,
  379. "limit": 3,
  380. },
  381. )
  382. assert "results" in resp, "No results field in non-ASCII query response"
  383. def test_missing_fields(client):
  384. filters = {"metadata.someNonExistentField": {"$eq": "anything"}}
  385. resp = client.retrieval.search(
  386. query="missingfield",
  387. search_mode="custom",
  388. search_settings={"filters": filters},
  389. )["results"]
  390. results = resp["chunk_search_results"]
  391. assert (
  392. len(results) == 0
  393. ), f"Expected 0 docs for a non-existent field, got {len(results)}"
  394. def test_rag_with_large_context(client):
  395. resp = client.retrieval.rag(
  396. query="Explain the contributions of Kant in detail",
  397. rag_generation_config={"stream": False, "max_tokens": 200},
  398. search_settings={"use_semantic_search": True, "limit": 10},
  399. )
  400. results = resp.get("results", {})
  401. assert "completion" in results, "RAG large context missing 'completion'"
  402. completion = results["completion"]["choices"][0]["message"]["content"]
  403. assert len(completion) > 0, "RAG large context returned empty answer"
  404. def test_agent_long_conversation(client):
  405. conversation = client.conversations.create()["results"]
  406. conversation_id = conversation["id"]
  407. msg1 = Message(role="user", content="What were Aristotle's main ideas?")
  408. resp1 = client.retrieval.agent(
  409. message=msg1,
  410. rag_generation_config={"stream": False, "max_tokens": 100},
  411. search_settings={"use_semantic_search": True, "limit": 5},
  412. conversation_id=conversation_id,
  413. )
  414. assert "results" in resp1, "No results in first turn of conversation"
  415. msg2 = Message(
  416. role="user", content="How did these ideas influence modern philosophy?"
  417. )
  418. resp2 = client.retrieval.agent(
  419. message=msg2,
  420. rag_generation_config={"stream": False, "max_tokens": 100},
  421. search_settings={"use_semantic_search": True, "limit": 5},
  422. conversation_id=conversation_id,
  423. )
  424. assert "results" in resp2, "No results in second turn of conversation"
  425. msg3 = Message(role="user", content="Now tell me about Descartes.")
  426. resp3 = client.retrieval.agent(
  427. message=msg3,
  428. rag_generation_config={"stream": False, "max_tokens": 100},
  429. search_settings={"use_semantic_search": True, "limit": 5},
  430. conversation_id=conversation_id,
  431. )
  432. assert "results" in resp3, "No results in third turn of conversation"
  433. def test_filter_by_document_type(client):
  434. client.documents.create(chunks=["a", "b", "c"])
  435. filters = {"document_type": {"$eq": "txt"}}
  436. resp = client.retrieval.search(
  437. query="a", search_settings={"filters": filters}
  438. )["results"]
  439. results = resp["chunk_search_results"]
  440. # Depending on your environment, if no txt documents exist this might fail.
  441. # Adjust accordingly to ensure there's a txt doc available or mock if needed.
  442. assert len(results) > 0, "No results found for filter by document type"