test_chunks.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # tests/integration/test_chunks.py
  2. import asyncio
  3. import uuid
  4. from typing import AsyncGenerator, List, Optional, Tuple
  5. import pytest
  6. from r2r import R2RAsyncClient, R2RException
  7. class AsyncR2RTestClient:
  8. """Wrapper to ensure async operations use the correct event loop"""
  9. def __init__(self, base_url: str = "http://localhost:7272"):
  10. self.client = R2RAsyncClient(base_url)
  11. async def create_document(
  12. self, chunks: List[str], run_with_orchestration: bool = False
  13. ) -> Tuple[str, List[dict]]:
  14. response = await self.client.documents.create(
  15. chunks=chunks, run_with_orchestration=run_with_orchestration
  16. )
  17. return response["results"]["document_id"], []
  18. async def delete_document(self, doc_id: str) -> None:
  19. await self.client.documents.delete(id=doc_id)
  20. async def list_chunks(self, doc_id: str) -> List[dict]:
  21. response = await self.client.documents.list_chunks(id=doc_id)
  22. return response["results"]
  23. async def retrieve_chunk(self, chunk_id: str) -> dict:
  24. response = await self.client.chunks.retrieve(id=chunk_id)
  25. return response["results"]
  26. async def update_chunk(
  27. self, chunk_id: str, text: str, metadata: Optional[dict] = None
  28. ) -> dict:
  29. response = await self.client.chunks.update(
  30. {"id": chunk_id, "text": text, "metadata": metadata or {}}
  31. )
  32. return response["results"]
  33. async def delete_chunk(self, chunk_id: str) -> dict:
  34. response = await self.client.chunks.delete(id=chunk_id)
  35. return response["results"]
  36. async def search_chunks(self, query: str, limit: int = 5) -> List[dict]:
  37. response = await self.client.chunks.search(
  38. query=query, search_settings={"limit": limit}
  39. )
  40. return response["results"]
  41. async def register_user(self, email: str, password: str) -> None:
  42. await self.client.users.register(email, password)
  43. async def login_user(self, email: str, password: str) -> None:
  44. await self.client.users.login(email, password)
  45. async def logout_user(self) -> None:
  46. await self.client.users.logout()
  47. @pytest.fixture
  48. async def test_client() -> AsyncGenerator[AsyncR2RTestClient, None]:
  49. """Create a test client."""
  50. client = AsyncR2RTestClient()
  51. yield client
  52. @pytest.fixture
  53. async def test_document(
  54. test_client: AsyncR2RTestClient,
  55. ) -> AsyncGenerator[Tuple[str, List[dict]], None]:
  56. """Create a test document with chunks."""
  57. doc_id, _ = await test_client.create_document(
  58. ["Test chunk 1", "Test chunk 2"]
  59. )
  60. await asyncio.sleep(1) # Wait for ingestion
  61. chunks = await test_client.list_chunks(doc_id)
  62. yield doc_id, chunks
  63. try:
  64. await test_client.delete_document(doc_id)
  65. except R2RException:
  66. pass
  67. class TestChunks:
  68. @pytest.mark.asyncio
  69. async def test_create_and_list_chunks(
  70. self, test_client: AsyncR2RTestClient
  71. ):
  72. # Create document with chunks
  73. doc_id, _ = await test_client.create_document(
  74. ["Hello chunk", "World chunk"]
  75. )
  76. await asyncio.sleep(1) # Wait for ingestion
  77. # List and verify chunks
  78. chunks = await test_client.list_chunks(doc_id)
  79. assert len(chunks) == 2, "Expected 2 chunks in the document"
  80. # Cleanup
  81. await test_client.delete_document(doc_id)
  82. @pytest.mark.asyncio
  83. async def test_retrieve_chunk(
  84. self, test_client: AsyncR2RTestClient, test_document
  85. ):
  86. doc_id, chunks = test_document
  87. chunk_id = chunks[0]["id"]
  88. retrieved = await test_client.retrieve_chunk(chunk_id)
  89. assert retrieved["id"] == chunk_id, "Retrieved wrong chunk ID"
  90. assert retrieved["text"] == "Test chunk 1", "Chunk text mismatch"
  91. @pytest.mark.asyncio
  92. async def test_update_chunk(
  93. self, test_client: AsyncR2RTestClient, test_document
  94. ):
  95. doc_id, chunks = test_document
  96. chunk_id = chunks[0]["id"]
  97. # Update chunk
  98. updated = await test_client.update_chunk(
  99. chunk_id, "Updated text", {"version": 2}
  100. )
  101. assert updated["text"] == "Updated text", "Chunk text not updated"
  102. assert updated["metadata"]["version"] == 2, "Metadata not updated"
  103. @pytest.mark.asyncio
  104. async def test_delete_chunk(
  105. self, test_client: AsyncR2RTestClient, test_document
  106. ):
  107. doc_id, chunks = test_document
  108. chunk_id = chunks[0]["id"]
  109. # Delete and verify
  110. result = await test_client.delete_chunk(chunk_id)
  111. assert result["success"], "Chunk deletion failed"
  112. # Verify deletion
  113. with pytest.raises(R2RException) as exc_info:
  114. await test_client.retrieve_chunk(chunk_id)
  115. assert exc_info.value.status_code == 404
  116. @pytest.mark.asyncio
  117. async def test_search_chunks(self, test_client: AsyncR2RTestClient):
  118. # Create searchable document
  119. doc_id, _ = await test_client.create_document(
  120. ["Aristotle reference", "Another piece of text"]
  121. )
  122. await asyncio.sleep(1) # Wait for indexing
  123. # Search
  124. results = await test_client.search_chunks("Aristotle")
  125. assert len(results) > 0, "No search results found"
  126. # Cleanup
  127. await test_client.delete_document(doc_id)
  128. @pytest.mark.asyncio
  129. async def test_unauthorized_chunk_access(
  130. self, test_client: AsyncR2RTestClient, test_document
  131. ):
  132. doc_id, chunks = test_document
  133. chunk_id = chunks[0]["id"]
  134. # Create and login as different user
  135. non_owner_client = AsyncR2RTestClient()
  136. email = f"test_{uuid.uuid4()}@example.com"
  137. await non_owner_client.register_user(email, "password123")
  138. await non_owner_client.login_user(email, "password123")
  139. # Attempt unauthorized access
  140. with pytest.raises(R2RException) as exc_info:
  141. await non_owner_client.retrieve_chunk(chunk_id)
  142. assert exc_info.value.status_code == 403
  143. if __name__ == "__main__":
  144. pytest.main(["-v", "--asyncio-mode=auto"])