auth_provider.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import logging
  2. from fastapi import Depends
  3. from sqlmodel import select
  4. from app.api.deps import verfiy_token, verify_token_relation
  5. from app.models.token_relation import RelationType, TokenRelation, TokenRelationDelete
  6. from app.services.token.token_relation import TokenRelationService
  7. from config.config import settings
  8. class AuthPolicy(object):
  9. """
  10. default auth policy with nothing to do
  11. """
  12. def enable(self):
  13. """
  14. enable auth policy
  15. """
  16. def insert_token_rel(
  17. self, session, token_id: str, relation_type: RelationType, relation_id: str
  18. ):
  19. """
  20. insert a token relation to database when enable token auth policy
  21. """
  22. async def delete_token_rel(
  23. self, session, relation_type: RelationType, relation_id: str
  24. ):
  25. """
  26. delete token relation when enable token auth policy
  27. """
  28. def token_filter(
  29. self, statement, field, relation_type: RelationType, token_id: str
  30. ):
  31. """
  32. add token filter clause when enable token auth policy
  33. """
  34. return statement
  35. class SimpleTokenAuthPolicy(AuthPolicy):
  36. """
  37. simple token auth policy
  38. """
  39. def enable(self):
  40. """
  41. add auth verify dependents to path router
  42. """
  43. from app.api.v1 import assistant, assistant_file, thread, message, runs, action
  44. verify_assistant_depends = Depends(
  45. verify_token_relation(
  46. relation_type=RelationType.Assistant, name="assistant_id"
  47. )
  48. )
  49. # assistant router
  50. for route in assistant.router.routes:
  51. if (
  52. route.name == assistant.create_assistant.__name__
  53. or route.name == assistant.list_assistants.__name__
  54. ):
  55. route.dependencies.append(Depends(verfiy_token))
  56. else:
  57. route.dependencies.append(verify_assistant_depends)
  58. # thread router
  59. verify_thread_depends = Depends(
  60. verify_token_relation(relation_type=RelationType.Thread, name="thread_id")
  61. )
  62. for route in thread.router.routes:
  63. if route.name == thread.create_thread.__name__:
  64. route.dependencies.append(
  65. Depends(
  66. verify_token_relation(
  67. relation_type=RelationType.Thread,
  68. name="thread_id",
  69. ignore_none_relation_id=True,
  70. )
  71. )
  72. )
  73. else:
  74. route.dependencies.append(verify_thread_depends)
  75. # action router
  76. verify_action_depends = Depends(
  77. verify_token_relation(relation_type=RelationType.Action, name="action_id")
  78. )
  79. for route in action.router.routes:
  80. if (
  81. route.name == action.create_actions.__name__
  82. or route.name == action.list_actions.__name__
  83. ):
  84. route.dependencies.append(Depends(verfiy_token))
  85. else:
  86. route.dependencies.append(verify_action_depends)
  87. self.__append_deps_for_all_routes(
  88. assistant_file.router, verify_assistant_depends
  89. )
  90. self.__append_deps_for_all_routes(message.router, verify_thread_depends)
  91. self.__append_deps_for_all_routes(runs.router, verify_thread_depends)
  92. def insert_token_rel(
  93. self, session, token_id: str, relation_type: RelationType, relation_id: str
  94. ):
  95. if token_id:
  96. relation = TokenRelation(
  97. token_id=token_id,
  98. relation_type=relation_type,
  99. relation_id=str(relation_id),
  100. )
  101. session.add(relation)
  102. async def delete_token_rel(
  103. self, session, relation_type: RelationType, relation_id: str
  104. ):
  105. to_delete = TokenRelationDelete(
  106. relation_type=relation_type, relation_id=relation_id
  107. )
  108. relation = await TokenRelationService.get_relation_to_delete(
  109. session=session, delete=to_delete
  110. )
  111. await session.delete(relation)
  112. def token_filter(
  113. self, statement, field, relation_type: RelationType, token_id: str
  114. ):
  115. id_subquery = select(TokenRelation.relation_id).where(
  116. TokenRelation.relation_type == relation_type,
  117. TokenRelation.token_id == token_id,
  118. )
  119. return statement.where(field.in_(id_subquery))
  120. def __append_deps_for_all_routes(self, router, depends):
  121. for route in router.routes:
  122. route.dependencies.append(depends)
  123. auth_policy: AuthPolicy = (
  124. AuthPolicy() if settings.AUTH_ENABLE else AuthPolicy()
  125. ) ##SimpleTokenAuthPolicy()
  126. def register(app):
  127. logging.info("use auth polily: %s", auth_policy.__class__.__name__)
  128. auth_policy.enable()