bcrypt.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import base64
  2. import os
  3. from abc import ABC
  4. from datetime import datetime, timezone
  5. from typing import Optional, Tuple
  6. import bcrypt
  7. import jwt
  8. import nacl.encoding
  9. import nacl.exceptions
  10. import nacl.signing
  11. import nacl.utils
  12. from core.base import CryptoConfig, CryptoProvider
  13. DEFAULT_BCRYPT_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM" # Replace or load from env or secrets manager
  14. class BcryptCryptoConfig(CryptoConfig):
  15. provider: str = "bcrypt"
  16. # Number of rounds for bcrypt (increasing this makes hashing slower but more secure)
  17. bcrypt_rounds: int = 12
  18. secret_key: Optional[str] = None
  19. api_key_bytes: int = 32 # Length of raw API keys
  20. @property
  21. def supported_providers(self) -> list[str]:
  22. return ["bcrypt"]
  23. def validate_config(self) -> None:
  24. super().validate_config()
  25. if self.provider not in self.supported_providers:
  26. raise ValueError(f"Unsupported crypto provider: {self.provider}")
  27. if self.bcrypt_rounds < 4 or self.bcrypt_rounds > 31:
  28. raise ValueError("bcrypt_rounds must be between 4 and 31")
  29. def verify_password(
  30. self, plain_password: str, hashed_password: str
  31. ) -> bool:
  32. try:
  33. # First try to decode as base64 (new format)
  34. stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
  35. except:
  36. # If that fails, treat as raw bcrypt hash (old format)
  37. stored_hash = hashed_password.encode("utf-8")
  38. return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash)
  39. class BCryptCryptoProvider(CryptoProvider, ABC):
  40. def __init__(self, config: BcryptCryptoConfig):
  41. if not isinstance(config, BcryptCryptoConfig):
  42. raise ValueError(
  43. "BcryptCryptoProvider must be initialized with a BcryptCryptoConfig"
  44. )
  45. super().__init__(config)
  46. self.config: BcryptCryptoConfig = config
  47. # Load the secret key for JWT
  48. # No fallback defaults: fail if not provided
  49. self.secret_key = (
  50. config.secret_key
  51. or os.getenv("R2R_SECRET_KEY")
  52. or DEFAULT_BCRYPT_SECRET_KEY
  53. )
  54. if not self.secret_key:
  55. raise ValueError(
  56. "No secret key provided for BcryptCryptoProvider."
  57. )
  58. def get_password_hash(self, password: str) -> str:
  59. # Bcrypt expects bytes
  60. password_bytes = password.encode("utf-8")
  61. hashed = bcrypt.hashpw(
  62. password_bytes, bcrypt.gensalt(rounds=self.config.bcrypt_rounds)
  63. )
  64. return base64.b64encode(hashed).decode("utf-8")
  65. def verify_password(
  66. self, plain_password: str, hashed_password: str
  67. ) -> bool:
  68. try:
  69. # First try to decode as base64 (new format)
  70. stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
  71. if not stored_hash.startswith(b"$2b$"): # Valid bcrypt hash prefix
  72. stored_hash = hashed_password.encode("utf-8")
  73. except:
  74. # Otherwise raw bcrypt hash (old format)
  75. stored_hash = hashed_password.encode("utf-8")
  76. try:
  77. return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash)
  78. except ValueError as e:
  79. if "Invalid salt" in str(e):
  80. # If it's an invalid salt, the hash format is wrong - try the other format
  81. try:
  82. stored_hash = (
  83. hashed_password
  84. if isinstance(hashed_password, bytes)
  85. else hashed_password.encode("utf-8")
  86. )
  87. return bcrypt.checkpw(
  88. plain_password.encode("utf-8"), stored_hash
  89. )
  90. except ValueError:
  91. return False
  92. raise
  93. def generate_verification_code(self, length: int = 32) -> str:
  94. random_bytes = nacl.utils.random(length)
  95. return base64.urlsafe_b64encode(random_bytes)[:length].decode("utf-8")
  96. def generate_signing_keypair(self) -> Tuple[str, str, str]:
  97. signing_key = nacl.signing.SigningKey.generate()
  98. verify_key = signing_key.verify_key
  99. # Generate unique key_id
  100. key_entropy = nacl.utils.random(16)
  101. key_id = f"sk_{base64.urlsafe_b64encode(key_entropy).decode()}"
  102. private_key = base64.b64encode(bytes(signing_key)).decode()
  103. public_key = base64.b64encode(bytes(verify_key)).decode()
  104. return key_id, private_key, public_key
  105. def sign_request(self, private_key: str, data: str) -> str:
  106. try:
  107. key_bytes = base64.b64decode(private_key)
  108. signing_key = nacl.signing.SigningKey(key_bytes)
  109. signature = signing_key.sign(data.encode())
  110. return base64.b64encode(signature.signature).decode()
  111. except Exception as e:
  112. raise ValueError(f"Invalid private key or signing error: {str(e)}")
  113. def verify_request_signature(
  114. self, public_key: str, signature: str, data: str
  115. ) -> bool:
  116. try:
  117. key_bytes = base64.b64decode(public_key)
  118. verify_key = nacl.signing.VerifyKey(key_bytes)
  119. signature_bytes = base64.b64decode(signature)
  120. verify_key.verify(data.encode(), signature_bytes)
  121. return True
  122. except (nacl.exceptions.BadSignatureError, ValueError):
  123. return False
  124. def generate_api_key(self) -> Tuple[str, str]:
  125. # Similar approach as with NaCl provider:
  126. key_id_bytes = nacl.utils.random(16)
  127. key_id = f"key_{base64.urlsafe_b64encode(key_id_bytes).decode()}"
  128. # Generate raw API key
  129. raw_api_key = base64.urlsafe_b64encode(
  130. nacl.utils.random(self.config.api_key_bytes)
  131. ).decode()
  132. return key_id, raw_api_key
  133. def hash_api_key(self, raw_api_key: str) -> str:
  134. # Hash with bcrypt
  135. hashed = bcrypt.hashpw(
  136. raw_api_key.encode("utf-8"),
  137. bcrypt.gensalt(rounds=self.config.bcrypt_rounds),
  138. )
  139. return base64.b64encode(hashed).decode("utf-8")
  140. def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool:
  141. stored_hash = base64.b64decode(hashed_key.encode("utf-8"))
  142. return bcrypt.checkpw(raw_api_key.encode("utf-8"), stored_hash)
  143. def generate_secure_token(self, data: dict, expiry: datetime) -> str:
  144. now = datetime.now(timezone.utc)
  145. to_encode = {
  146. **data,
  147. "exp": expiry.timestamp(),
  148. "iat": now.timestamp(),
  149. "nbf": now.timestamp(),
  150. "jti": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
  151. "nonce": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
  152. }
  153. return jwt.encode(to_encode, self.secret_key, algorithm="HS256")
  154. def verify_secure_token(self, token: str) -> Optional[dict]:
  155. try:
  156. payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
  157. exp = payload.get("exp")
  158. if exp is None or datetime.fromtimestamp(
  159. exp, tz=timezone.utc
  160. ) < datetime.now(timezone.utc):
  161. return None
  162. return payload
  163. except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
  164. return None