test_routes.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import inspect
  2. from unittest.mock import Mock, create_autospec
  3. import pytest
  4. from starlette.responses import FileResponse, StreamingResponse
  5. from starlette.templating import _TemplateResponse
  6. from core import R2RProviders
  7. from core.main.abstractions import R2RServices
  8. from core.main.api.v3.chunks_router import ChunksRouter
  9. from core.main.api.v3.collections_router import CollectionsRouter
  10. from core.main.api.v3.conversations_router import ConversationsRouter
  11. from core.main.api.v3.documents_router import DocumentsRouter
  12. from core.main.api.v3.graph_router import GraphRouter
  13. from core.main.api.v3.indices_router import IndicesRouter
  14. from core.main.api.v3.prompts_router import PromptsRouter
  15. from core.main.api.v3.retrieval_router import RetrievalRouter
  16. from core.main.api.v3.system_router import SystemRouter
  17. from core.main.api.v3.users_router import UsersRouter
  18. from core.main.config import R2RConfig
  19. from core.providers.auth import R2RAuthProvider
  20. from core.providers.database import PostgresDatabaseProvider
  21. from core.providers.email import ConsoleMockEmailProvider
  22. from core.providers.embeddings import OpenAIEmbeddingProvider
  23. from core.providers.file import PostgresFileProvider
  24. from core.providers.ingestion import R2RIngestionProvider
  25. from core.providers.llm import OpenAICompletionProvider
  26. from core.providers.orchestration import SimpleOrchestrationProvider
  27. from core.providers.scheduler import APSchedulerProvider
  28. from core.providers.ocr import MistralOCRProvider
  29. ROUTERS = [
  30. UsersRouter,
  31. ChunksRouter,
  32. CollectionsRouter,
  33. ConversationsRouter,
  34. DocumentsRouter,
  35. GraphRouter,
  36. IndicesRouter,
  37. PromptsRouter,
  38. RetrievalRouter,
  39. SystemRouter,
  40. ]
  41. @pytest.fixture
  42. def mock_providers():
  43. # Create mock auth provider that inherits from the base class
  44. mock_auth = create_autospec(R2RAuthProvider)
  45. # Create other mock providers
  46. mock_db = create_autospec(PostgresDatabaseProvider)
  47. mock_db.config = Mock()
  48. mock_ingestion = create_autospec(R2RIngestionProvider)
  49. mock_ingestion.config = Mock()
  50. mock_embedding = create_autospec(OpenAIEmbeddingProvider)
  51. mock_embedding.config = Mock()
  52. mock_completion_embedding = create_autospec(OpenAIEmbeddingProvider)
  53. mock_completion_embedding.config = Mock()
  54. mock_file = create_autospec(PostgresFileProvider)
  55. mock_file.config = Mock()
  56. mock_llm = create_autospec(OpenAICompletionProvider)
  57. mock_llm.config = Mock()
  58. mock_ocr = create_autospec(MistralOCRProvider)
  59. mock_ocr.config = Mock()
  60. mock_orchestration = create_autospec(SimpleOrchestrationProvider)
  61. mock_orchestration.config = Mock()
  62. mock_email = create_autospec(ConsoleMockEmailProvider)
  63. mock_email.config = Mock()
  64. mock_scheduler = create_autospec(APSchedulerProvider)
  65. mock_scheduler.config = Mock()
  66. # Set up any needed methods
  67. mock_auth.auth_wrapper = Mock(return_value=lambda: None)
  68. return R2RProviders(
  69. auth=mock_auth,
  70. completion_embedding=mock_completion_embedding,
  71. database=mock_db,
  72. email=mock_email,
  73. embedding=mock_embedding,
  74. file=mock_file,
  75. ingestion=mock_ingestion,
  76. llm=mock_llm,
  77. ocr=mock_ocr,
  78. orchestration=mock_orchestration,
  79. scheduler=mock_scheduler,
  80. )
  81. @pytest.fixture
  82. def mock_services():
  83. return R2RServices(
  84. auth=Mock(),
  85. ingestion=Mock(),
  86. graph=Mock(),
  87. maintenance=Mock(),
  88. management=Mock(),
  89. retrieval=Mock(),
  90. )
  91. @pytest.fixture
  92. def mock_config():
  93. config_data = {
  94. "app": {}, # AppConfig needs minimal data
  95. "auth": {
  96. "provider": "mock"
  97. },
  98. "completion": {
  99. "provider": "mock"
  100. },
  101. "crypto": {
  102. "provider": "mock"
  103. },
  104. "database": {
  105. "provider": "mock"
  106. },
  107. "embedding": {
  108. "provider": "mock",
  109. "base_model": "test",
  110. "base_dimension": 1024,
  111. "batch_size": 10,
  112. },
  113. "completion_embedding": {
  114. "provider": "mock",
  115. "base_model": "test",
  116. "base_dimension": 1024,
  117. "batch_size": 10,
  118. },
  119. "email": {
  120. "provider": "mock"
  121. },
  122. "ingestion": {
  123. "provider": "mock"
  124. },
  125. "agent": {
  126. "generation_config": {}
  127. },
  128. "orchestration": {
  129. "provider": "mock"
  130. },
  131. }
  132. return R2RConfig(config_data)
  133. @pytest.fixture(params=ROUTERS)
  134. def router(request, mock_providers, mock_services, mock_config):
  135. router_class = request.param
  136. return router_class(mock_providers, mock_services, mock_config)
  137. def test_all_routes_have_base_endpoint_decorator(router):
  138. for route in router.router.routes:
  139. if (route.path.endswith("/stream") or route.path.endswith("/viewer")
  140. or "websocket" in str(type(route)).lower()):
  141. continue
  142. endpoint = route.endpoint
  143. assert hasattr(endpoint, "_is_base_endpoint"), (
  144. f"Route {route.path} missing @base_endpoint decorator")
  145. def test_all_routes_have_proper_return_type_hints(router):
  146. for route in router.router.routes:
  147. if (route.path.endswith("/stream")
  148. or "websocket" in str(type(route)).lower()):
  149. continue
  150. endpoint = route.endpoint
  151. return_type = inspect.signature(endpoint).return_annotation
  152. # Check if the type is an R2RResults by name
  153. is_valid = isinstance(
  154. return_type, type) and ("R2RResults" in str(return_type)
  155. or "PaginatedR2RResult" in str(return_type)
  156. or return_type == FileResponse
  157. or return_type == StreamingResponse
  158. or return_type == _TemplateResponse)
  159. assert is_valid, (
  160. f"Route {route.path} has invalid return type: {return_type}, expected R2RResults[...]"
  161. )
  162. def test_all_routes_have_rate_limiting(router):
  163. import warnings
  164. for route in router.router.routes:
  165. print(f"Checking route: {route.path}")
  166. print(f"Dependencies: {route.dependencies}")
  167. has_rate_limit = any(dep.dependency == router.rate_limit_dependency
  168. for dep in route.dependencies)
  169. if not has_rate_limit:
  170. # We should require this in the future, but for now just warn
  171. warnings.warn(
  172. f"Route {route.path} missing rate limiting - this will be required in the future",
  173. UserWarning,
  174. )