deps.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from typing import AsyncGenerator
  2. from fastapi import Depends, Request
  3. from fastapi.security import APIKeyHeader
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from app.exceptions.exception import AuthenticationError, AuthorizationError, ResourceNotFoundError
  6. from app.models.token import Token
  7. from app.models.token_relation import RelationType, TokenRelationQuery
  8. from app.providers import database
  9. from app.services.token.token import TokenService
  10. from app.services.token.token_relation import TokenRelationService
  11. from config.config import settings
  12. async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
  13. """session生成器 作为fast api的Depends选项"""
  14. async with database.async_session_local() as session:
  15. yield session
  16. class OAuth2Bearer(APIKeyHeader):
  17. """
  18. it use to fetch token from header
  19. """
  20. def __init__(
  21. self, *, name: str, scheme_name: str | None = None, description: str | None = None, auto_error: bool = True
  22. ):
  23. super().__init__(name=name, scheme_name=scheme_name, description=description, auto_error=auto_error)
  24. async def __call__(self, request: Request) -> str:
  25. authorization_header_value = request.headers.get(self.model.name)
  26. if authorization_header_value:
  27. scheme, _, param = authorization_header_value.partition(" ")
  28. if scheme.lower() == "bearer" and param.strip() != "":
  29. return param.strip()
  30. return None
  31. oauth_token = OAuth2Bearer(name="Authorization")
  32. async def verify_admin_token(token=Depends(oauth_token)) -> Token:
  33. """
  34. admin token authentication
  35. """
  36. if token is None:
  37. raise AuthenticationError()
  38. if settings.AUTH_ADMIN_TOKEN != token:
  39. raise AuthorizationError()
  40. async def get_token(session=Depends(get_async_session), token=Depends(oauth_token)) -> Token:
  41. """
  42. get token info
  43. """
  44. if token and token != "":
  45. try:
  46. return await TokenService.get_token(session=session, token=token)
  47. except ResourceNotFoundError:
  48. pass
  49. return None
  50. async def verfiy_token(token: Token = Depends(get_token)):
  51. if token is None:
  52. raise AuthenticationError()
  53. async def get_token_id(token: Token = Depends(get_token)):
  54. """
  55. Return token_id, which can be considered as user information.
  56. """
  57. return token.id if token is not None else None
  58. def get_param(name: str):
  59. """
  60. extract param from Request
  61. """
  62. async def get_param_from_request(request: Request):
  63. if name in request.path_params:
  64. return request.path_params[name]
  65. if name in request.query_params:
  66. return request.query_params[name]
  67. body = await request.json()
  68. if name in body:
  69. return body[name]
  70. return get_param_from_request
  71. def verify_token_relation(relation_type: RelationType, name: str, ignore_none_relation_id: bool = False):
  72. """
  73. param relation_type: relation type
  74. param name: param name
  75. param ignore_none_relation_id: if ignore_none_relation_id is set, return where relation_id is None, use for copy thread api
  76. """
  77. async def verify_authorization(
  78. session=Depends(get_async_session), token_id=Depends(get_token_id), relation_id=Depends(get_param(name))
  79. ):
  80. if token_id and ignore_none_relation_id:
  81. return
  82. if token_id and relation_id:
  83. verify = TokenRelationQuery(token_id=token_id, relation_type=relation_type, relation_id=relation_id)
  84. if await TokenRelationService.verify_relation(session=session, verify=verify):
  85. return
  86. raise AuthorizationError()
  87. return verify_authorization