deps.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from typing import AsyncGenerator
  2. from fastapi import Depends, Request
  3. from fastapi.security import APIKeyHeader
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from app.exceptions.exception import (
  6. AuthenticationError,
  7. AuthorizationError,
  8. ResourceNotFoundError,
  9. )
  10. from app.models.token import Token
  11. from app.models.token_relation import RelationType, TokenRelationQuery
  12. from app.providers import database
  13. from app.services.token.token import TokenService
  14. from app.services.token.token_relation import TokenRelationService
  15. from config.config import settings
  16. async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
  17. """session生成器 作为fast api的Depends选项"""
  18. async with database.async_session_local() as session:
  19. yield session
  20. class OAuth2Bearer(APIKeyHeader):
  21. """
  22. it use to fetch token from header
  23. """
  24. def __init__(
  25. self,
  26. *,
  27. name: str,
  28. scheme_name: str | None = None,
  29. description: str | None = None,
  30. auto_error: bool = True
  31. ):
  32. super().__init__(
  33. name=name,
  34. scheme_name=scheme_name,
  35. description=description,
  36. auto_error=auto_error,
  37. )
  38. async def __call__(self, request: Request) -> str:
  39. authorization_header_value = request.headers.get(self.model.name)
  40. if authorization_header_value:
  41. scheme, _, param = authorization_header_value.partition(" ")
  42. if scheme.lower() == "bearer" and param.strip() != "":
  43. return param.strip()
  44. return None
  45. oauth_token = OAuth2Bearer(name="Authorization")
  46. async def verify_admin_token(token=Depends(oauth_token)) -> Token:
  47. """
  48. admin token authentication
  49. """
  50. if token is None:
  51. raise AuthenticationError()
  52. if settings.AUTH_ADMIN_TOKEN != token:
  53. raise AuthorizationError()
  54. async def get_token(
  55. session=Depends(get_async_session), token=Depends(oauth_token)
  56. ) -> Token:
  57. """
  58. get token info
  59. """
  60. if token and token != "":
  61. try:
  62. return await TokenService.get_token(session=session, token=token)
  63. except ResourceNotFoundError:
  64. pass
  65. return None
  66. async def verfiy_token(token: Token = Depends(get_token)):
  67. if token is None:
  68. raise AuthenticationError()
  69. async def get_token_id(token: Token = Depends(get_token)):
  70. """
  71. Return token_id, which can be considered as user information.
  72. """
  73. return token.id if token is not None else None
  74. def get_param(name: str):
  75. """
  76. extract param from Request
  77. """
  78. async def get_param_from_request(request: Request):
  79. if name in request.path_params:
  80. return request.path_params[name]
  81. if name in request.query_params:
  82. return request.query_params[name]
  83. body = await request.json()
  84. if name in body:
  85. return body[name]
  86. return get_param_from_request
  87. def verify_token_relation(
  88. relation_type: RelationType, name: str, ignore_none_relation_id: bool = False
  89. ):
  90. """
  91. param relation_type: relation type
  92. param name: param name
  93. param ignore_none_relation_id: if ignore_none_relation_id is set, return where relation_id is None, use for copy thread api
  94. """
  95. async def verify_authorization(
  96. session=Depends(get_async_session),
  97. token_id=Depends(get_token_id),
  98. relation_id=Depends(get_param(name)),
  99. ):
  100. if token_id and ignore_none_relation_id:
  101. return
  102. if token_id and relation_id:
  103. verify = TokenRelationQuery(
  104. token_id=token_id, relation_type=relation_type, relation_id=relation_id
  105. )
  106. if await TokenRelationService.verify_relation(
  107. session=session, verify=verify
  108. ):
  109. return
  110. raise AuthorizationError()
  111. return verify_authorization