123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- import logging
- from fastapi import Depends
- from sqlmodel import select
- from app.api.deps import verfiy_token, verify_token_relation
- from app.models.token_relation import RelationType, TokenRelation, TokenRelationDelete
- from app.services.token.token_relation import TokenRelationService
- from config.config import settings
- class AuthPolicy(object):
- """
- default auth policy with nothing to do
- """
- def enable(self):
- """
- enable auth policy
- """
- def insert_token_rel(
- self, session, token_id: str, relation_type: RelationType, relation_id: str
- ):
- """
- insert a token relation to database when enable token auth policy
- """
- async def delete_token_rel(
- self, session, relation_type: RelationType, relation_id: str
- ):
- """
- delete token relation when enable token auth policy
- """
- def token_filter(
- self, statement, field, relation_type: RelationType, token_id: str
- ):
- """
- add token filter clause when enable token auth policy
- """
- return statement
- class SimpleTokenAuthPolicy(AuthPolicy):
- """
- simple token auth policy
- """
- def enable(self):
- """
- add auth verify dependents to path router
- """
- from app.api.v1 import assistant, assistant_file, thread, message, runs, action
- verify_assistant_depends = Depends(
- verify_token_relation(
- relation_type=RelationType.Assistant, name="assistant_id"
- )
- )
- # assistant router
- for route in assistant.router.routes:
- if (
- route.name == assistant.create_assistant.__name__
- or route.name == assistant.list_assistants.__name__
- ):
- route.dependencies.append(Depends(verfiy_token))
- else:
- route.dependencies.append(verify_assistant_depends)
- # thread router
- verify_thread_depends = Depends(
- verify_token_relation(relation_type=RelationType.Thread, name="thread_id")
- )
- for route in thread.router.routes:
- if route.name == thread.create_thread.__name__:
- route.dependencies.append(
- Depends(
- verify_token_relation(
- relation_type=RelationType.Thread,
- name="thread_id",
- ignore_none_relation_id=True,
- )
- )
- )
- else:
- route.dependencies.append(verify_thread_depends)
- # action router
- verify_action_depends = Depends(
- verify_token_relation(relation_type=RelationType.Action, name="action_id")
- )
- for route in action.router.routes:
- if (
- route.name == action.create_actions.__name__
- or route.name == action.list_actions.__name__
- ):
- route.dependencies.append(Depends(verfiy_token))
- else:
- route.dependencies.append(verify_action_depends)
- self.__append_deps_for_all_routes(
- assistant_file.router, verify_assistant_depends
- )
- self.__append_deps_for_all_routes(message.router, verify_thread_depends)
- self.__append_deps_for_all_routes(runs.router, verify_thread_depends)
- def insert_token_rel(
- self, session, token_id: str, relation_type: RelationType, relation_id: str
- ):
- if token_id:
- relation = TokenRelation(
- token_id=token_id,
- relation_type=relation_type,
- relation_id=str(relation_id),
- )
- session.add(relation)
- async def delete_token_rel(
- self, session, relation_type: RelationType, relation_id: str
- ):
- to_delete = TokenRelationDelete(
- relation_type=relation_type, relation_id=relation_id
- )
- relation = await TokenRelationService.get_relation_to_delete(
- session=session, delete=to_delete
- )
- await session.delete(relation)
- def token_filter(
- self, statement, field, relation_type: RelationType, token_id: str
- ):
- id_subquery = select(TokenRelation.relation_id).where(
- TokenRelation.relation_type == relation_type,
- TokenRelation.token_id == token_id,
- )
- return statement.where(field.in_(id_subquery))
- def __append_deps_for_all_routes(self, router, depends):
- for route in router.routes:
- route.dependencies.append(depends)
- auth_policy: AuthPolicy = (
- AuthPolicy() if settings.AUTH_ENABLE else AuthPolicy()
- ) ##SimpleTokenAuthPolicy()
- def register(app):
- logging.info("use auth polily: %s", auth_policy.__class__.__name__)
- auth_policy.enable()
|