auth_provider.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import logging
  2. from fastapi import Depends
  3. from sqlmodel import select
  4. from app.api.deps import verfiy_token, verify_token_relation
  5. from app.models.token_relation import RelationType, TokenRelation, TokenRelationDelete
  6. from app.services.token.token_relation import TokenRelationService
  7. from config.config import settings
  8. class AuthPolicy(object):
  9. """
  10. default auth policy with nothing to do
  11. """
  12. def enable(self):
  13. """
  14. enable auth policy
  15. """
  16. def insert_token_rel(self, session, token_id: str, relation_type: RelationType, relation_id: str):
  17. """
  18. insert a token relation to database when enable token auth policy
  19. """
  20. async def delete_token_rel(self, session, relation_type: RelationType, relation_id: str):
  21. """
  22. delete token relation when enable token auth policy
  23. """
  24. def token_filter(self, statement, field, relation_type: RelationType, token_id: str):
  25. """
  26. add token filter clause when enable token auth policy
  27. """
  28. return statement
  29. class SimpleTokenAuthPolicy(AuthPolicy):
  30. """
  31. simple token auth policy
  32. """
  33. def enable(self):
  34. """
  35. add auth verify dependents to path router
  36. """
  37. from app.api.v1 import assistant, assistant_file, thread, message, runs, action
  38. verify_assistant_depends = Depends(
  39. verify_token_relation(relation_type=RelationType.Assistant, name="assistant_id")
  40. )
  41. # assistant router
  42. for route in assistant.router.routes:
  43. if route.name == assistant.create_assistant.__name__ or route.name == assistant.list_assistants.__name__:
  44. route.dependencies.append(Depends(verfiy_token))
  45. else:
  46. route.dependencies.append(verify_assistant_depends)
  47. # thread router
  48. verify_thread_depends = Depends(verify_token_relation(relation_type=RelationType.Thread, name="thread_id"))
  49. for route in thread.router.routes:
  50. if route.name == thread.create_thread.__name__:
  51. route.dependencies.append(
  52. Depends(
  53. verify_token_relation(
  54. relation_type=RelationType.Thread, name="thread_id", ignore_none_relation_id=True
  55. )
  56. )
  57. )
  58. else:
  59. route.dependencies.append(verify_thread_depends)
  60. # action router
  61. verify_action_depends = Depends(verify_token_relation(relation_type=RelationType.Action, name="action_id"))
  62. for route in action.router.routes:
  63. if route.name == action.create_actions.__name__ or route.name == action.list_actions.__name__:
  64. route.dependencies.append(Depends(verfiy_token))
  65. else:
  66. route.dependencies.append(verify_action_depends)
  67. self.__append_deps_for_all_routes(assistant_file.router, verify_assistant_depends)
  68. self.__append_deps_for_all_routes(message.router, verify_thread_depends)
  69. self.__append_deps_for_all_routes(runs.router, verify_thread_depends)
  70. def insert_token_rel(self, session, token_id: str, relation_type: RelationType, relation_id: str):
  71. if token_id:
  72. relation = TokenRelation(token_id=token_id, relation_type=relation_type, relation_id=str(relation_id))
  73. session.add(relation)
  74. async def delete_token_rel(self, session, relation_type: RelationType, relation_id: str):
  75. to_delete = TokenRelationDelete(relation_type=relation_type, relation_id=relation_id)
  76. relation = await TokenRelationService.get_relation_to_delete(session=session, delete=to_delete)
  77. await session.delete(relation)
  78. def token_filter(self, statement, field, relation_type: RelationType, token_id: str):
  79. id_subquery = select(TokenRelation.relation_id).where(
  80. TokenRelation.relation_type == relation_type, TokenRelation.token_id == token_id
  81. )
  82. return statement.where(field.in_(id_subquery))
  83. def __append_deps_for_all_routes(self, router, depends):
  84. for route in router.routes:
  85. route.dependencies.append(depends)
  86. auth_policy: AuthPolicy = SimpleTokenAuthPolicy() if settings.AUTH_ENABLE else AuthPolicy()
  87. def register(app):
  88. logging.info("use auth polily: %s", auth_policy.__class__.__name__)
  89. auth_policy.enable()