123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- import logging
- from abc import ABC, abstractmethod
- from typing import TYPE_CHECKING, Optional
- from fastapi import Security
- from fastapi.security import (
- APIKeyHeader,
- HTTPAuthorizationCredentials,
- HTTPBearer,
- )
- from ..abstractions import R2RException, Token, TokenData
- from ..api.models import User
- from .base import Provider, ProviderConfig
- from .crypto import CryptoProvider
- # from .database import DatabaseProvider
- from .email import EmailProvider
- logger = logging.getLogger()
- if TYPE_CHECKING:
- from core.database import PostgresDatabaseProvider
- api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
- class AuthConfig(ProviderConfig):
- secret_key: Optional[str] = None
- require_authentication: bool = False
- require_email_verification: bool = False
- default_admin_email: str = "admin@example.com"
- default_admin_password: str = "change_me_immediately"
- access_token_lifetime_in_minutes: Optional[int] = None
- refresh_token_lifetime_in_days: Optional[int] = None
- @property
- def supported_providers(self) -> list[str]:
- return ["r2r"]
- def validate_config(self) -> None:
- pass
- class AuthProvider(Provider, ABC):
- security = HTTPBearer(auto_error=False)
- crypto_provider: CryptoProvider
- email_provider: EmailProvider
- database_provider: "PostgresDatabaseProvider"
- def __init__(
- self,
- config: AuthConfig,
- crypto_provider: CryptoProvider,
- database_provider: "PostgresDatabaseProvider",
- email_provider: EmailProvider,
- ):
- if not isinstance(config, AuthConfig):
- raise ValueError(
- "AuthProvider must be initialized with an AuthConfig"
- )
- self.config = config
- self.admin_email = config.default_admin_email
- self.admin_password = config.default_admin_password
- self.crypto_provider = crypto_provider
- self.database_provider = database_provider
- self.email_provider = email_provider
- super().__init__(config)
- self.config: AuthConfig = config
- self.database_provider: "PostgresDatabaseProvider" = database_provider
- async def _get_default_admin_user(self) -> User:
- return await self.database_provider.users_handler.get_user_by_email(
- self.admin_email
- )
- @abstractmethod
- def create_access_token(self, data: dict) -> str:
- pass
- @abstractmethod
- def create_refresh_token(self, data: dict) -> str:
- pass
- @abstractmethod
- async def decode_token(self, token: str) -> TokenData:
- pass
- @abstractmethod
- async def user(self, token: str) -> User:
- pass
- @abstractmethod
- def get_current_active_user(self, current_user: User) -> User:
- pass
- @abstractmethod
- async def register(self, email: str, password: str) -> User:
- pass
- @abstractmethod
- async def verify_email(
- self, email: str, verification_code: str
- ) -> dict[str, str]:
- pass
- @abstractmethod
- async def login(self, email: str, password: str) -> dict[str, Token]:
- pass
- @abstractmethod
- async def refresh_access_token(
- self, refresh_token: str
- ) -> dict[str, Token]:
- pass
- def auth_wrapper(
- self,
- public: bool = False,
- ):
- async def _auth_wrapper(
- auth: Optional[HTTPAuthorizationCredentials] = Security(
- self.security
- ),
- api_key: Optional[str] = Security(api_key_header),
- ) -> User:
- # If authentication is not required and no credentials are provided, return the default admin user
- if (
- ((not self.config.require_authentication) or public)
- and auth is None
- and api_key is None
- ):
- return await self._get_default_admin_user()
- if not auth and not api_key:
- raise R2RException(
- message="No credentials provided",
- status_code=401,
- )
- if auth and api_key:
- raise R2RException(
- message="Cannot have both Bearer token and API key",
- status_code=400,
- )
- # 1. Try JWT if `auth` is present (Bearer token)
- if auth is not None:
- credentials = auth.credentials
- try:
- token_data = await self.decode_token(credentials)
- user = await self.database_provider.users_handler.get_user_by_email(
- token_data.email
- )
- if user is not None:
- return user
- except R2RException:
- # JWT decoding failed for logical reasons (invalid token)
- pass
- except Exception as e:
- # JWT decoding failed unexpectedly, log and continue
- logger.debug(f"JWT verification failed: {e}")
- # 2. If JWT failed, try API key from Bearer token
- # Expected format: key_id.raw_api_key
- if "." in credentials:
- key_id, raw_api_key = credentials.split(".", 1)
- api_key_record = await self.database_provider.users_handler.get_api_key_record(
- key_id
- )
- if api_key_record is not None:
- hashed_key = api_key_record["hashed_key"]
- if self.crypto_provider.verify_api_key(
- raw_api_key, hashed_key
- ):
- user = await self.database_provider.users_handler.get_user_by_id(
- api_key_record["user_id"]
- )
- if user is not None and user.is_active:
- return user
- # 3. If no Bearer token worked, try the X-API-Key header
- if api_key is not None and "." in api_key:
- key_id, raw_api_key = api_key.split(".", 1)
- api_key_record = await self.database_provider.users_handler.get_api_key_record(
- key_id
- )
- if api_key_record is not None:
- hashed_key = api_key_record["hashed_key"]
- if self.crypto_provider.verify_api_key(
- raw_api_key, hashed_key
- ):
- user = await self.database_provider.users_handler.get_user_by_id(
- api_key_record["user_id"]
- )
- if user is not None and user.is_active:
- return user
- # If we reach here, both JWT and API key auth failed
- raise R2RException(
- message="Invalid token or API key",
- status_code=401,
- )
- return _auth_wrapper
- @abstractmethod
- async def change_password(
- self, user: User, current_password: str, new_password: str
- ) -> dict[str, str]:
- pass
- @abstractmethod
- async def request_password_reset(self, email: str) -> dict[str, str]:
- pass
- @abstractmethod
- async def confirm_password_reset(
- self, reset_token: str, new_password: str
- ) -> dict[str, str]:
- pass
- @abstractmethod
- async def logout(self, token: str) -> dict[str, str]:
- pass
- @abstractmethod
- async def send_reset_email(self, email: str) -> dict[str, str]:
- pass
|