app_entry.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import os
  2. from contextlib import asynccontextmanager
  3. from typing import Optional
  4. from apscheduler.schedulers.asyncio import AsyncIOScheduler
  5. from fastapi import FastAPI, Request
  6. from fastapi.middleware.cors import CORSMiddleware
  7. from fastapi.responses import JSONResponse
  8. from core.base import R2RException
  9. from core.utils.logging_config import configure_logging
  10. from .assembly import R2RBuilder, R2RConfig
  11. logger, log_file = configure_logging()
  12. # Global scheduler
  13. scheduler = AsyncIOScheduler()
  14. @asynccontextmanager
  15. async def lifespan(app: FastAPI):
  16. # Startup
  17. r2r_app = await create_r2r_app(
  18. config_name=config_name,
  19. config_path=config_path,
  20. )
  21. # Copy all routes from r2r_app to app
  22. app.router.routes = r2r_app.app.routes
  23. # Copy middleware and exception handlers
  24. app.middleware = r2r_app.app.middleware # type: ignore
  25. app.exception_handlers = r2r_app.app.exception_handlers
  26. # Start the scheduler
  27. scheduler.start()
  28. # Start the Hatchet worker
  29. await r2r_app.orchestration_provider.start_worker()
  30. yield
  31. # # Shutdown
  32. scheduler.shutdown()
  33. async def create_r2r_app(
  34. config_name: Optional[str] = "default",
  35. config_path: Optional[str] = None,
  36. ):
  37. config = R2RConfig.load(config_name=config_name, config_path=config_path)
  38. if (
  39. config.embedding.provider == "openai"
  40. and "OPENAI_API_KEY" not in os.environ
  41. ):
  42. raise ValueError(
  43. "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
  44. )
  45. # Build the R2RApp
  46. builder = R2RBuilder(config=config)
  47. return await builder.build()
  48. config_name = os.getenv("R2R_CONFIG_NAME", None)
  49. config_path = os.getenv("R2R_CONFIG_PATH", None)
  50. if not config_path and not config_name:
  51. config_name = "default"
  52. host = os.getenv("R2R_HOST", os.getenv("HOST", "0.0.0.0"))
  53. port = int(os.getenv("R2R_PORT", "7272"))
  54. logger.info(
  55. f"Environment R2R_CONFIG_NAME: {'None' if config_name is None else config_name}"
  56. )
  57. logger.info(
  58. f"Environment R2R_CONFIG_PATH: {'None' if config_path is None else config_path}"
  59. )
  60. logger.info(f"Environment R2R_PROJECT_NAME: {os.getenv('R2R_PROJECT_NAME')}")
  61. logger.info(f"Environment R2R_POSTGRES_HOST: {os.getenv('R2R_POSTGRES_HOST')}")
  62. logger.info(
  63. f"Environment R2R_POSTGRES_DBNAME: {os.getenv('R2R_POSTGRES_DBNAME')}"
  64. )
  65. logger.info(f"Environment R2R_POSTGRES_PORT: {os.getenv('R2R_POSTGRES_PORT')}")
  66. logger.info(
  67. f"Environment R2R_POSTGRES_PASSWORD: {os.getenv('R2R_POSTGRES_PASSWORD')}"
  68. )
  69. logger.info(
  70. f"Environment R2R_PROJECT_NAME: {os.getenv('R2R_PR2R_PROJECT_NAME')}"
  71. )
  72. # Create the FastAPI app
  73. app = FastAPI(
  74. lifespan=lifespan,
  75. log_config=None,
  76. )
  77. @app.exception_handler(R2RException)
  78. async def r2r_exception_handler(request: Request, exc: R2RException):
  79. return JSONResponse(
  80. status_code=exc.status_code,
  81. content={
  82. "message": exc.message,
  83. "error_type": type(exc).__name__,
  84. },
  85. )
  86. # Add CORS middleware
  87. app.add_middleware(
  88. CORSMiddleware,
  89. allow_origins=["*"],
  90. allow_credentials=True,
  91. allow_methods=["*"],
  92. allow_headers=["*"],
  93. )