supabase.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import logging
  2. import os
  3. from datetime import datetime, timedelta, timezone
  4. from typing import Optional
  5. from uuid import UUID
  6. from fastapi import Depends, HTTPException
  7. from fastapi.security import OAuth2PasswordBearer
  8. from supabase import Client, create_client
  9. from core.base import (
  10. AuthConfig,
  11. AuthProvider,
  12. CryptoProvider,
  13. EmailProvider,
  14. R2RException,
  15. Token,
  16. TokenData,
  17. )
  18. from core.base.api.models import User
  19. from ..database import PostgresDatabaseProvider
  20. logger = logging.getLogger()
  21. logger = logging.getLogger()
  22. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  23. class SupabaseAuthProvider(AuthProvider):
  24. def __init__(
  25. self,
  26. config: AuthConfig,
  27. crypto_provider: CryptoProvider,
  28. database_provider: PostgresDatabaseProvider,
  29. email_provider: EmailProvider,
  30. ):
  31. super().__init__(
  32. config, crypto_provider, database_provider, email_provider
  33. )
  34. self.supabase_url = config.extra_fields.get(
  35. "supabase_url", None
  36. ) or os.getenv("SUPABASE_URL")
  37. self.supabase_key = config.extra_fields.get(
  38. "supabase_key", None
  39. ) or os.getenv("SUPABASE_KEY")
  40. if not self.supabase_url or not self.supabase_key:
  41. raise HTTPException(
  42. status_code=500,
  43. detail="Supabase URL and key must be provided",
  44. )
  45. self.supabase: Client = create_client(
  46. self.supabase_url, self.supabase_key
  47. )
  48. async def initialize(self):
  49. # No initialization needed for Supabase
  50. pass
  51. def create_access_token(self, data: dict) -> str:
  52. raise NotImplementedError(
  53. "create_access_token is not used with Supabase authentication"
  54. )
  55. def create_refresh_token(self, data: dict) -> str:
  56. raise NotImplementedError(
  57. "create_refresh_token is not used with Supabase authentication"
  58. )
  59. async def decode_token(self, token: str) -> TokenData:
  60. try:
  61. # Remove the "Bearer " prefix (if present)
  62. if token.startswith("Bearer "):
  63. token = token[7:]
  64. # Get Supabase token information
  65. auth_response = self.supabase.auth.get_user(token)
  66. if not auth_response or not auth_response.user:
  67. raise R2RException(status_code=401, message="Invalid token")
  68. user = auth_response.user
  69. # Default expiration time
  70. # If Supabase session expire information is not available, use the current time plus 1 hour
  71. expiration_time = datetime.now(timezone.utc) + timedelta(hours=1)
  72. # If Supabase session_expires_at information is available, use it
  73. if hasattr(auth_response, "session") and hasattr(
  74. auth_response.session, "expires_at"
  75. ):
  76. # If expires_at is a timestamp, convert it to a datetime
  77. expiration_time = datetime.fromtimestamp(
  78. auth_response.session.expires_at, timezone.utc
  79. )
  80. # Create TokenData object
  81. return TokenData(
  82. email=user.email,
  83. token_type="access", # Supabase JWT is considered an access token
  84. exp=expiration_time,
  85. )
  86. except Exception as e:
  87. logger.error(f"Token decode error: {str(e)}")
  88. raise R2RException(status_code=401, message="Invalid token") from e
  89. async def register(
  90. self,
  91. email: str,
  92. password: str,
  93. is_verified: bool = False,
  94. name: Optional[str] = None,
  95. bio: Optional[str] = None,
  96. profile_picture: Optional[str] = None,
  97. ) -> User: # type: ignore
  98. # Use Supabase client to create a new user
  99. if self.supabase.auth.sign_up(email=email, password=password):
  100. raise R2RException(
  101. status_code=400,
  102. message="Supabase provider implementation is still under construction",
  103. )
  104. else:
  105. raise R2RException(
  106. status_code=400, message="User registration failed"
  107. )
  108. async def send_verification_email(
  109. self, email: str, user: Optional[User] = None
  110. ) -> tuple[str, datetime]:
  111. raise NotImplementedError(
  112. "send_verification_email is not used with Supabase"
  113. )
  114. async def verify_email(
  115. self, email: str, verification_code: str
  116. ) -> dict[str, str]:
  117. # Use Supabase client to verify email
  118. if self.supabase.auth.verify_email(email, verification_code):
  119. return {"message": "Email verified successfully"}
  120. else:
  121. raise R2RException(
  122. status_code=400, message="Invalid or expired verification code"
  123. )
  124. async def login(self, email: str, password: str) -> dict[str, Token]:
  125. # Use Supabase client to authenticate user and get tokens
  126. try:
  127. response = self.supabase.auth.sign_in_with_password(
  128. {"email": email, "password": password}
  129. )
  130. # Correct access method - token information is found in response.session
  131. if response.session:
  132. access_token = response.session.access_token
  133. refresh_token = response.session.refresh_token
  134. return {
  135. "access_token": Token(
  136. token=access_token, token_type="access"
  137. ),
  138. "refresh_token": Token(
  139. token=refresh_token, token_type="refresh"
  140. ),
  141. }
  142. else:
  143. raise R2RException(
  144. status_code=401, message="Invalid email or password"
  145. )
  146. except Exception as e:
  147. logger.error(f"Login error: {str(e)}")
  148. raise R2RException(
  149. status_code=401, message="Invalid email or password"
  150. ) from e
  151. async def refresh_access_token(
  152. self, refresh_token: str
  153. ) -> dict[str, Token]:
  154. # Use Supabase client to refresh access token
  155. try:
  156. response = self.supabase.auth.refresh_session(refresh_token)
  157. if response.session:
  158. new_access_token = response.session.access_token
  159. new_refresh_token = response.session.refresh_token
  160. return {
  161. "access_token": Token(
  162. token=new_access_token, token_type="access"
  163. ),
  164. "refresh_token": Token(
  165. token=new_refresh_token, token_type="refresh"
  166. ),
  167. }
  168. else:
  169. raise R2RException(
  170. status_code=401, message="Invalid refresh token"
  171. )
  172. except Exception as e:
  173. logger.error(f"Token refresh error: {str(e)}")
  174. raise R2RException(
  175. status_code=401, message="Invalid refresh token"
  176. ) from e
  177. async def user(self, token: str = Depends(oauth2_scheme)) -> User:
  178. # Use Supabase client to get user details from token
  179. try:
  180. auth_response = self.supabase.auth.get_user(token)
  181. if auth_response.user:
  182. user_data = auth_response.user
  183. return User(
  184. id=user_data.id,
  185. email=user_data.email,
  186. is_active=True, # Assuming active if exists in Supabase
  187. is_superuser=False, # Default to False unless explicitly set
  188. created_at=user_data.created_at,
  189. updated_at=user_data.updated_at or user_data.created_at,
  190. is_verified=user_data.email_confirmed_at is not None,
  191. name=user_data.user_metadata.get("name"),
  192. # Set other optional fields if available in user metadata
  193. )
  194. else:
  195. raise R2RException(status_code=401, message="Invalid token")
  196. except Exception as e:
  197. logger.error(f"User lookup error: {str(e)}")
  198. raise R2RException(status_code=401, message="Invalid token") from e
  199. def get_current_active_user(
  200. self, current_user: User = Depends(user)
  201. ) -> User:
  202. # Check if user is active
  203. if not current_user.is_active:
  204. raise R2RException(status_code=400, message="Inactive user")
  205. return current_user
  206. async def change_password(
  207. self, user: User, current_password: str, new_password: str
  208. ) -> dict[str, str]:
  209. # Use Supabase client to update user password
  210. try:
  211. # First, we log in with the current password to verify the user
  212. self.supabase.auth.sign_in_with_password(
  213. {"email": user.email, "password": current_password}
  214. )
  215. # Then we update the password
  216. self.supabase.auth.update_user({"password": new_password})
  217. return {"message": "Password changed successfully"}
  218. except Exception as e:
  219. logger.error(f"Password change error: {str(e)}")
  220. raise R2RException(
  221. status_code=400, message="Failed to change password"
  222. ) from e
  223. async def request_password_reset(self, email: str) -> dict[str, str]:
  224. # Use Supabase client to send password reset email
  225. try:
  226. # Find the base URL from the environment variable
  227. if base_url := os.getenv("R2R_BASE_URL"):
  228. # If R2R_BASE_URL is set, change the port from 7272 to 7273
  229. # Add /auth/login to the end of the URL
  230. # Remove the trailing slash from the URL
  231. if base_url.endswith("/"):
  232. base_url = base_url[:-1]
  233. # Change the port from 7272 to 7273
  234. if ":7272" in base_url:
  235. redirect_url = base_url.replace(":7272", ":7273")
  236. else:
  237. redirect_url = base_url
  238. # Add /auth/login to the end of the URL
  239. if not redirect_url.endswith("/auth/login"):
  240. redirect_url = f"{redirect_url}/auth/login"
  241. else:
  242. # Use the default URL
  243. redirect_url = "https://app.sciphi.ai/auth/login"
  244. # Send the password reset email and use the custom redirect URL
  245. self.supabase.auth.reset_password_for_email(
  246. email, options={"redirect_to": redirect_url}
  247. )
  248. # Return a success message for security reasons
  249. return {
  250. "message": "If the email exists, a reset link has been sent"
  251. }
  252. except Exception as e:
  253. # Even if an error occurs, log the error and return a success message
  254. logger.error(f"Password reset request error: {str(e)}")
  255. return {
  256. "message": "If the email exists, a reset link has been sent"
  257. }
  258. async def confirm_password_reset(
  259. self, reset_token: str, new_password: str
  260. ) -> dict[str, str]:
  261. raise NotImplementedError(
  262. "Password reset confirmation is not implemented with Supabase authentication"
  263. )
  264. async def logout(self, token: str) -> dict[str, str]:
  265. try:
  266. # Logout the user
  267. self.supabase.auth.sign_out()
  268. return {"message": "Logged out successfully"}
  269. except Exception as e:
  270. logger.error(f"Logout error: {str(e)}")
  271. raise R2RException(status_code=400, message="Logout failed") from e
  272. async def clean_expired_blacklisted_tokens(self):
  273. # Not applicable for Supabase, tokens are managed by Supabase
  274. pass
  275. async def send_reset_email(self, email: str) -> dict[str, str]:
  276. raise NotImplementedError("send_reset_email is not used with Supabase")
  277. async def create_user_api_key(
  278. self,
  279. user_id: UUID,
  280. name: Optional[str] = None,
  281. description: Optional[str] = None,
  282. ) -> dict[str, str]:
  283. raise NotImplementedError(
  284. "API key management is not supported with Supabase authentication"
  285. )
  286. async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
  287. raise NotImplementedError(
  288. "API key management is not supported with Supabase authentication"
  289. )
  290. async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
  291. raise NotImplementedError(
  292. "API key management is not supported with Supabase authentication"
  293. )
  294. async def oauth_callback_handler(
  295. self, provider: str, oauth_id: str, email: str
  296. ) -> dict[str, Token]:
  297. raise NotImplementedError(
  298. "API key management is not supported with Supabase authentication"
  299. )