supabase.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import logging
  2. import os
  3. from typing import Optional
  4. from uuid import UUID
  5. from fastapi import Depends, HTTPException
  6. from fastapi.security import OAuth2PasswordBearer
  7. from supabase import Client, create_client
  8. from core.base import (
  9. AuthConfig,
  10. AuthProvider,
  11. CryptoProvider,
  12. DatabaseProvider,
  13. EmailProvider,
  14. R2RException,
  15. Token,
  16. TokenData,
  17. )
  18. from core.base.api.models import User
  19. logger = logging.getLogger()
  20. logger = logging.getLogger()
  21. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  22. class SupabaseAuthProvider(AuthProvider):
  23. def __init__(
  24. self,
  25. config: AuthConfig,
  26. crypto_provider: CryptoProvider,
  27. database_provider: DatabaseProvider,
  28. email_provider: EmailProvider,
  29. ):
  30. super().__init__(
  31. config, crypto_provider, database_provider, email_provider
  32. )
  33. self.supabase_url = config.extra_fields.get(
  34. "supabase_url", None
  35. ) or os.getenv("SUPABASE_URL")
  36. self.supabase_key = config.extra_fields.get(
  37. "supabase_key", None
  38. ) or os.getenv("SUPABASE_KEY")
  39. if not self.supabase_url or not self.supabase_key:
  40. raise HTTPException(
  41. status_code=500,
  42. detail="Supabase URL and key must be provided",
  43. )
  44. self.supabase: Client = create_client(
  45. self.supabase_url, self.supabase_key
  46. )
  47. async def initialize(self):
  48. # No initialization needed for Supabase
  49. pass
  50. def create_access_token(self, data: dict) -> str:
  51. raise NotImplementedError(
  52. "create_access_token is not used with Supabase authentication"
  53. )
  54. def create_refresh_token(self, data: dict) -> str:
  55. raise NotImplementedError(
  56. "create_refresh_token is not used with Supabase authentication"
  57. )
  58. async def decode_token(self, token: str) -> TokenData:
  59. raise NotImplementedError(
  60. "decode_token is not used with Supabase authentication"
  61. )
  62. async def register(self, email: str, password: str) -> User: # type: ignore
  63. # Use Supabase client to create a new user
  64. if user := self.supabase.auth.sign_up(email=email, password=password):
  65. raise R2RException(
  66. status_code=400,
  67. message="Supabase provider implementation is still under construction",
  68. )
  69. # return User(
  70. # id=user.id,
  71. # email=user.email,
  72. # is_active=True,
  73. # is_superuser=False,
  74. # created_at=user.created_at,
  75. # updated_at=user.updated_at,
  76. # is_verified=False,
  77. # )
  78. else:
  79. raise R2RException(
  80. status_code=400, message="User registration failed"
  81. )
  82. async def verify_email(
  83. self, email: str, verification_code: str
  84. ) -> dict[str, str]:
  85. # Use Supabase client to verify email
  86. if response := self.supabase.auth.verify_email(
  87. email, verification_code
  88. ):
  89. return {"message": "Email verified successfully"}
  90. else:
  91. raise R2RException(
  92. status_code=400, message="Invalid or expired verification code"
  93. )
  94. async def login(self, email: str, password: str) -> dict[str, Token]:
  95. # Use Supabase client to authenticate user and get tokens
  96. if response := self.supabase.auth.sign_in(
  97. email=email, password=password
  98. ):
  99. access_token = response.access_token
  100. refresh_token = response.refresh_token
  101. return {
  102. "access_token": Token(token=access_token, token_type="access"),
  103. "refresh_token": Token(
  104. token=refresh_token, token_type="refresh"
  105. ),
  106. }
  107. else:
  108. raise R2RException(
  109. status_code=401, message="Invalid email or password"
  110. )
  111. async def refresh_access_token(
  112. self, refresh_token: str
  113. ) -> dict[str, Token]:
  114. # Use Supabase client to refresh access token
  115. if response := self.supabase.auth.refresh_access_token(refresh_token):
  116. new_access_token = response.access_token
  117. new_refresh_token = response.refresh_token
  118. return {
  119. "access_token": Token(
  120. token=new_access_token, token_type="access"
  121. ),
  122. "refresh_token": Token(
  123. token=new_refresh_token, token_type="refresh"
  124. ),
  125. }
  126. else:
  127. raise R2RException(
  128. status_code=401, message="Invalid refresh token"
  129. )
  130. async def user(self, token: str = Depends(oauth2_scheme)) -> User:
  131. # Use Supabase client to get user details from token
  132. if user := self.supabase.auth.get_user(token).user:
  133. return User(
  134. id=user.id,
  135. email=user.email,
  136. is_active=True, # Assuming active if exists in Supabase
  137. is_superuser=False, # Default to False unless explicitly set
  138. created_at=user.created_at,
  139. updated_at=user.updated_at,
  140. is_verified=user.email_confirmed_at is not None,
  141. name=user.user_metadata.get("full_name"),
  142. # Set other optional fields if available in user metadata
  143. )
  144. else:
  145. raise R2RException(status_code=401, message="Invalid token")
  146. def get_current_active_user(
  147. self, current_user: User = Depends(user)
  148. ) -> User:
  149. # Check if user is active
  150. if not current_user.is_active:
  151. raise R2RException(status_code=400, message="Inactive user")
  152. return current_user
  153. async def change_password(
  154. self, user: User, current_password: str, new_password: str
  155. ) -> dict[str, str]:
  156. # Use Supabase client to update user password
  157. if response := self.supabase.auth.update(
  158. user.id, {"password": new_password}
  159. ):
  160. return {"message": "Password changed successfully"}
  161. else:
  162. raise R2RException(
  163. status_code=400, message="Failed to change password"
  164. )
  165. async def request_password_reset(self, email: str) -> dict[str, str]:
  166. # Use Supabase client to send password reset email
  167. if response := self.supabase.auth.send_password_reset_email(email):
  168. return {
  169. "message": "If the email exists, a reset link has been sent"
  170. }
  171. else:
  172. raise R2RException(
  173. status_code=400, message="Failed to send password reset email"
  174. )
  175. async def confirm_password_reset(
  176. self, reset_token: str, new_password: str
  177. ) -> dict[str, str]:
  178. # Use Supabase client to reset password with token
  179. if response := self.supabase.auth.reset_password_for_email(
  180. reset_token, new_password
  181. ):
  182. return {"message": "Password reset successfully"}
  183. else:
  184. raise R2RException(
  185. status_code=400, message="Invalid or expired reset token"
  186. )
  187. async def logout(self, token: str) -> dict[str, str]:
  188. # Use Supabase client to logout user and revoke token
  189. self.supabase.auth.sign_out(token)
  190. return {"message": "Logged out successfully"}
  191. async def clean_expired_blacklisted_tokens(self):
  192. # Not applicable for Supabase, tokens are managed by Supabase
  193. pass
  194. async def send_reset_email(self, email: str) -> dict[str, str]:
  195. raise NotImplementedError("send_reset_email is not used with Supabase")
  196. async def create_user_api_key(
  197. self, user_id: UUID, name: Optional[str] = None
  198. ) -> dict[str, str]:
  199. raise NotImplementedError(
  200. "API key management is not supported with Supabase authentication"
  201. )
  202. async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
  203. raise NotImplementedError(
  204. "API key management is not supported with Supabase authentication"
  205. )
  206. async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> dict:
  207. raise NotImplementedError(
  208. "API key management is not supported with Supabase authentication"
  209. )