bcrypt.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import base64
  2. import logging
  3. import os
  4. from abc import ABC
  5. from datetime import datetime, timezone
  6. from typing import Optional, Tuple
  7. import bcrypt
  8. import jwt
  9. import nacl.encoding
  10. import nacl.exceptions
  11. import nacl.signing
  12. import nacl.utils
  13. from core.base import CryptoConfig, CryptoProvider
  14. DEFAULT_BCRYPT_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM" # Replace or load from env or secrets manager
  15. class BcryptCryptoConfig(CryptoConfig):
  16. provider: str = "bcrypt"
  17. # Number of rounds for bcrypt (increasing this makes hashing slower but more secure)
  18. bcrypt_rounds: int = 12
  19. secret_key: Optional[str] = None
  20. api_key_bytes: int = 32 # Length of raw API keys
  21. @property
  22. def supported_providers(self) -> list[str]:
  23. return ["bcrypt"]
  24. def validate_config(self) -> None:
  25. super().validate_config()
  26. if self.provider not in self.supported_providers:
  27. raise ValueError(f"Unsupported crypto provider: {self.provider}")
  28. if self.bcrypt_rounds < 4 or self.bcrypt_rounds > 31:
  29. raise ValueError("bcrypt_rounds must be between 4 and 31")
  30. def verify_password(
  31. self, plain_password: str, hashed_password: str
  32. ) -> bool:
  33. try:
  34. # First try to decode as base64 (new format)
  35. stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
  36. except Exception:
  37. # If that fails, treat as raw bcrypt hash (old format)
  38. stored_hash = hashed_password.encode("utf-8")
  39. return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash)
  40. class BCryptCryptoProvider(CryptoProvider, ABC):
  41. def __init__(self, config: BcryptCryptoConfig):
  42. if not isinstance(config, BcryptCryptoConfig):
  43. raise ValueError(
  44. "BcryptCryptoProvider must be initialized with a BcryptCryptoConfig"
  45. )
  46. logging.info("Initializing BcryptCryptoProvider")
  47. super().__init__(config)
  48. self.config: BcryptCryptoConfig = config
  49. # Load the secret key for JWT
  50. # No fallback defaults: fail if not provided
  51. self.secret_key = (
  52. config.secret_key
  53. or os.getenv("R2R_SECRET_KEY")
  54. or DEFAULT_BCRYPT_SECRET_KEY
  55. )
  56. if not self.secret_key:
  57. raise ValueError(
  58. "No secret key provided for BcryptCryptoProvider."
  59. )
  60. def get_password_hash(self, password: str) -> str:
  61. # Bcrypt expects bytes
  62. password_bytes = password.encode("utf-8")
  63. hashed = bcrypt.hashpw(
  64. password_bytes, bcrypt.gensalt(rounds=self.config.bcrypt_rounds)
  65. )
  66. return base64.b64encode(hashed).decode("utf-8")
  67. def verify_password(
  68. self, plain_password: str, hashed_password: str
  69. ) -> bool:
  70. try:
  71. # First try to decode as base64 (new format)
  72. stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
  73. if not stored_hash.startswith(b"$2b$"): # Valid bcrypt hash prefix
  74. stored_hash = hashed_password.encode("utf-8")
  75. except Exception:
  76. # Otherwise raw bcrypt hash (old format)
  77. stored_hash = hashed_password.encode("utf-8")
  78. try:
  79. return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash)
  80. except ValueError as e:
  81. if "Invalid salt" in str(e):
  82. # If it's an invalid salt, the hash format is wrong - try the other format
  83. try:
  84. stored_hash = (
  85. hashed_password
  86. if isinstance(hashed_password, bytes)
  87. else hashed_password.encode("utf-8")
  88. )
  89. return bcrypt.checkpw(
  90. plain_password.encode("utf-8"), stored_hash
  91. )
  92. except ValueError:
  93. return False
  94. raise
  95. def generate_verification_code(self, length: int = 32) -> str:
  96. random_bytes = nacl.utils.random(length)
  97. return base64.urlsafe_b64encode(random_bytes)[:length].decode("utf-8")
  98. def generate_signing_keypair(self) -> Tuple[str, str, str]:
  99. signing_key = nacl.signing.SigningKey.generate()
  100. verify_key = signing_key.verify_key
  101. # Generate unique key_id
  102. key_entropy = nacl.utils.random(16)
  103. key_id = f"sk_{base64.urlsafe_b64encode(key_entropy).decode()}"
  104. private_key = base64.b64encode(bytes(signing_key)).decode()
  105. public_key = base64.b64encode(bytes(verify_key)).decode()
  106. return key_id, private_key, public_key
  107. def sign_request(self, private_key: str, data: str) -> str:
  108. try:
  109. key_bytes = base64.b64decode(private_key)
  110. signing_key = nacl.signing.SigningKey(key_bytes)
  111. signature = signing_key.sign(data.encode())
  112. return base64.b64encode(signature.signature).decode()
  113. except Exception as e:
  114. raise ValueError(
  115. f"Invalid private key or signing error: {str(e)}"
  116. ) from e
  117. def verify_request_signature(
  118. self, public_key: str, signature: str, data: str
  119. ) -> bool:
  120. try:
  121. key_bytes = base64.b64decode(public_key)
  122. verify_key = nacl.signing.VerifyKey(key_bytes)
  123. signature_bytes = base64.b64decode(signature)
  124. verify_key.verify(data.encode(), signature_bytes)
  125. return True
  126. except (nacl.exceptions.BadSignatureError, ValueError):
  127. return False
  128. def generate_api_key(self) -> Tuple[str, str]:
  129. # Similar approach as with NaCl provider:
  130. key_id_bytes = nacl.utils.random(16)
  131. key_id = f"key_{base64.urlsafe_b64encode(key_id_bytes).decode()}"
  132. # Generate raw API key
  133. raw_api_key = base64.urlsafe_b64encode(
  134. nacl.utils.random(self.config.api_key_bytes)
  135. ).decode()
  136. return key_id, raw_api_key
  137. def hash_api_key(self, raw_api_key: str) -> str:
  138. # Hash with bcrypt
  139. hashed = bcrypt.hashpw(
  140. raw_api_key.encode("utf-8"),
  141. bcrypt.gensalt(rounds=self.config.bcrypt_rounds),
  142. )
  143. return base64.b64encode(hashed).decode("utf-8")
  144. def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool:
  145. stored_hash = base64.b64decode(hashed_key.encode("utf-8"))
  146. return bcrypt.checkpw(raw_api_key.encode("utf-8"), stored_hash)
  147. def generate_secure_token(self, data: dict, expiry: datetime) -> str:
  148. now = datetime.now(timezone.utc)
  149. to_encode = {
  150. **data,
  151. "exp": expiry.timestamp(),
  152. "iat": now.timestamp(),
  153. "nbf": now.timestamp(),
  154. "jti": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
  155. "nonce": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
  156. }
  157. return jwt.encode(to_encode, self.secret_key, algorithm="HS256")
  158. def verify_secure_token(self, token: str) -> Optional[dict]:
  159. try:
  160. payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
  161. exp = payload.get("exp")
  162. if exp is None or datetime.fromtimestamp(
  163. exp, tz=timezone.utc
  164. ) < datetime.now(timezone.utc):
  165. return None
  166. return payload
  167. except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
  168. return None