import logging from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Optional from fastapi import Security from fastapi.security import ( APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer, ) from ..abstractions import R2RException, Token, TokenData from ..api.models import User from .base import Provider, ProviderConfig from .crypto import CryptoProvider # from .database import DatabaseProvider from .email import EmailProvider logger = logging.getLogger() if TYPE_CHECKING: from core.database import PostgresDatabaseProvider api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) class AuthConfig(ProviderConfig): secret_key: Optional[str] = None require_authentication: bool = False require_email_verification: bool = False default_admin_email: str = "admin@example.com" default_admin_password: str = "change_me_immediately" access_token_lifetime_in_minutes: Optional[int] = None refresh_token_lifetime_in_days: Optional[int] = None @property def supported_providers(self) -> list[str]: return ["r2r"] def validate_config(self) -> None: pass class AuthProvider(Provider, ABC): security = HTTPBearer(auto_error=False) crypto_provider: CryptoProvider email_provider: EmailProvider database_provider: "PostgresDatabaseProvider" def __init__( self, config: AuthConfig, crypto_provider: CryptoProvider, database_provider: "PostgresDatabaseProvider", email_provider: EmailProvider, ): if not isinstance(config, AuthConfig): raise ValueError( "AuthProvider must be initialized with an AuthConfig" ) self.config = config self.admin_email = config.default_admin_email self.admin_password = config.default_admin_password self.crypto_provider = crypto_provider self.database_provider = database_provider self.email_provider = email_provider super().__init__(config) self.config: AuthConfig = config self.database_provider: "PostgresDatabaseProvider" = database_provider async def _get_default_admin_user(self) -> User: return await self.database_provider.users_handler.get_user_by_email( self.admin_email ) @abstractmethod def create_access_token(self, data: dict) -> str: pass @abstractmethod def create_refresh_token(self, data: dict) -> str: pass @abstractmethod async def decode_token(self, token: str) -> TokenData: pass @abstractmethod async def user(self, token: str) -> User: pass @abstractmethod def get_current_active_user(self, current_user: User) -> User: pass @abstractmethod async def register(self, email: str, password: str) -> User: pass @abstractmethod async def verify_email( self, email: str, verification_code: str ) -> dict[str, str]: pass @abstractmethod async def login(self, email: str, password: str) -> dict[str, Token]: pass @abstractmethod async def refresh_access_token( self, refresh_token: str ) -> dict[str, Token]: pass def auth_wrapper( self, public: bool = False, ): async def _auth_wrapper( auth: Optional[HTTPAuthorizationCredentials] = Security( self.security ), api_key: Optional[str] = Security(api_key_header), ) -> User: # If authentication is not required and no credentials are provided, return the default admin user if ( ((not self.config.require_authentication) or public) and auth is None and api_key is None ): return await self._get_default_admin_user() if not auth and not api_key: raise R2RException( message="No credentials provided", status_code=401, ) if auth and api_key: raise R2RException( message="Cannot have both Bearer token and API key", status_code=400, ) # 1. Try JWT if `auth` is present (Bearer token) if auth is not None: credentials = auth.credentials try: token_data = await self.decode_token(credentials) user = await self.database_provider.users_handler.get_user_by_email( token_data.email ) if user is not None: return user except R2RException: # JWT decoding failed for logical reasons (invalid token) pass except Exception as e: # JWT decoding failed unexpectedly, log and continue logger.debug(f"JWT verification failed: {e}") # 2. If JWT failed, try API key from Bearer token # Expected format: key_id.raw_api_key if "." in credentials: key_id, raw_api_key = credentials.split(".", 1) api_key_record = await self.database_provider.users_handler.get_api_key_record( key_id ) if api_key_record is not None: hashed_key = api_key_record["hashed_key"] if self.crypto_provider.verify_api_key( raw_api_key, hashed_key ): user = await self.database_provider.users_handler.get_user_by_id( api_key_record["user_id"] ) if user is not None and user.is_active: return user # 3. If no Bearer token worked, try the X-API-Key header if api_key is not None and "." in api_key: key_id, raw_api_key = api_key.split(".", 1) api_key_record = await self.database_provider.users_handler.get_api_key_record( key_id ) if api_key_record is not None: hashed_key = api_key_record["hashed_key"] if self.crypto_provider.verify_api_key( raw_api_key, hashed_key ): user = await self.database_provider.users_handler.get_user_by_id( api_key_record["user_id"] ) if user is not None and user.is_active: return user # If we reach here, both JWT and API key auth failed raise R2RException( message="Invalid token or API key", status_code=401, ) return _auth_wrapper @abstractmethod async def change_password( self, user: User, current_password: str, new_password: str ) -> dict[str, str]: pass @abstractmethod async def request_password_reset(self, email: str) -> dict[str, str]: pass @abstractmethod async def confirm_password_reset( self, reset_token: str, new_password: str ) -> dict[str, str]: pass @abstractmethod async def logout(self, token: str) -> dict[str, str]: pass @abstractmethod async def send_reset_email(self, email: str) -> dict[str, str]: pass