123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- import inspect
- from unittest.mock import Mock, create_autospec
- import pytest
- from starlette.responses import FileResponse, StreamingResponse
- from starlette.templating import _TemplateResponse
- from core import R2RProviders
- from core.main.abstractions import R2RServices
- from core.main.api.v3.chunks_router import ChunksRouter
- from core.main.api.v3.collections_router import CollectionsRouter
- from core.main.api.v3.conversations_router import ConversationsRouter
- from core.main.api.v3.documents_router import DocumentsRouter
- from core.main.api.v3.graph_router import GraphRouter
- from core.main.api.v3.indices_router import IndicesRouter
- from core.main.api.v3.prompts_router import PromptsRouter
- from core.main.api.v3.retrieval_router import RetrievalRouter
- from core.main.api.v3.system_router import SystemRouter
- from core.main.api.v3.users_router import UsersRouter
- from core.main.config import R2RConfig
- from core.providers.auth import R2RAuthProvider
- from core.providers.database import PostgresDatabaseProvider
- from core.providers.email import ConsoleMockEmailProvider
- from core.providers.embeddings import OpenAIEmbeddingProvider
- from core.providers.file import PostgresFileProvider
- from core.providers.ingestion import R2RIngestionProvider
- from core.providers.llm import OpenAICompletionProvider
- from core.providers.orchestration import SimpleOrchestrationProvider
- from core.providers.scheduler import APSchedulerProvider
- from core.providers.ocr import MistralOCRProvider
- ROUTERS = [
- UsersRouter,
- ChunksRouter,
- CollectionsRouter,
- ConversationsRouter,
- DocumentsRouter,
- GraphRouter,
- IndicesRouter,
- PromptsRouter,
- RetrievalRouter,
- SystemRouter,
- ]
- @pytest.fixture
- def mock_providers():
- # Create mock auth provider that inherits from the base class
- mock_auth = create_autospec(R2RAuthProvider)
- # Create other mock providers
- mock_db = create_autospec(PostgresDatabaseProvider)
- mock_db.config = Mock()
- mock_ingestion = create_autospec(R2RIngestionProvider)
- mock_ingestion.config = Mock()
- mock_embedding = create_autospec(OpenAIEmbeddingProvider)
- mock_embedding.config = Mock()
- mock_completion_embedding = create_autospec(OpenAIEmbeddingProvider)
- mock_completion_embedding.config = Mock()
- mock_file = create_autospec(PostgresFileProvider)
- mock_file.config = Mock()
- mock_llm = create_autospec(OpenAICompletionProvider)
- mock_llm.config = Mock()
- mock_ocr = create_autospec(MistralOCRProvider)
- mock_ocr.config = Mock()
- mock_orchestration = create_autospec(SimpleOrchestrationProvider)
- mock_orchestration.config = Mock()
- mock_email = create_autospec(ConsoleMockEmailProvider)
- mock_email.config = Mock()
- mock_scheduler = create_autospec(APSchedulerProvider)
- mock_scheduler.config = Mock()
- # Set up any needed methods
- mock_auth.auth_wrapper = Mock(return_value=lambda: None)
- return R2RProviders(
- auth=mock_auth,
- completion_embedding=mock_completion_embedding,
- database=mock_db,
- email=mock_email,
- embedding=mock_embedding,
- file=mock_file,
- ingestion=mock_ingestion,
- llm=mock_llm,
- ocr=mock_ocr,
- orchestration=mock_orchestration,
- scheduler=mock_scheduler,
- )
- @pytest.fixture
- def mock_services():
- return R2RServices(
- auth=Mock(),
- ingestion=Mock(),
- graph=Mock(),
- maintenance=Mock(),
- management=Mock(),
- retrieval=Mock(),
- )
- @pytest.fixture
- def mock_config():
- config_data = {
- "app": {}, # AppConfig needs minimal data
- "auth": {
- "provider": "mock"
- },
- "completion": {
- "provider": "mock"
- },
- "crypto": {
- "provider": "mock"
- },
- "database": {
- "provider": "mock"
- },
- "embedding": {
- "provider": "mock",
- "base_model": "test",
- "base_dimension": 1024,
- "batch_size": 10,
- },
- "completion_embedding": {
- "provider": "mock",
- "base_model": "test",
- "base_dimension": 1024,
- "batch_size": 10,
- },
- "email": {
- "provider": "mock"
- },
- "ingestion": {
- "provider": "mock"
- },
- "agent": {
- "generation_config": {}
- },
- "orchestration": {
- "provider": "mock"
- },
- }
- return R2RConfig(config_data)
- @pytest.fixture(params=ROUTERS)
- def router(request, mock_providers, mock_services, mock_config):
- router_class = request.param
- return router_class(mock_providers, mock_services, mock_config)
- def test_all_routes_have_base_endpoint_decorator(router):
- for route in router.router.routes:
- if (route.path.endswith("/stream") or route.path.endswith("/viewer")
- or "websocket" in str(type(route)).lower()):
- continue
- endpoint = route.endpoint
- assert hasattr(endpoint, "_is_base_endpoint"), (
- f"Route {route.path} missing @base_endpoint decorator")
- def test_all_routes_have_proper_return_type_hints(router):
- for route in router.router.routes:
- if (route.path.endswith("/stream")
- or "websocket" in str(type(route)).lower()):
- continue
- endpoint = route.endpoint
- return_type = inspect.signature(endpoint).return_annotation
- # Check if the type is an R2RResults by name
- is_valid = isinstance(
- return_type, type) and ("R2RResults" in str(return_type)
- or "PaginatedR2RResult" in str(return_type)
- or return_type == FileResponse
- or return_type == StreamingResponse
- or return_type == _TemplateResponse)
- assert is_valid, (
- f"Route {route.path} has invalid return type: {return_type}, expected R2RResults[...]"
- )
- def test_all_routes_have_rate_limiting(router):
- import warnings
- for route in router.router.routes:
- print(f"Checking route: {route.path}")
- print(f"Dependencies: {route.dependencies}")
- has_rate_limit = any(dep.dependency == router.rate_limit_dependency
- for dep in route.dependencies)
- if not has_rate_limit:
- # We should require this in the future, but for now just warn
- warnings.warn(
- f"Route {route.path} missing rate limiting - this will be required in the future",
- UserWarning,
- )
|