app.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from fastapi import FastAPI, Request
  2. from fastapi.middleware.cors import CORSMiddleware
  3. from fastapi.openapi.utils import get_openapi
  4. from fastapi.responses import JSONResponse
  5. from core.base import R2RException
  6. from core.providers import (
  7. HatchetOrchestrationProvider,
  8. SimpleOrchestrationProvider,
  9. )
  10. from .api.v3.chunks_router import ChunksRouter
  11. from .api.v3.collections_router import CollectionsRouter
  12. from .api.v3.conversations_router import ConversationsRouter
  13. from .api.v3.documents_router import DocumentsRouter
  14. from .api.v3.graph_router import GraphRouter
  15. from .api.v3.indices_router import IndicesRouter
  16. from .api.v3.logs_router import LogsRouter
  17. from .api.v3.prompts_router import PromptsRouter
  18. from .api.v3.retrieval_router import RetrievalRouterV3
  19. from .api.v3.system_router import SystemRouter
  20. from .api.v3.users_router import UsersRouter
  21. from .config import R2RConfig
  22. class R2RApp:
  23. def __init__(
  24. self,
  25. config: R2RConfig,
  26. orchestration_provider: (
  27. HatchetOrchestrationProvider | SimpleOrchestrationProvider
  28. ),
  29. chunks_router: ChunksRouter,
  30. collections_router: CollectionsRouter,
  31. conversations_router: ConversationsRouter,
  32. documents_router: DocumentsRouter,
  33. graph_router: GraphRouter,
  34. indices_router: IndicesRouter,
  35. logs_router: LogsRouter,
  36. prompts_router: PromptsRouter,
  37. retrieval_router_v3: RetrievalRouterV3,
  38. system_router: SystemRouter,
  39. users_router: UsersRouter,
  40. ):
  41. self.config = config
  42. self.chunks_router = chunks_router
  43. self.collections_router = collections_router
  44. self.conversations_router = conversations_router
  45. self.documents_router = documents_router
  46. self.graph_router = graph_router
  47. self.indices_router = indices_router
  48. self.logs_router = logs_router
  49. self.orchestration_provider = orchestration_provider
  50. self.prompts_router = prompts_router
  51. self.retrieval_router_v3 = retrieval_router_v3
  52. self.system_router = system_router
  53. self.users_router = users_router
  54. self.app = FastAPI()
  55. @self.app.exception_handler(R2RException)
  56. async def r2r_exception_handler(request: Request, exc: R2RException):
  57. return JSONResponse(
  58. status_code=exc.status_code,
  59. content={
  60. "message": exc.message,
  61. "error_type": type(exc).__name__,
  62. },
  63. )
  64. self._setup_routes()
  65. self._apply_cors()
  66. def _setup_routes(self):
  67. self.app.include_router(self.chunks_router, prefix="/v3")
  68. self.app.include_router(self.collections_router, prefix="/v3")
  69. self.app.include_router(self.conversations_router, prefix="/v3")
  70. self.app.include_router(self.documents_router, prefix="/v3")
  71. self.app.include_router(self.graph_router, prefix="/v3")
  72. self.app.include_router(self.indices_router, prefix="/v3")
  73. self.app.include_router(self.logs_router, prefix="/v3")
  74. self.app.include_router(self.prompts_router, prefix="/v3")
  75. self.app.include_router(self.retrieval_router_v3, prefix="/v3")
  76. self.app.include_router(self.system_router, prefix="/v3")
  77. self.app.include_router(self.users_router, prefix="/v3")
  78. @self.app.get("/openapi_spec", include_in_schema=False)
  79. async def openapi_spec():
  80. return get_openapi(
  81. title="R2R Application API",
  82. version="1.0.0",
  83. routes=self.app.routes,
  84. )
  85. def _apply_cors(self):
  86. origins = ["*", "http://localhost:3000", "http://localhost:7272"]
  87. self.app.add_middleware(
  88. CORSMiddleware,
  89. allow_origins=origins,
  90. allow_credentials=True,
  91. allow_methods=["*"],
  92. allow_headers=["*"],
  93. )
  94. async def serve(self, host: str = "0.0.0.0", port: int = 7272):
  95. import uvicorn
  96. from core.utils.logging_config import configure_logging
  97. configure_logging()
  98. config = uvicorn.Config(
  99. self.app,
  100. host=host,
  101. port=port,
  102. log_config=None,
  103. )
  104. server = uvicorn.Server(config)
  105. await server.serve()