| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 | import loggingfrom fastapi import Dependsfrom sqlmodel import selectfrom app.api.deps import verfiy_token, verify_token_relationfrom app.models.token_relation import RelationType, TokenRelation, TokenRelationDeletefrom app.services.token.token_relation import TokenRelationServicefrom config.config import settingsclass AuthPolicy(object):    """    default auth policy with nothing to do    """    def enable(self):        """        enable auth policy        """    def insert_token_rel(        self, session, token_id: str, relation_type: RelationType, relation_id: str    ):        """        insert a token relation to database when enable token auth policy        """    async def delete_token_rel(        self, session, relation_type: RelationType, relation_id: str    ):        """        delete token relation when enable token auth policy        """    def token_filter(        self, statement, field, relation_type: RelationType, token_id: str    ):        """        add token filter clause when enable token auth policy        """        return statementclass SimpleTokenAuthPolicy(AuthPolicy):    """    simple token auth policy    """    def enable(self):        """        add auth verify dependents to path router        """        from app.api.v1 import assistant, assistant_file, thread, message, runs, action        verify_assistant_depends = Depends(            verify_token_relation(                relation_type=RelationType.Assistant, name="assistant_id"            )        )        # assistant router        for route in assistant.router.routes:            if (                route.name == assistant.create_assistant.__name__                or route.name == assistant.list_assistants.__name__            ):                route.dependencies.append(Depends(verfiy_token))            else:                route.dependencies.append(verify_assistant_depends)        # thread router        verify_thread_depends = Depends(            verify_token_relation(relation_type=RelationType.Thread, name="thread_id")        )        for route in thread.router.routes:            if route.name == thread.create_thread.__name__:                route.dependencies.append(                    Depends(                        verify_token_relation(                            relation_type=RelationType.Thread,                            name="thread_id",                            ignore_none_relation_id=True,                        )                    )                )            else:                route.dependencies.append(verify_thread_depends)        # action router        verify_action_depends = Depends(            verify_token_relation(relation_type=RelationType.Action, name="action_id")        )        for route in action.router.routes:            if (                route.name == action.create_actions.__name__                or route.name == action.list_actions.__name__            ):                route.dependencies.append(Depends(verfiy_token))            else:                route.dependencies.append(verify_action_depends)        self.__append_deps_for_all_routes(            assistant_file.router, verify_assistant_depends        )        self.__append_deps_for_all_routes(message.router, verify_thread_depends)        self.__append_deps_for_all_routes(runs.router, verify_thread_depends)    def insert_token_rel(        self, session, token_id: str, relation_type: RelationType, relation_id: str    ):        if token_id:            relation = TokenRelation(                token_id=token_id,                relation_type=relation_type,                relation_id=str(relation_id),            )            session.add(relation)    async def delete_token_rel(        self, session, relation_type: RelationType, relation_id: str    ):        to_delete = TokenRelationDelete(            relation_type=relation_type, relation_id=relation_id        )        relation = await TokenRelationService.get_relation_to_delete(            session=session, delete=to_delete        )        await session.delete(relation)    def token_filter(        self, statement, field, relation_type: RelationType, token_id: str    ):        id_subquery = select(TokenRelation.relation_id).where(            TokenRelation.relation_type == relation_type,            TokenRelation.token_id == token_id,        )        return statement.where(field.in_(id_subquery))    def __append_deps_for_all_routes(self, router, depends):        for route in router.routes:            route.dependencies.append(depends)auth_policy: AuthPolicy = (    AuthPolicy() if settings.AUTH_ENABLE else AuthPolicy())  ##SimpleTokenAuthPolicy()def register(app):    logging.info("use auth polily: %s", auth_policy.__class__.__name__)    auth_policy.enable()
 |