import base64 import os from abc import ABC from datetime import datetime, timezone from typing import Optional, Tuple import bcrypt import jwt import nacl.encoding import nacl.exceptions import nacl.signing import nacl.utils from core.base import CryptoConfig, CryptoProvider DEFAULT_BCRYPT_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM" # Replace or load from env or secrets manager class BcryptCryptoConfig(CryptoConfig): provider: str = "bcrypt" # Number of rounds for bcrypt (increasing this makes hashing slower but more secure) bcrypt_rounds: int = 12 secret_key: Optional[str] = None api_key_bytes: int = 32 # Length of raw API keys @property def supported_providers(self) -> list[str]: return ["bcrypt"] def validate_config(self) -> None: super().validate_config() if self.provider not in self.supported_providers: raise ValueError(f"Unsupported crypto provider: {self.provider}") if self.bcrypt_rounds < 4 or self.bcrypt_rounds > 31: raise ValueError("bcrypt_rounds must be between 4 and 31") def verify_password( self, plain_password: str, hashed_password: str ) -> bool: try: # First try to decode as base64 (new format) stored_hash = base64.b64decode(hashed_password.encode("utf-8")) except: # If that fails, treat as raw bcrypt hash (old format) stored_hash = hashed_password.encode("utf-8") return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash) class BCryptCryptoProvider(CryptoProvider, ABC): def __init__(self, config: BcryptCryptoConfig): if not isinstance(config, BcryptCryptoConfig): raise ValueError( "BcryptCryptoProvider must be initialized with a BcryptCryptoConfig" ) super().__init__(config) self.config: BcryptCryptoConfig = config # Load the secret key for JWT # No fallback defaults: fail if not provided self.secret_key = ( config.secret_key or os.getenv("R2R_SECRET_KEY") or DEFAULT_BCRYPT_SECRET_KEY ) if not self.secret_key: raise ValueError( "No secret key provided for BcryptCryptoProvider." ) def get_password_hash(self, password: str) -> str: # Bcrypt expects bytes password_bytes = password.encode("utf-8") hashed = bcrypt.hashpw( password_bytes, bcrypt.gensalt(rounds=self.config.bcrypt_rounds) ) return base64.b64encode(hashed).decode("utf-8") def verify_password( self, plain_password: str, hashed_password: str ) -> bool: try: # First try to decode as base64 (new format) stored_hash = base64.b64decode(hashed_password.encode("utf-8")) if not stored_hash.startswith(b"$2b$"): # Valid bcrypt hash prefix stored_hash = hashed_password.encode("utf-8") except: # Otherwise raw bcrypt hash (old format) stored_hash = hashed_password.encode("utf-8") try: return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash) except ValueError as e: if "Invalid salt" in str(e): # If it's an invalid salt, the hash format is wrong - try the other format try: stored_hash = ( hashed_password if isinstance(hashed_password, bytes) else hashed_password.encode("utf-8") ) return bcrypt.checkpw( plain_password.encode("utf-8"), stored_hash ) except ValueError: return False raise def generate_verification_code(self, length: int = 32) -> str: random_bytes = nacl.utils.random(length) return base64.urlsafe_b64encode(random_bytes)[:length].decode("utf-8") def generate_signing_keypair(self) -> Tuple[str, str, str]: signing_key = nacl.signing.SigningKey.generate() verify_key = signing_key.verify_key # Generate unique key_id key_entropy = nacl.utils.random(16) key_id = f"sk_{base64.urlsafe_b64encode(key_entropy).decode()}" private_key = base64.b64encode(bytes(signing_key)).decode() public_key = base64.b64encode(bytes(verify_key)).decode() return key_id, private_key, public_key def sign_request(self, private_key: str, data: str) -> str: try: key_bytes = base64.b64decode(private_key) signing_key = nacl.signing.SigningKey(key_bytes) signature = signing_key.sign(data.encode()) return base64.b64encode(signature.signature).decode() except Exception as e: raise ValueError(f"Invalid private key or signing error: {str(e)}") def verify_request_signature( self, public_key: str, signature: str, data: str ) -> bool: try: key_bytes = base64.b64decode(public_key) verify_key = nacl.signing.VerifyKey(key_bytes) signature_bytes = base64.b64decode(signature) verify_key.verify(data.encode(), signature_bytes) return True except (nacl.exceptions.BadSignatureError, ValueError): return False def generate_api_key(self) -> Tuple[str, str]: # Similar approach as with NaCl provider: key_id_bytes = nacl.utils.random(16) key_id = f"key_{base64.urlsafe_b64encode(key_id_bytes).decode()}" # Generate raw API key raw_api_key = base64.urlsafe_b64encode( nacl.utils.random(self.config.api_key_bytes) ).decode() return key_id, raw_api_key def hash_api_key(self, raw_api_key: str) -> str: # Hash with bcrypt hashed = bcrypt.hashpw( raw_api_key.encode("utf-8"), bcrypt.gensalt(rounds=self.config.bcrypt_rounds), ) return base64.b64encode(hashed).decode("utf-8") def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool: stored_hash = base64.b64decode(hashed_key.encode("utf-8")) return bcrypt.checkpw(raw_api_key.encode("utf-8"), stored_hash) def generate_secure_token(self, data: dict, expiry: datetime) -> str: now = datetime.now(timezone.utc) to_encode = { **data, "exp": expiry.timestamp(), "iat": now.timestamp(), "nbf": now.timestamp(), "jti": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(), "nonce": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(), } return jwt.encode(to_encode, self.secret_key, algorithm="HS256") def verify_secure_token(self, token: str) -> Optional[dict]: try: payload = jwt.decode(token, self.secret_key, algorithms=["HS256"]) exp = payload.get("exp") if exp is None or datetime.fromtimestamp( exp, tz=timezone.utc ) < datetime.now(timezone.utc): return None return payload except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): return None