supabase.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import logging
  2. import os
  3. from fastapi import Depends, HTTPException
  4. from fastapi.security import OAuth2PasswordBearer
  5. from supabase import Client, create_client
  6. from core.base import (
  7. AuthConfig,
  8. AuthProvider,
  9. CryptoProvider,
  10. DatabaseProvider,
  11. EmailProvider,
  12. R2RException,
  13. Token,
  14. TokenData,
  15. )
  16. from core.base.api.models import User
  17. logger = logging.getLogger()
  18. logger = logging.getLogger()
  19. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  20. class SupabaseAuthProvider(AuthProvider):
  21. def __init__(
  22. self,
  23. config: AuthConfig,
  24. crypto_provider: CryptoProvider,
  25. database_provider: DatabaseProvider,
  26. email_provider: EmailProvider,
  27. ):
  28. super().__init__(
  29. config, crypto_provider, database_provider, email_provider
  30. )
  31. self.supabase_url = config.extra_fields.get(
  32. "supabase_url", None
  33. ) or os.getenv("SUPABASE_URL")
  34. self.supabase_key = config.extra_fields.get(
  35. "supabase_key", None
  36. ) or os.getenv("SUPABASE_KEY")
  37. if not self.supabase_url or not self.supabase_key:
  38. raise HTTPException(
  39. status_code=500,
  40. detail="Supabase URL and key must be provided",
  41. )
  42. self.supabase: Client = create_client(
  43. self.supabase_url, self.supabase_key
  44. )
  45. async def initialize(self):
  46. # No initialization needed for Supabase
  47. pass
  48. def create_access_token(self, data: dict) -> str:
  49. raise NotImplementedError(
  50. "create_access_token is not used with Supabase authentication"
  51. )
  52. def create_refresh_token(self, data: dict) -> str:
  53. raise NotImplementedError(
  54. "create_refresh_token is not used with Supabase authentication"
  55. )
  56. async def decode_token(self, token: str) -> TokenData:
  57. raise NotImplementedError(
  58. "decode_token is not used with Supabase authentication"
  59. )
  60. async def register(self, email: str, password: str) -> User: # type: ignore
  61. # Use Supabase client to create a new user
  62. if user := self.supabase.auth.sign_up(email=email, password=password):
  63. raise R2RException(
  64. status_code=400,
  65. message="Supabase provider implementation is still under construction",
  66. )
  67. # return User(
  68. # id=user.id,
  69. # email=user.email,
  70. # is_active=True,
  71. # is_superuser=False,
  72. # created_at=user.created_at,
  73. # updated_at=user.updated_at,
  74. # is_verified=False,
  75. # )
  76. else:
  77. raise R2RException(
  78. status_code=400, message="User registration failed"
  79. )
  80. async def verify_email(
  81. self, email: str, verification_code: str
  82. ) -> dict[str, str]:
  83. # Use Supabase client to verify email
  84. if response := self.supabase.auth.verify_email(
  85. email, verification_code
  86. ):
  87. return {"message": "Email verified successfully"}
  88. else:
  89. raise R2RException(
  90. status_code=400, message="Invalid or expired verification code"
  91. )
  92. async def login(self, email: str, password: str) -> dict[str, Token]:
  93. # Use Supabase client to authenticate user and get tokens
  94. if response := self.supabase.auth.sign_in(
  95. email=email, password=password
  96. ):
  97. access_token = response.access_token
  98. refresh_token = response.refresh_token
  99. return {
  100. "access_token": Token(token=access_token, token_type="access"),
  101. "refresh_token": Token(
  102. token=refresh_token, token_type="refresh"
  103. ),
  104. }
  105. else:
  106. raise R2RException(
  107. status_code=401, message="Invalid email or password"
  108. )
  109. async def refresh_access_token(
  110. self, refresh_token: str
  111. ) -> dict[str, Token]:
  112. # Use Supabase client to refresh access token
  113. if response := self.supabase.auth.refresh_access_token(refresh_token):
  114. new_access_token = response.access_token
  115. new_refresh_token = response.refresh_token
  116. return {
  117. "access_token": Token(
  118. token=new_access_token, token_type="access"
  119. ),
  120. "refresh_token": Token(
  121. token=new_refresh_token, token_type="refresh"
  122. ),
  123. }
  124. else:
  125. raise R2RException(
  126. status_code=401, message="Invalid refresh token"
  127. )
  128. async def user(self, token: str = Depends(oauth2_scheme)) -> User:
  129. # Use Supabase client to get user details from token
  130. if user := self.supabase.auth.get_user(token).user:
  131. return User(
  132. id=user.id,
  133. email=user.email,
  134. is_active=True, # Assuming active if exists in Supabase
  135. is_superuser=False, # Default to False unless explicitly set
  136. created_at=user.created_at,
  137. updated_at=user.updated_at,
  138. is_verified=user.email_confirmed_at is not None,
  139. name=user.user_metadata.get("full_name"),
  140. # Set other optional fields if available in user metadata
  141. )
  142. else:
  143. raise R2RException(status_code=401, message="Invalid token")
  144. def get_current_active_user(
  145. self, current_user: User = Depends(user)
  146. ) -> User:
  147. # Check if user is active
  148. if not current_user.is_active:
  149. raise R2RException(status_code=400, message="Inactive user")
  150. return current_user
  151. async def change_password(
  152. self, user: User, current_password: str, new_password: str
  153. ) -> dict[str, str]:
  154. # Use Supabase client to update user password
  155. if response := self.supabase.auth.update(
  156. user.id, {"password": new_password}
  157. ):
  158. return {"message": "Password changed successfully"}
  159. else:
  160. raise R2RException(
  161. status_code=400, message="Failed to change password"
  162. )
  163. async def request_password_reset(self, email: str) -> dict[str, str]:
  164. # Use Supabase client to send password reset email
  165. if response := self.supabase.auth.send_password_reset_email(email):
  166. return {
  167. "message": "If the email exists, a reset link has been sent"
  168. }
  169. else:
  170. raise R2RException(
  171. status_code=400, message="Failed to send password reset email"
  172. )
  173. async def confirm_password_reset(
  174. self, reset_token: str, new_password: str
  175. ) -> dict[str, str]:
  176. # Use Supabase client to reset password with token
  177. if response := self.supabase.auth.reset_password_for_email(
  178. reset_token, new_password
  179. ):
  180. return {"message": "Password reset successfully"}
  181. else:
  182. raise R2RException(
  183. status_code=400, message="Invalid or expired reset token"
  184. )
  185. async def logout(self, token: str) -> dict[str, str]:
  186. # Use Supabase client to logout user and revoke token
  187. self.supabase.auth.sign_out(token)
  188. return {"message": "Logged out successfully"}
  189. async def clean_expired_blacklisted_tokens(self):
  190. # Not applicable for Supabase, tokens are managed by Supabase
  191. pass
  192. async def send_reset_email(self, email: str) -> dict[str, str]:
  193. raise NotImplementedError("send_reset_email is not used with Supabase")