base_client.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import os
  2. from shared.abstractions import R2RClientException
  3. class BaseClient:
  4. def __init__(
  5. self,
  6. base_url: str | None = None,
  7. timeout: float = 300.0,
  8. ):
  9. self.base_url = base_url or os.getenv(
  10. "R2R_API_BASE", "http://localhost:7272"
  11. )
  12. self.timeout = timeout
  13. self.access_token: str | None = None
  14. self._refresh_token: str | None = None
  15. self._user_id: str | None = None
  16. self.api_key: str | None = os.getenv("R2R_API_KEY", None)
  17. self.project_name: str | None = None
  18. def _get_auth_header(self) -> dict[str, str]:
  19. if self.access_token and self.api_key:
  20. raise R2RClientException(
  21. message="Cannot have both access token and api key.",
  22. )
  23. if self.access_token:
  24. return {"Authorization": f"Bearer {self.access_token}"}
  25. elif self.api_key:
  26. return {"x-api-key": self.api_key}
  27. else:
  28. return {}
  29. def _get_full_url(self, endpoint: str, version: str = "v3") -> str:
  30. return f"{self.base_url}/{version}/{endpoint}"
  31. def _prepare_request_args(self, endpoint: str, **kwargs) -> dict:
  32. headers = kwargs.pop("headers", {})
  33. if (self.access_token or self.api_key) and endpoint not in [
  34. "register",
  35. "login",
  36. "verify_email",
  37. ]:
  38. headers.update(self._get_auth_header())
  39. if self.project_name:
  40. headers["x-project-name"] = self.project_name
  41. if (
  42. kwargs.get("params", None) == {}
  43. or kwargs.get("params", None) is None
  44. ):
  45. kwargs.pop("params", None)
  46. return {"headers": headers, **kwargs}