base_client.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import asyncio
  2. import contextlib
  3. import os
  4. from functools import wraps
  5. from typing import Optional
  6. from shared.abstractions import R2RException
  7. def sync_wrapper(async_func):
  8. """Decorator to convert async methods to sync methods"""
  9. @wraps(async_func)
  10. def wrapper(*args, **kwargs):
  11. loop = asyncio.get_event_loop()
  12. return loop.run_until_complete(async_func(*args, **kwargs))
  13. return wrapper
  14. def sync_generator_wrapper(async_gen_func):
  15. """Decorator to convert async generators to sync generators"""
  16. @wraps(async_gen_func)
  17. def wrapper(*args, **kwargs):
  18. async_gen = async_gen_func(*args, **kwargs)
  19. loop = asyncio.get_event_loop()
  20. with contextlib.suppress(StopAsyncIteration):
  21. while True:
  22. yield loop.run_until_complete(async_gen.__anext__())
  23. return wrapper
  24. class BaseClient:
  25. def __init__(
  26. self,
  27. base_url: str = "https://api.cloud.sciphi.ai",
  28. prefix: str = "/v2",
  29. timeout: float = 300.0,
  30. ):
  31. self.base_url = base_url
  32. self.prefix = prefix
  33. self.timeout = timeout
  34. self.access_token: Optional[str] = None
  35. self._refresh_token: Optional[str] = None
  36. self.api_key: Optional[str] = os.getenv("R2R_API_KEY", None)
  37. def _get_auth_header(self) -> dict[str, str]:
  38. if self.access_token and self.api_key:
  39. raise R2RException(
  40. status_code=400,
  41. message="Cannot have both access token and api key.",
  42. )
  43. if self.access_token:
  44. return {"Authorization": f"Bearer {self.access_token}"}
  45. elif self.api_key:
  46. return {"x-api-key": self.api_key}
  47. else:
  48. return {}
  49. def _ensure_authenticated(self):
  50. if not self.access_token:
  51. raise R2RException(
  52. status_code=401,
  53. message="Not authenticated. Please login first.",
  54. )
  55. def _get_full_url(self, endpoint: str, version: str = "v2") -> str:
  56. return f"{self.base_url}/{version}/{endpoint}"
  57. def _prepare_request_args(self, endpoint: str, **kwargs) -> dict:
  58. headers = kwargs.pop("headers", {})
  59. if (self.access_token or self.api_key) and endpoint not in [
  60. "register",
  61. "login",
  62. "verify_email",
  63. ]:
  64. headers.update(self._get_auth_header())
  65. if (
  66. kwargs.get("params", None) == {}
  67. or kwargs.get("params", None) is None
  68. ):
  69. kwargs.pop("params", None)
  70. return {"headers": headers, **kwargs}