nacl.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import base64
  2. import json
  3. import os
  4. import secrets
  5. import string
  6. from datetime import datetime, timedelta, timezone
  7. from typing import Optional, Tuple
  8. import jwt
  9. import nacl.encoding
  10. import nacl.exceptions
  11. import nacl.pwhash
  12. import nacl.signing
  13. from nacl.exceptions import BadSignatureError
  14. from nacl.pwhash import argon2i
  15. from core.base import CryptoConfig, CryptoProvider
  16. DEFAULT_NACL_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM" # Replace or load from env or secrets manager
  17. class NaClCryptoConfig(CryptoConfig):
  18. provider: str = "nacl"
  19. # Interactive parameters for password ops (fast)
  20. ops_limit: int = argon2i.OPSLIMIT_INTERACTIVE
  21. mem_limit: int = argon2i.MEMLIMIT_INTERACTIVE
  22. # Sensitive parameters for API key generation (slow but more secure)
  23. api_ops_limit: int = argon2i.OPSLIMIT_SENSITIVE
  24. api_mem_limit: int = argon2i.MEMLIMIT_SENSITIVE
  25. api_key_bytes: int = 32
  26. secret_key: Optional[str] = None
  27. class NaClCryptoProvider(CryptoProvider):
  28. def __init__(self, config: NaClCryptoConfig):
  29. if not isinstance(config, NaClCryptoConfig):
  30. raise ValueError(
  31. "NaClCryptoProvider must be initialized with a NaClCryptoConfig"
  32. )
  33. super().__init__(config)
  34. self.config: NaClCryptoConfig = config
  35. # Securely load the secret key for JWT
  36. # Priority: config.secret_key > environment variable > default
  37. self.secret_key = (
  38. config.secret_key
  39. or os.getenv("R2R_SECRET_KEY")
  40. or DEFAULT_NACL_SECRET_KEY
  41. )
  42. def get_password_hash(self, password: str) -> str:
  43. password_bytes = password.encode("utf-8")
  44. hashed = nacl.pwhash.argon2i.str(
  45. password_bytes,
  46. opslimit=self.config.ops_limit,
  47. memlimit=self.config.mem_limit,
  48. )
  49. return base64.b64encode(hashed).decode("utf-8")
  50. def verify_password(
  51. self, plain_password: str, hashed_password: str
  52. ) -> bool:
  53. try:
  54. stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
  55. nacl.pwhash.verify(stored_hash, plain_password.encode("utf-8"))
  56. return True
  57. except nacl.exceptions.InvalidkeyError:
  58. return False
  59. def generate_verification_code(self, length: int = 32) -> str:
  60. random_bytes = nacl.utils.random(length)
  61. return base64.urlsafe_b64encode(random_bytes)[:length].decode("utf-8")
  62. def generate_api_key(self) -> Tuple[str, str]:
  63. # Generate a unique key_id
  64. key_id_bytes = nacl.utils.random(16) # 16 random bytes
  65. key_id = f"key_{base64.urlsafe_b64encode(key_id_bytes).decode()}"
  66. # Generate a high-entropy API key
  67. raw_api_key = base64.urlsafe_b64encode(
  68. nacl.utils.random(self.config.api_key_bytes)
  69. ).decode()
  70. # The caller will store the hashed version in the database
  71. return key_id, raw_api_key
  72. def hash_api_key(self, raw_api_key: str) -> str:
  73. hashed = nacl.pwhash.argon2i.str(
  74. raw_api_key.encode("utf-8"),
  75. opslimit=self.config.api_ops_limit,
  76. memlimit=self.config.api_mem_limit,
  77. )
  78. return base64.b64encode(hashed).decode("utf-8")
  79. def verify_api_key(self, raw_api_key: str, hashed_key: str) -> bool:
  80. try:
  81. stored_hash = base64.b64decode(hashed_key.encode("utf-8"))
  82. nacl.pwhash.verify(stored_hash, raw_api_key.encode("utf-8"))
  83. return True
  84. except nacl.exceptions.InvalidkeyError:
  85. return False
  86. def sign_request(self, private_key: str, data: str) -> str:
  87. try:
  88. key_bytes = base64.b64decode(private_key)
  89. signing_key = nacl.signing.SigningKey(key_bytes)
  90. signature = signing_key.sign(data.encode())
  91. return base64.b64encode(signature.signature).decode()
  92. except Exception as e:
  93. raise ValueError(f"Invalid private key or signing error: {str(e)}")
  94. def verify_request_signature(
  95. self, public_key: str, signature: str, data: str
  96. ) -> bool:
  97. try:
  98. key_bytes = base64.b64decode(public_key)
  99. verify_key = nacl.signing.VerifyKey(key_bytes)
  100. signature_bytes = base64.b64decode(signature)
  101. verify_key.verify(data.encode(), signature_bytes)
  102. return True
  103. except (BadSignatureError, ValueError):
  104. return False
  105. def generate_secure_token(self, data: dict, expiry: datetime) -> str:
  106. """
  107. Generate a secure token using JWT with HS256.
  108. The secret_key is used for symmetrical signing.
  109. """
  110. now = datetime.now(timezone.utc)
  111. to_encode = {
  112. **data,
  113. "exp": expiry.timestamp(),
  114. "iat": now.timestamp(),
  115. "nbf": now.timestamp(),
  116. "jti": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
  117. "nonce": base64.urlsafe_b64encode(nacl.utils.random(16)).decode(),
  118. }
  119. return jwt.encode(to_encode, self.secret_key, algorithm="HS256")
  120. def verify_secure_token(self, token: str) -> Optional[dict]:
  121. """
  122. Verify a secure token using the shared secret_key and JWT.
  123. """
  124. try:
  125. payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
  126. exp = payload.get("exp")
  127. if exp is None or datetime.fromtimestamp(
  128. exp, tz=timezone.utc
  129. ) < datetime.now(timezone.utc):
  130. return None
  131. return payload
  132. except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
  133. return None
  134. def generate_signing_keypair(self) -> Tuple[str, str, str]:
  135. signing_key = nacl.signing.SigningKey.generate()
  136. private_key_b64 = base64.b64encode(signing_key.encode()).decode()
  137. public_key_b64 = base64.b64encode(
  138. signing_key.verify_key.encode()
  139. ).decode()
  140. # Generate a unique key_id
  141. key_id_bytes = nacl.utils.random(16)
  142. key_id = f"sign_{base64.urlsafe_b64encode(key_id_bytes).decode()}"
  143. return (key_id, private_key_b64, public_key_b64)