test_chunks.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # tests/integration/test_chunks.py
  2. import asyncio
  3. import uuid
  4. from typing import AsyncGenerator, Optional, Tuple
  5. import pytest
  6. from r2r import R2RAsyncClient, R2RException
  7. class AsyncR2RTestClient:
  8. """Wrapper to ensure async operations use the correct event loop"""
  9. def __init__(self, base_url: str = "http://localhost:7272"):
  10. self.client = R2RAsyncClient(base_url)
  11. async def create_document(
  12. self, chunks: list[str], run_with_orchestration: bool = False
  13. ) -> Tuple[str, list[dict]]:
  14. response = await self.client.documents.create(
  15. chunks=chunks, run_with_orchestration=run_with_orchestration
  16. )
  17. return response["results"]["document_id"], []
  18. async def delete_document(self, doc_id: str) -> None:
  19. await self.client.documents.delete(id=doc_id)
  20. async def list_chunks(self, doc_id: str) -> list[dict]:
  21. response = await self.client.documents.list_chunks(id=doc_id)
  22. return response["results"]
  23. async def retrieve_chunk(self, chunk_id: str) -> dict:
  24. response = await self.client.chunks.retrieve(id=chunk_id)
  25. return response["results"]
  26. async def update_chunk(
  27. self, chunk_id: str, text: str, metadata: Optional[dict] = None
  28. ) -> dict:
  29. response = await self.client.chunks.update(
  30. {"id": chunk_id, "text": text, "metadata": metadata or {}}
  31. )
  32. return response["results"]
  33. async def delete_chunk(self, chunk_id: str) -> dict:
  34. response = await self.client.chunks.delete(id=chunk_id)
  35. return response["results"]
  36. async def search_chunks(self, query: str, limit: int = 5) -> list[dict]:
  37. response = await self.client.chunks.search(
  38. query=query, search_settings={"limit": limit}
  39. )
  40. return response["results"]
  41. async def register_user(self, email: str, password: str) -> None:
  42. await self.client.users.register(email, password)
  43. async def login_user(self, email: str, password: str) -> None:
  44. await self.client.users.login(email, password)
  45. async def logout_user(self) -> None:
  46. await self.client.users.logout()
  47. @pytest.fixture
  48. async def test_client() -> AsyncGenerator[AsyncR2RTestClient, None]:
  49. """Create a test client."""
  50. client = AsyncR2RTestClient()
  51. yield client
  52. @pytest.fixture
  53. async def test_document(
  54. test_client: AsyncR2RTestClient,
  55. ) -> AsyncGenerator[Tuple[str, list[dict]], None]:
  56. """Create a test document with chunks."""
  57. doc_id, _ = await test_client.create_document(
  58. ["Test chunk 1", "Test chunk 2"]
  59. )
  60. await asyncio.sleep(1) # Wait for ingestion
  61. chunks = await test_client.list_chunks(doc_id)
  62. yield doc_id, chunks
  63. try:
  64. await test_client.delete_document(doc_id)
  65. except R2RException:
  66. pass
  67. class TestChunks:
  68. @pytest.mark.asyncio
  69. async def test_create_and_list_chunks(
  70. self, test_client: AsyncR2RTestClient
  71. ):
  72. # Create document with chunks
  73. doc_id, _ = await test_client.create_document(
  74. ["Hello chunk", "World chunk"]
  75. )
  76. await asyncio.sleep(1) # Wait for ingestion
  77. # List and verify chunks
  78. chunks = await test_client.list_chunks(doc_id)
  79. assert len(chunks) == 2, "Expected 2 chunks in the document"
  80. # Cleanup
  81. await test_client.delete_document(doc_id)
  82. @pytest.mark.asyncio
  83. async def test_retrieve_chunk(
  84. self, test_client: AsyncR2RTestClient, test_document
  85. ):
  86. doc_id, chunks = test_document
  87. chunk_id = chunks[0]["id"]
  88. retrieved = await test_client.retrieve_chunk(chunk_id)
  89. assert retrieved["id"] == chunk_id, "Retrieved wrong chunk ID"
  90. assert retrieved["text"] == "Test chunk 1", "Chunk text mismatch"
  91. @pytest.mark.asyncio
  92. async def test_update_chunk(
  93. self, test_client: AsyncR2RTestClient, test_document
  94. ):
  95. doc_id, chunks = test_document
  96. chunk_id = chunks[0]["id"]
  97. # Update chunk
  98. updated = await test_client.update_chunk(
  99. chunk_id, "Updated text", {"version": 2}
  100. )
  101. assert updated["text"] == "Updated text", "Chunk text not updated"
  102. assert updated["metadata"]["version"] == 2, "Metadata not updated"
  103. @pytest.mark.asyncio
  104. async def test_delete_chunk(
  105. self, test_client: AsyncR2RTestClient, test_document
  106. ):
  107. doc_id, chunks = test_document
  108. chunk_id = chunks[0]["id"]
  109. # Delete and verify
  110. result = await test_client.delete_chunk(chunk_id)
  111. assert result["success"], "Chunk deletion failed"
  112. # Verify deletion
  113. with pytest.raises(R2RException) as exc_info:
  114. await test_client.retrieve_chunk(chunk_id)
  115. assert exc_info.value.status_code == 404
  116. @pytest.mark.asyncio
  117. async def test_search_chunks(self, test_client: AsyncR2RTestClient):
  118. # Create searchable document
  119. doc_id, _ = await test_client.create_document(
  120. ["Aristotle reference", "Another piece of text"]
  121. )
  122. await asyncio.sleep(1) # Wait for indexing
  123. # Search
  124. results = await test_client.search_chunks("Aristotle")
  125. assert len(results) > 0, "No search results found"
  126. # Cleanup
  127. await test_client.delete_document(doc_id)
  128. @pytest.mark.asyncio
  129. async def test_unauthorized_chunk_access(
  130. self, test_client: AsyncR2RTestClient, test_document
  131. ):
  132. doc_id, chunks = test_document
  133. chunk_id = chunks[0]["id"]
  134. # Create and login as different user
  135. non_owner_client = AsyncR2RTestClient()
  136. email = f"test_{uuid.uuid4()}@example.com"
  137. await non_owner_client.register_user(email, "password123")
  138. await non_owner_client.login_user(email, "password123")
  139. # Attempt unauthorized access
  140. with pytest.raises(R2RException) as exc_info:
  141. await non_owner_client.retrieve_chunk(chunk_id)
  142. assert exc_info.value.status_code == 403
  143. @pytest.mark.asyncio
  144. async def test_list_chunks_with_filters(
  145. self, test_client: AsyncR2RTestClient
  146. ):
  147. """Test listing chunks with owner_id filter."""
  148. # Create and login as temporary user
  149. temp_email = f"{uuid.uuid4()}@example.com"
  150. await test_client.register_user(temp_email, "password123")
  151. await test_client.login_user(temp_email, "password123")
  152. try:
  153. # Create a document with chunks
  154. doc_id, _ = await test_client.create_document(
  155. ["Test chunk 1", "Test chunk 2"]
  156. )
  157. await asyncio.sleep(1) # Wait for ingestion
  158. # Test listing chunks (filters automatically applied on server)
  159. response = await test_client.client.chunks.list(offset=0, limit=1)
  160. assert "results" in response, "Expected 'results' in response"
  161. # assert "page_info" in response, "Expected 'page_info' in response"
  162. assert (
  163. len(response["results"]) <= 1
  164. ), "Expected at most 1 result due to limit"
  165. if len(response["results"]) > 0:
  166. # Verify we only get chunks owned by our temp user
  167. chunk = response["results"][0]
  168. chunks = await test_client.list_chunks(doc_id)
  169. assert chunk["owner_id"] in [
  170. c["owner_id"] for c in chunks
  171. ], "Got chunk from wrong owner"
  172. finally:
  173. # Cleanup
  174. try:
  175. await test_client.delete_document(doc_id)
  176. except:
  177. pass
  178. await test_client.logout_user()
  179. @pytest.mark.asyncio
  180. async def test_list_chunks_pagination(
  181. self, test_client: AsyncR2RTestClient
  182. ):
  183. """Test chunk listing with pagination."""
  184. # Create and login as temporary user
  185. temp_email = f"{uuid.uuid4()}@example.com"
  186. await test_client.register_user(temp_email, "password123")
  187. await test_client.login_user(temp_email, "password123")
  188. doc_id = None
  189. try:
  190. # Create a document with multiple chunks
  191. chunks = [f"Test chunk {i}" for i in range(5)]
  192. doc_id, _ = await test_client.create_document(chunks)
  193. await asyncio.sleep(1) # Wait for ingestion
  194. # Test first page
  195. response1 = await test_client.client.chunks.list(offset=0, limit=2)
  196. assert (
  197. len(response1["results"]) == 2
  198. ), "Expected 2 results on first page"
  199. # assert response1["page_info"]["has_next"], "Expected more pages"
  200. # Test second page
  201. response2 = await test_client.client.chunks.list(offset=2, limit=2)
  202. assert (
  203. len(response2["results"]) == 2
  204. ), "Expected 2 results on second page"
  205. # Verify no duplicate results
  206. ids_page1 = {chunk["id"] for chunk in response1["results"]}
  207. ids_page2 = {chunk["id"] for chunk in response2["results"]}
  208. assert not ids_page1.intersection(
  209. ids_page2
  210. ), "Found duplicate chunks across pages"
  211. finally:
  212. # Cleanup
  213. if doc_id:
  214. try:
  215. await test_client.delete_document(doc_id)
  216. except:
  217. pass
  218. await test_client.logout_user()
  219. @pytest.mark.asyncio
  220. async def test_list_chunks_with_multiple_documents(
  221. self, test_client: AsyncR2RTestClient
  222. ):
  223. """Test listing chunks across multiple documents."""
  224. # Create and login as temporary user
  225. temp_email = f"{uuid.uuid4()}@example.com"
  226. await test_client.register_user(temp_email, "password123")
  227. await test_client.login_user(temp_email, "password123")
  228. doc_ids = []
  229. try:
  230. # Create multiple documents
  231. for i in range(2):
  232. doc_id, _ = await test_client.create_document(
  233. [f"Doc {i} chunk 1", f"Doc {i} chunk 2"]
  234. )
  235. doc_ids.append(doc_id)
  236. await asyncio.sleep(1) # Wait for ingestion
  237. # List all chunks
  238. response = await test_client.client.chunks.list(offset=0, limit=10)
  239. assert len(response["results"]) == 4, "Expected 4 total chunks"
  240. # Verify all chunks belong to our documents
  241. chunk_doc_ids = {
  242. chunk["document_id"] for chunk in response["results"]
  243. }
  244. assert all(
  245. str(doc_id) in chunk_doc_ids for doc_id in doc_ids
  246. ), "Got chunks from wrong documents"
  247. finally:
  248. # Cleanup
  249. for doc_id in doc_ids:
  250. try:
  251. await test_client.delete_document(doc_id)
  252. except:
  253. pass
  254. await test_client.logout_user()
  255. if __name__ == "__main__":
  256. pytest.main(["-v", "--asyncio-mode=auto"])