app_entry.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import logging
  2. import os
  3. from contextlib import asynccontextmanager
  4. from typing import Optional
  5. from apscheduler.schedulers.asyncio import AsyncIOScheduler
  6. from fastapi import FastAPI, Request
  7. from fastapi.middleware.cors import CORSMiddleware
  8. from fastapi.responses import JSONResponse
  9. from core.base import R2RException
  10. from core.utils.logging_config import configure_logging
  11. from .app import R2RApp
  12. from .assembly import R2RBuilder, R2RConfig
  13. from .middleware.project_schema import ProjectSchemaMiddleware
  14. log_file = configure_logging()
  15. # Global scheduler
  16. scheduler = AsyncIOScheduler()
  17. @asynccontextmanager
  18. async def lifespan(app: FastAPI):
  19. # Startup
  20. r2r_app = await create_r2r_app(
  21. config_name=config_name,
  22. config_path=config_path,
  23. )
  24. # Copy all routes from r2r_app to app
  25. app.router.routes = r2r_app.app.routes
  26. # Copy middleware and exception handlers
  27. app.middleware = r2r_app.app.middleware # type: ignore
  28. app.exception_handlers = r2r_app.app.exception_handlers
  29. # Start the scheduler
  30. scheduler.start()
  31. # Start the Hatchet worker
  32. await r2r_app.orchestration_provider.start_worker()
  33. yield
  34. # # Shutdown
  35. scheduler.shutdown()
  36. async def create_r2r_app(
  37. config_name: Optional[str] = "default",
  38. config_path: Optional[str] = None,
  39. ) -> R2RApp:
  40. config = R2RConfig.load(config_name=config_name, config_path=config_path)
  41. if (
  42. config.embedding.provider == "openai"
  43. and "OPENAI_API_KEY" not in os.environ
  44. ):
  45. raise ValueError(
  46. "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
  47. )
  48. # Build the R2RApp
  49. builder = R2RBuilder(config=config)
  50. return await builder.build()
  51. config_name = os.getenv("R2R_CONFIG_NAME", None)
  52. config_path = os.getenv("R2R_CONFIG_PATH", None)
  53. if not config_path and not config_name:
  54. config_name = "default"
  55. host = os.getenv("R2R_HOST", os.getenv("HOST", "0.0.0.0"))
  56. port = int(os.getenv("R2R_PORT", "7272"))
  57. config = R2RConfig.load(config_name=config_name, config_path=config_path)
  58. project_name = (
  59. os.getenv("R2R_PROJECT_NAME") or config.app.project_name or "r2r_default"
  60. )
  61. logging.info(
  62. f"Environment R2R_IMAGE: {os.getenv('R2R_IMAGE')}",
  63. )
  64. logging.info(
  65. f"Environment R2R_CONFIG_NAME: {'None' if config_name is None else config_name}"
  66. )
  67. logging.info(
  68. f"Environment R2R_CONFIG_PATH: {'None' if config_path is None else config_path}"
  69. )
  70. logging.info(f"Environment R2R_PROJECT_NAME: {os.getenv('R2R_PROJECT_NAME')}")
  71. logging.info(f"Using project name: {project_name}")
  72. logging.info(
  73. f"Environment R2R_POSTGRES_HOST: {os.getenv('R2R_POSTGRES_HOST')}"
  74. )
  75. logging.info(
  76. f"Environment R2R_POSTGRES_DBNAME: {os.getenv('R2R_POSTGRES_DBNAME')}"
  77. )
  78. logging.info(
  79. f"Environment R2R_POSTGRES_PORT: {os.getenv('R2R_POSTGRES_PORT')}"
  80. )
  81. logging.info(
  82. f"Environment R2R_POSTGRES_PASSWORD: {os.getenv('R2R_POSTGRES_PASSWORD')}"
  83. )
  84. # Create the FastAPI app
  85. app = FastAPI(
  86. lifespan=lifespan,
  87. log_config=None,
  88. )
  89. @app.exception_handler(R2RException)
  90. async def r2r_exception_handler(request: Request, exc: R2RException):
  91. return JSONResponse(
  92. status_code=exc.status_code,
  93. content={
  94. "message": exc.message,
  95. "error_type": type(exc).__name__,
  96. },
  97. )
  98. # Add CORS middleware
  99. app.add_middleware(
  100. CORSMiddleware,
  101. allow_origins=["*"],
  102. allow_credentials=True,
  103. allow_methods=["*"],
  104. allow_headers=["*"],
  105. )
  106. app.add_middleware(
  107. ProjectSchemaMiddleware,
  108. default_schema=project_name,
  109. )