project_schema.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import logging
  2. import re
  3. from fastapi import Request
  4. from fastapi.responses import JSONResponse
  5. from starlette.middleware.base import BaseHTTPMiddleware
  6. from core.utils.context import project_schema_context, set_project_schema
  7. logger = logging.getLogger(__name__)
  8. class ProjectSchemaMiddleware(BaseHTTPMiddleware):
  9. def __init__(
  10. self, app, default_schema: str = "r2r_default", schema_exists_func=None
  11. ):
  12. super().__init__(app)
  13. self.default_schema = default_schema
  14. self.schema_exists_func = schema_exists_func
  15. async def dispatch(self, request: Request, call_next):
  16. # Skip schema check for static files, docs, etc.
  17. if request.url.path.startswith(
  18. ("/docs", "/redoc", "/static", "/openapi.json")
  19. ):
  20. return await call_next(request)
  21. # Get the project name from the x-project-name header or use default
  22. schema_name = request.headers.get(
  23. "x-project-name", self.default_schema
  24. )
  25. # Validate schema name format (prevent SQL injection)
  26. if not re.match(r"^[a-zA-Z0-9_]+$", schema_name):
  27. return JSONResponse(
  28. status_code=400,
  29. content={"detail": "Invalid schema name format"},
  30. )
  31. # Check if schema exists (optional)
  32. if self.schema_exists_func and schema_name != self.default_schema:
  33. try:
  34. schema_exists = await self.schema_exists_func(schema_name)
  35. if not schema_exists:
  36. return JSONResponse(
  37. status_code=403,
  38. content={
  39. "detail": f"Schema '{schema_name}' does not exist"
  40. },
  41. )
  42. except Exception as e:
  43. logger.error(f"Error checking schema existence: {e}")
  44. return JSONResponse(
  45. status_code=500,
  46. content={
  47. "detail": "Internal server error checking schema"
  48. },
  49. )
  50. # Set the project schema in the context for this request
  51. schema_name = schema_name.replace('"', "")
  52. token = set_project_schema(schema_name)
  53. try:
  54. # Process the request with the set schema
  55. return await call_next(request)
  56. finally:
  57. # Reset context when done
  58. project_schema_context.reset(token)