base_client.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import asyncio
  2. import contextlib
  3. from functools import wraps
  4. from typing import Optional
  5. from shared.abstractions import R2RException
  6. def sync_wrapper(async_func):
  7. """Decorator to convert async methods to sync methods"""
  8. @wraps(async_func)
  9. def wrapper(*args, **kwargs):
  10. loop = asyncio.get_event_loop()
  11. return loop.run_until_complete(async_func(*args, **kwargs))
  12. return wrapper
  13. def sync_generator_wrapper(async_gen_func):
  14. """Decorator to convert async generators to sync generators"""
  15. @wraps(async_gen_func)
  16. def wrapper(*args, **kwargs):
  17. async_gen = async_gen_func(*args, **kwargs)
  18. loop = asyncio.get_event_loop()
  19. with contextlib.suppress(StopAsyncIteration):
  20. while True:
  21. yield loop.run_until_complete(async_gen.__anext__())
  22. return wrapper
  23. class BaseClient:
  24. def __init__(
  25. self,
  26. base_url: str = "http://localhost:7272",
  27. prefix: str = "/v2",
  28. timeout: float = 300.0,
  29. ):
  30. self.base_url = base_url
  31. self.prefix = prefix
  32. self.timeout = timeout
  33. self.access_token: Optional[str] = None
  34. self._refresh_token: Optional[str] = None
  35. def _get_auth_header(self) -> dict[str, str]:
  36. if not self.access_token:
  37. return {}
  38. return {"Authorization": f"Bearer {self.access_token}"}
  39. def _ensure_authenticated(self):
  40. if not self.access_token:
  41. raise R2RException(
  42. status_code=401,
  43. message="Not authenticated. Please login first.",
  44. )
  45. def _get_full_url(self, endpoint: str, version: str = "v2") -> str:
  46. return f"{self.base_url}/{version}/{endpoint}"
  47. def _prepare_request_args(self, endpoint: str, **kwargs) -> dict:
  48. headers = kwargs.pop("headers", {})
  49. if self.access_token and endpoint not in [
  50. "register",
  51. "login",
  52. "verify_email",
  53. ]:
  54. headers.update(self._get_auth_header())
  55. if (
  56. kwargs.get("params", None) == {}
  57. or kwargs.get("params", None) is None
  58. ):
  59. kwargs.pop("params", None)
  60. return {"headers": headers, **kwargs}