| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 | from typing import AsyncGeneratorfrom fastapi import Depends, Requestfrom fastapi.security import APIKeyHeaderfrom sqlalchemy.ext.asyncio import AsyncSessionfrom app.exceptions.exception import (    AuthenticationError,    AuthorizationError,    ResourceNotFoundError,)from app.models.token import Tokenfrom app.models.token_relation import RelationType, TokenRelationQueryfrom app.providers import databasefrom app.services.token.token import TokenServicefrom app.services.token.token_relation import TokenRelationServicefrom config.config import settingsasync def get_async_session() -> AsyncGenerator[AsyncSession, None]:    """session生成器 作为fast api的Depends选项"""    async with database.async_session_local() as session:        yield sessionclass OAuth2Bearer(APIKeyHeader):    """    it use to fetch token from header    """    def __init__(        self,        *,        name: str,        scheme_name: str | None = None,        description: str | None = None,        auto_error: bool = True    ):        super().__init__(            name=name,            scheme_name=scheme_name,            description=description,            auto_error=auto_error,        )    async def __call__(self, request: Request) -> str:        authorization_header_value = request.headers.get(self.model.name)        if authorization_header_value:            scheme, _, param = authorization_header_value.partition(" ")            if scheme.lower() == "bearer" and param.strip() != "":                return param.strip()        return Noneoauth_token = OAuth2Bearer(name="Authorization")async def verify_admin_token(token=Depends(oauth_token)) -> Token:    """    admin token authentication    """    if token is None:        raise AuthenticationError()    if settings.AUTH_ADMIN_TOKEN != token:        raise AuthorizationError()async def get_token(    session=Depends(get_async_session), token=Depends(oauth_token)) -> Token:    """    get token info    """    if token and token != "":        try:            return await TokenService.get_token(session=session, token=token)        except ResourceNotFoundError:            pass    return Noneasync def verfiy_token(token: Token = Depends(get_token)):    if token is None:        raise AuthenticationError()async def get_token_id(token: Token = Depends(get_token)):    """    Return token_id, which can be considered as user information.    """    return token.id if token is not None else Nonedef get_param(name: str):    """    extract param from Request    """    async def get_param_from_request(request: Request):        if name in request.path_params:            return request.path_params[name]        if name in request.query_params:            return request.query_params[name]        body = await request.json()        if name in body:            return body[name]    return get_param_from_requestdef verify_token_relation(    relation_type: RelationType, name: str, ignore_none_relation_id: bool = False):    """    param relation_type: relation type    param name: param name    param ignore_none_relation_id: if ignore_none_relation_id is set, return where relation_id is None, use for copy thread api    """    async def verify_authorization(        session=Depends(get_async_session),        token_id=Depends(get_token_id),        relation_id=Depends(get_param(name)),    ):        if token_id and ignore_none_relation_id:            return        if token_id and relation_id:            verify = TokenRelationQuery(                token_id=token_id, relation_type=relation_type, relation_id=relation_id            )            if await TokenRelationService.verify_relation(                session=session, verify=verify            ):                return        raise AuthorizationError()    return verify_authorization
 |