jwt.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import logging
  2. import os
  3. from datetime import datetime
  4. from typing import Optional
  5. from uuid import UUID
  6. import jwt
  7. from fastapi import Depends
  8. from core.base import (
  9. AuthConfig,
  10. AuthProvider,
  11. CryptoProvider,
  12. EmailProvider,
  13. R2RException,
  14. Token,
  15. TokenData,
  16. )
  17. from core.base.api.models import User
  18. from ..database import PostgresDatabaseProvider
  19. logger = logging.getLogger()
  20. class JwtAuthProvider(AuthProvider):
  21. def __init__(
  22. self,
  23. config: AuthConfig,
  24. crypto_provider: CryptoProvider,
  25. database_provider: PostgresDatabaseProvider,
  26. email_provider: EmailProvider,
  27. ):
  28. super().__init__(
  29. config, crypto_provider, database_provider, email_provider
  30. )
  31. async def login(self, email: str, password: str) -> dict[str, Token]:
  32. raise NotImplementedError("Not implemented")
  33. async def oauth_callback(self, code: str) -> dict[str, Token]:
  34. raise NotImplementedError("Not implemented")
  35. async def user(self, token: str) -> User:
  36. raise NotImplementedError("Not implemented")
  37. async def change_password(
  38. self, user: User, current_password: str, new_password: str
  39. ) -> dict[str, str]:
  40. raise NotImplementedError("Not implemented")
  41. async def confirm_password_reset(
  42. self, reset_token: str, new_password: str
  43. ) -> dict[str, str]:
  44. raise NotImplementedError("Not implemented")
  45. def create_access_token(self, data: dict) -> str:
  46. raise NotImplementedError("Not implemented")
  47. def create_refresh_token(self, data: dict) -> str:
  48. raise NotImplementedError("Not implemented")
  49. async def decode_token(self, token: str) -> TokenData:
  50. # use JWT library to validate and decode JWT token
  51. jwtSecret = os.getenv("JWT_SECRET")
  52. if jwtSecret is None:
  53. raise R2RException(
  54. status_code=500,
  55. message="JWT_SECRET environment variable is not set",
  56. )
  57. try:
  58. user = jwt.decode(token, jwtSecret, algorithms=["HS256"])
  59. except Exception as e:
  60. logger.info(f"JWT verification failed: {e}")
  61. raise R2RException(
  62. status_code=401, message="Invalid JWT token", detail=e
  63. ) from e
  64. if user:
  65. # Create user in database if not exists
  66. try:
  67. await self.database_provider.users_handler.get_user_by_email(
  68. user.get("email")
  69. )
  70. # TODO do we want to update user info here based on what's in the token?
  71. except Exception:
  72. # user doesn't exist, create in db
  73. logger.debug(f"Creating new user: {user.get('email')}")
  74. try:
  75. await self.database_provider.users_handler.create_user(
  76. email=user.get("email"),
  77. account_type="external",
  78. name=user.get("name"),
  79. )
  80. except Exception as e:
  81. logger.error(f"Error creating user: {e}")
  82. raise R2RException(
  83. status_code=500, message="Failed to create user"
  84. ) from e
  85. return TokenData(
  86. email=user.get("email"),
  87. token_type="bearer",
  88. exp=user.get("exp"),
  89. )
  90. else:
  91. raise R2RException(status_code=401, message="Invalid JWT token")
  92. async def refresh_access_token(
  93. self, refresh_token: str
  94. ) -> dict[str, Token]:
  95. raise NotImplementedError("Not implemented")
  96. def get_current_active_user(
  97. self, current_user: User = Depends(user)
  98. ) -> User:
  99. # Check if user is active
  100. if not current_user.is_active:
  101. raise R2RException(status_code=400, message="Inactive user")
  102. return current_user
  103. async def logout(self, token: str) -> dict[str, str]:
  104. raise NotImplementedError("Not implemented")
  105. async def register(
  106. self,
  107. email: str,
  108. password: str,
  109. is_verified: bool = False,
  110. name: Optional[str] = None,
  111. bio: Optional[str] = None,
  112. profile_picture: Optional[str] = None,
  113. ) -> User: # type: ignore
  114. raise NotImplementedError("Not implemented")
  115. async def request_password_reset(self, email: str) -> dict[str, str]:
  116. raise NotImplementedError("Not implemented")
  117. async def send_reset_email(self, email: str) -> dict[str, str]:
  118. raise NotImplementedError("Not implemented")
  119. async def create_user_api_key(
  120. self,
  121. user_id: UUID,
  122. name: Optional[str] = None,
  123. description: Optional[str] = None,
  124. ) -> dict[str, str]:
  125. raise NotImplementedError("Not implemented")
  126. async def verify_email(
  127. self, email: str, verification_code: str
  128. ) -> dict[str, str]:
  129. raise NotImplementedError("Not implemented")
  130. async def send_verification_email(
  131. self, email: str, user: Optional[User] = None
  132. ) -> tuple[str, datetime]:
  133. raise NotImplementedError("Not implemented")
  134. async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
  135. raise NotImplementedError("Not implemented")
  136. async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
  137. raise NotImplementedError("Not implemented")
  138. async def oauth_callback_handler(
  139. self, provider: str, oauth_id: str, email: str
  140. ) -> dict[str, Token]:
  141. raise NotImplementedError("Not implemented")