app.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from fastapi import FastAPI, Request
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.openapi.utils import get_openapi
  4. from fastapi.responses import JSONResponse
  5. from core.base import R2RException
  6. from core.providers import (
  7. HatchetOrchestrationProvider,
  8. SimpleOrchestrationProvider,
  9. )
  10. from core.utils.sentry import init_sentry
  11. from .abstractions import R2RProviders, R2RServices
  12. from .api.v3.chunks_router import ChunksRouter
  13. from .api.v3.collections_router import CollectionsRouter
  14. from .api.v3.conversations_router import ConversationsRouter
  15. from .api.v3.documents_router import DocumentsRouter
  16. from .api.v3.graph_router import GraphRouter
  17. from .api.v3.indices_router import IndicesRouter
  18. from .api.v3.prompts_router import PromptsRouter
  19. from .api.v3.retrieval_router import RetrievalRouter
  20. from .api.v3.system_router import SystemRouter
  21. from .api.v3.users_router import UsersRouter
  22. from .config import R2RConfig
  23. from .middleware.project_schema import ProjectSchemaMiddleware
  24. class R2RApp:
  25. def __init__(
  26. self,
  27. config: R2RConfig,
  28. orchestration_provider: (
  29. HatchetOrchestrationProvider | SimpleOrchestrationProvider
  30. ),
  31. services: R2RServices,
  32. providers: R2RProviders,
  33. chunks_router: ChunksRouter,
  34. collections_router: CollectionsRouter,
  35. conversations_router: ConversationsRouter,
  36. documents_router: DocumentsRouter,
  37. graph_router: GraphRouter,
  38. indices_router: IndicesRouter,
  39. prompts_router: PromptsRouter,
  40. retrieval_router: RetrievalRouter,
  41. system_router: SystemRouter,
  42. users_router: UsersRouter,
  43. ):
  44. init_sentry()
  45. self.config = config
  46. self.services = services
  47. self.providers = providers
  48. self.chunks_router = chunks_router
  49. self.collections_router = collections_router
  50. self.conversations_router = conversations_router
  51. self.documents_router = documents_router
  52. self.graph_router = graph_router
  53. self.indices_router = indices_router
  54. self.orchestration_provider = orchestration_provider
  55. self.prompts_router = prompts_router
  56. self.retrieval_router = retrieval_router
  57. self.system_router = system_router
  58. self.users_router = users_router
  59. self.app = FastAPI()
  60. @self.app.exception_handler(R2RException)
  61. async def r2r_exception_handler(request: Request, exc: R2RException):
  62. return JSONResponse(
  63. status_code=exc.status_code,
  64. content={
  65. "message": exc.message,
  66. "error_type": type(exc).__name__,
  67. },
  68. )
  69. self._setup_routes()
  70. self._apply_middleware()
  71. def _setup_routes(self):
  72. self.app.include_router(self.chunks_router, prefix="/v3")
  73. self.app.include_router(self.collections_router, prefix="/v3")
  74. self.app.include_router(self.conversations_router, prefix="/v3")
  75. self.app.include_router(self.documents_router, prefix="/v3")
  76. self.app.include_router(self.graph_router, prefix="/v3")
  77. self.app.include_router(self.indices_router, prefix="/v3")
  78. self.app.include_router(self.prompts_router, prefix="/v3")
  79. self.app.include_router(self.retrieval_router, prefix="/v3")
  80. self.app.include_router(self.system_router, prefix="/v3")
  81. self.app.include_router(self.users_router, prefix="/v3")
  82. @self.app.get("/openapi_spec", include_in_schema=False)
  83. async def openapi_spec():
  84. return get_openapi(
  85. title="R2R Application API",
  86. version="1.0.0",
  87. routes=self.app.routes,
  88. )
  89. def _apply_middleware(self):
  90. origins = ["*", "http://localhost:3000", "http://localhost:7272"]
  91. project_name = self.providers.database.project_name
  92. self.app.add_middleware(
  93. CORSMiddleware,
  94. allow_origins=origins,
  95. allow_credentials=True,
  96. allow_methods=["*"],
  97. allow_headers=["*"],
  98. )
  99. self.app.add_middleware(
  100. ProjectSchemaMiddleware,
  101. default_schema=project_name,
  102. )
  103. async def serve(self, host: str = "0.0.0.0", port: int = 7272):
  104. import uvicorn
  105. from core.utils.logging_config import configure_logging
  106. configure_logging()
  107. config = uvicorn.Config(
  108. self.app,
  109. host=host,
  110. port=port,
  111. log_config=None,
  112. )
  113. server = uvicorn.Server(config)
  114. await server.serve()