base_client.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import asyncio
  2. import contextlib
  3. from functools import wraps
  4. from typing import Optional
  5. from shared.abstractions import R2RException
  6. def sync_wrapper(async_func):
  7. """Decorator to convert async methods to sync methods"""
  8. @wraps(async_func)
  9. def wrapper(*args, **kwargs):
  10. loop = asyncio.get_event_loop()
  11. return loop.run_until_complete(async_func(*args, **kwargs))
  12. return wrapper
  13. def sync_generator_wrapper(async_gen_func):
  14. """Decorator to convert async generators to sync generators"""
  15. @wraps(async_gen_func)
  16. def wrapper(*args, **kwargs):
  17. async_gen = async_gen_func(*args, **kwargs)
  18. loop = asyncio.get_event_loop()
  19. with contextlib.suppress(StopAsyncIteration):
  20. while True:
  21. yield loop.run_until_complete(async_gen.__anext__())
  22. return wrapper
  23. class BaseClient:
  24. def __init__(
  25. self,
  26. base_url: str = "http://localhost:7272",
  27. prefix: str = "/v2",
  28. timeout: float = 300.0,
  29. ):
  30. self.base_url = base_url
  31. self.prefix = prefix
  32. self.timeout = timeout
  33. self.access_token: Optional[str] = None
  34. self._refresh_token: Optional[str] = None
  35. self.api_key: Optional[str] = None
  36. def _get_auth_header(self) -> dict[str, str]:
  37. if self.access_token and self.api_key:
  38. raise R2RException(
  39. status_code=400,
  40. message="Cannot have both access token and api key.",
  41. )
  42. if self.access_token:
  43. return {"Authorization": f"Bearer {self.access_token}"}
  44. elif self.api_key:
  45. return {"x-api-key": self.api_key}
  46. else:
  47. return {}
  48. def _ensure_authenticated(self):
  49. if not self.access_token:
  50. raise R2RException(
  51. status_code=401,
  52. message="Not authenticated. Please login first.",
  53. )
  54. def _get_full_url(self, endpoint: str, version: str = "v2") -> str:
  55. return f"{self.base_url}/{version}/{endpoint}"
  56. def _prepare_request_args(self, endpoint: str, **kwargs) -> dict:
  57. headers = kwargs.pop("headers", {})
  58. if (self.access_token or self.api_key) and endpoint not in [
  59. "register",
  60. "login",
  61. "verify_email",
  62. ]:
  63. headers.update(self._get_auth_header())
  64. if (
  65. kwargs.get("params", None) == {}
  66. or kwargs.get("params", None) is None
  67. ):
  68. kwargs.pop("params", None)
  69. return {"headers": headers, **kwargs}