test_chunks.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import asyncio
  2. import contextlib
  3. import uuid
  4. from typing import AsyncGenerator, Optional, Tuple
  5. import pytest
  6. from r2r import R2RAsyncClient, R2RException
  7. class AsyncR2RTestClient:
  8. """Wrapper to ensure async operations use the correct event loop."""
  9. def __init__(self, base_url: str = "http://localhost:7272"):
  10. self.client = R2RAsyncClient(base_url)
  11. async def create_document(self,
  12. chunks: list[str],
  13. run_with_orchestration: bool = False):
  14. response = await self.client.documents.create(
  15. chunks=chunks, run_with_orchestration=run_with_orchestration)
  16. return response.results.document_id, []
  17. async def delete_document(self, doc_id: str):
  18. await self.client.documents.delete(id=doc_id)
  19. async def list_chunks(self, doc_id: str):
  20. response = await self.client.documents.list_chunks(id=doc_id)
  21. return response.results
  22. async def retrieve_chunk(self, chunk_id: str):
  23. response = await self.client.chunks.retrieve(id=chunk_id)
  24. return response.results
  25. async def update_chunk(self,
  26. chunk_id: str,
  27. text: str,
  28. metadata: Optional[dict] = None):
  29. response = await self.client.chunks.update({
  30. "id": chunk_id,
  31. "text": text,
  32. "metadata": metadata or {}
  33. })
  34. return response.results
  35. async def delete_chunk(self, chunk_id: str):
  36. response = await self.client.chunks.delete(id=chunk_id)
  37. return response.results
  38. async def search_chunks(self, query: str, limit: int = 5):
  39. response = await self.client.chunks.search(
  40. query=query, search_settings={"limit": limit})
  41. return response.results
  42. async def register_user(self, email: str, password: str):
  43. await self.client.users.create(email, password)
  44. async def login_user(self, email: str, password: str):
  45. await self.client.users.login(email, password)
  46. async def logout_user(self):
  47. await self.client.users.logout()
  48. @pytest.fixture
  49. async def test_client() -> AsyncGenerator[AsyncR2RTestClient, None]:
  50. """Create a test client."""
  51. yield AsyncR2RTestClient()
  52. @pytest.fixture
  53. async def test_document(
  54. test_client: AsyncR2RTestClient,
  55. ) -> AsyncGenerator[Tuple[str, list[dict]], None]:
  56. """Create a test document with chunks."""
  57. uuid_1 = uuid.uuid4()
  58. uuid_2 = uuid.uuid4()
  59. doc_id, _ = await test_client.create_document(
  60. [f"Test chunk 1_{uuid_1}", f"Test chunk 2_{uuid_2}"])
  61. await asyncio.sleep(1) # Wait for ingestion
  62. chunks = await test_client.list_chunks(str(doc_id))
  63. yield doc_id, chunks
  64. with contextlib.suppress(R2RException):
  65. await test_client.delete_document(str(doc_id))
  66. class TestChunks:
  67. @pytest.mark.asyncio
  68. async def test_create_and_list_chunks(self,
  69. test_client: AsyncR2RTestClient,
  70. cleanup_documents):
  71. # Create document with chunks
  72. doc_id, _ = await test_client.create_document(
  73. ["Hello chunk", "World chunk"])
  74. cleanup_documents(str(doc_id))
  75. await asyncio.sleep(1) # Wait for ingestion
  76. # List and verify chunks
  77. chunks = await test_client.list_chunks(str(doc_id))
  78. assert len(chunks) == 2, "Expected 2 chunks in the document"
  79. @pytest.mark.asyncio
  80. async def test_retrieve_chunk(self, test_client: AsyncR2RTestClient,
  81. test_document):
  82. doc_id, chunks = test_document
  83. chunk_id = chunks[0].id
  84. retrieved = await test_client.retrieve_chunk(chunk_id)
  85. assert str(retrieved.id) == str(chunk_id), "Retrieved wrong chunk ID"
  86. assert retrieved.text.split("_")[0] == "Test chunk 1", (
  87. "Chunk text mismatch")
  88. @pytest.mark.asyncio
  89. async def test_update_chunk(self, test_client: AsyncR2RTestClient,
  90. test_document):
  91. doc_id, chunks = test_document
  92. chunk_id = chunks[0].id
  93. # Update chunk
  94. updated = await test_client.update_chunk(str(chunk_id), "Updated text",
  95. {"version": 2})
  96. assert updated.text == "Updated text", "Chunk text not updated"
  97. assert updated.metadata["version"] == 2, "Metadata not updated"
  98. @pytest.mark.asyncio
  99. async def test_delete_chunk(self, test_client: AsyncR2RTestClient,
  100. test_document):
  101. doc_id, chunks = test_document
  102. chunk_id = chunks[0].id
  103. # Delete and verify
  104. result = await test_client.delete_chunk(str(chunk_id))
  105. assert result.success, "Chunk deletion failed"
  106. # Verify deletion
  107. with pytest.raises(R2RException) as exc_info:
  108. await test_client.retrieve_chunk(str(chunk_id))
  109. assert exc_info.value.status_code == 404
  110. @pytest.mark.asyncio
  111. async def test_search_chunks(self, test_client: AsyncR2RTestClient,
  112. cleanup_documents):
  113. # Create searchable document
  114. random_1 = uuid.uuid4()
  115. random_2 = uuid.uuid4()
  116. doc_id, _ = await test_client.create_document([
  117. f"Aristotle reference {random_1}",
  118. f"Another piece of text {random_2}",
  119. ])
  120. cleanup_documents(doc_id)
  121. await asyncio.sleep(1) # Wait for indexing
  122. # Search
  123. results = await test_client.search_chunks("Aristotle")
  124. assert len(results) > 0, "No search results found"
  125. @pytest.mark.asyncio
  126. async def test_unauthorized_chunk_access(self,
  127. test_client: AsyncR2RTestClient,
  128. test_document):
  129. doc_id, chunks = test_document
  130. chunk_id = chunks[0].id
  131. # Create and login as different user
  132. non_owner_client = AsyncR2RTestClient()
  133. email = f"test_{uuid.uuid4()}@example.com"
  134. await non_owner_client.register_user(email, "password123")
  135. await non_owner_client.login_user(email, "password123")
  136. # Attempt unauthorized access
  137. with pytest.raises(R2RException) as exc_info:
  138. await non_owner_client.retrieve_chunk(str(chunk_id))
  139. assert exc_info.value.status_code == 403
  140. @pytest.mark.asyncio
  141. async def test_list_chunks_with_filters(self,
  142. test_client: AsyncR2RTestClient,
  143. cleanup_documents):
  144. """Test listing chunks with owner_id filter."""
  145. # Create and login as temporary user
  146. temp_email = f"{uuid.uuid4()}@example.com"
  147. await test_client.register_user(temp_email, "password123")
  148. await test_client.login_user(temp_email, "password123")
  149. # Create a document with chunks
  150. doc_id, _ = await test_client.create_document(
  151. ["Test chunk 1", "Test chunk 2"])
  152. cleanup_documents(doc_id)
  153. await asyncio.sleep(1) # Wait for ingestion
  154. @pytest.mark.asyncio
  155. async def test_list_chunks_pagination(self,
  156. test_client: AsyncR2RTestClient):
  157. """Test chunk listing with pagination."""
  158. # Create and login as temporary user
  159. temp_email = f"{uuid.uuid4()}@example.com"
  160. await test_client.register_user(temp_email, "password123")
  161. await test_client.login_user(temp_email, "password123")
  162. doc_id = None
  163. try:
  164. # Create a document with multiple chunks
  165. chunks = [f"Test chunk {i}" for i in range(5)]
  166. doc_id, _ = await test_client.create_document(chunks)
  167. await asyncio.sleep(1) # Wait for ingestion
  168. # Test first page
  169. response1 = await test_client.client.chunks.list(offset=0, limit=2)
  170. assert len(
  171. response1.results) == 2, ("Expected 2 results on first page")
  172. # Test second page
  173. response2 = await test_client.client.chunks.list(offset=2, limit=2)
  174. assert len(
  175. response2.results) == 2, ("Expected 2 results on second page")
  176. # Verify no duplicate results
  177. ids_page1 = {str(chunk.id) for chunk in response1.results}
  178. ids_page2 = {str(chunk.id) for chunk in response2.results}
  179. assert not ids_page1.intersection(ids_page2), (
  180. "Found duplicate chunks across pages")
  181. finally:
  182. # Cleanup
  183. if doc_id:
  184. try:
  185. await test_client.delete_document(doc_id)
  186. except:
  187. pass
  188. await test_client.logout_user()
  189. @pytest.mark.asyncio
  190. async def test_list_chunks_with_multiple_documents(
  191. self, test_client: AsyncR2RTestClient):
  192. """Test listing chunks across multiple documents."""
  193. # Create and login as temporary user
  194. temp_email = f"{uuid.uuid4()}@example.com"
  195. await test_client.register_user(temp_email, "password123")
  196. await test_client.login_user(temp_email, "password123")
  197. doc_ids = []
  198. try:
  199. # Create multiple documents
  200. for i in range(2):
  201. doc_id, _ = await test_client.create_document(
  202. [f"Doc {i} chunk 1", f"Doc {i} chunk 2"])
  203. doc_ids.append(doc_id)
  204. await asyncio.sleep(1) # Wait for ingestion
  205. # List all chunks
  206. response = await test_client.client.chunks.list(offset=0, limit=10)
  207. assert len(response.results) == 4, "Expected 4 total chunks"
  208. chunk_doc_ids = {
  209. str(chunk.document_id)
  210. for chunk in response.results
  211. }
  212. assert all(
  213. str(doc_id) in chunk_doc_ids
  214. for doc_id in doc_ids), ("Got chunks from wrong documents")
  215. finally:
  216. # Cleanup
  217. for doc_id in doc_ids:
  218. try:
  219. await test_client.delete_document(doc_id)
  220. except:
  221. pass
  222. await test_client.logout_user()
  223. @pytest.fixture
  224. async def cleanup_documents(test_client: AsyncR2RTestClient):
  225. doc_ids = []
  226. def _track_document(doc_id: str) -> str:
  227. doc_ids.append(doc_id)
  228. return doc_id
  229. yield _track_document
  230. # Cleanup all documents
  231. for doc_id in doc_ids:
  232. with contextlib.suppress(R2RException):
  233. await test_client.delete_document(doc_id)
  234. if __name__ == "__main__":
  235. pytest.main(["-v", "--asyncio-mode=auto"])