123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- import logging
- from fastapi import Depends
- from sqlmodel import select
- from app.api.deps import verfiy_token, verify_token_relation
- from app.models.token_relation import RelationType, TokenRelation, TokenRelationDelete
- from app.services.token.token_relation import TokenRelationService
- from config.config import settings
- class AuthPolicy(object):
- """
- default auth policy with nothing to do
- """
- def enable(self):
- """
- enable auth policy
- """
- def insert_token_rel(self, session, token_id: str, relation_type: RelationType, relation_id: str):
- """
- insert a token relation to database when enable token auth policy
- """
- async def delete_token_rel(self, session, relation_type: RelationType, relation_id: str):
- """
- delete token relation when enable token auth policy
- """
- def token_filter(self, statement, field, relation_type: RelationType, token_id: str):
- """
- add token filter clause when enable token auth policy
- """
- return statement
- class SimpleTokenAuthPolicy(AuthPolicy):
- """
- simple token auth policy
- """
- def enable(self):
- """
- add auth verify dependents to path router
- """
- from app.api.v1 import assistant, assistant_file, thread, message, runs, action
- verify_assistant_depends = Depends(
- verify_token_relation(relation_type=RelationType.Assistant, name="assistant_id")
- )
- # assistant router
- for route in assistant.router.routes:
- if route.name == assistant.create_assistant.__name__ or route.name == assistant.list_assistants.__name__:
- route.dependencies.append(Depends(verfiy_token))
- else:
- route.dependencies.append(verify_assistant_depends)
- # thread router
- verify_thread_depends = Depends(verify_token_relation(relation_type=RelationType.Thread, name="thread_id"))
- for route in thread.router.routes:
- if route.name == thread.create_thread.__name__:
- route.dependencies.append(
- Depends(
- verify_token_relation(
- relation_type=RelationType.Thread, name="thread_id", ignore_none_relation_id=True
- )
- )
- )
- else:
- route.dependencies.append(verify_thread_depends)
- # action router
- verify_action_depends = Depends(verify_token_relation(relation_type=RelationType.Action, name="action_id"))
- for route in action.router.routes:
- if route.name == action.create_actions.__name__ or route.name == action.list_actions.__name__:
- route.dependencies.append(Depends(verfiy_token))
- else:
- route.dependencies.append(verify_action_depends)
- self.__append_deps_for_all_routes(assistant_file.router, verify_assistant_depends)
- self.__append_deps_for_all_routes(message.router, verify_thread_depends)
- self.__append_deps_for_all_routes(runs.router, verify_thread_depends)
- def insert_token_rel(self, session, token_id: str, relation_type: RelationType, relation_id: str):
- if token_id:
- relation = TokenRelation(token_id=token_id, relation_type=relation_type, relation_id=str(relation_id))
- session.add(relation)
- async def delete_token_rel(self, session, relation_type: RelationType, relation_id: str):
- to_delete = TokenRelationDelete(relation_type=relation_type, relation_id=relation_id)
- relation = await TokenRelationService.get_relation_to_delete(session=session, delete=to_delete)
- await session.delete(relation)
- def token_filter(self, statement, field, relation_type: RelationType, token_id: str):
- id_subquery = select(TokenRelation.relation_id).where(
- TokenRelation.relation_type == relation_type, TokenRelation.token_id == token_id
- )
- return statement.where(field.in_(id_subquery))
- def __append_deps_for_all_routes(self, router, depends):
- for route in router.routes:
- route.dependencies.append(depends)
- auth_policy: AuthPolicy = SimpleTokenAuthPolicy() if settings.AUTH_ENABLE else AuthPolicy()
- def register(app):
- logging.info("use auth polily: %s", auth_policy.__class__.__name__)
- auth_policy.enable()
|