async_client.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import json
  2. from io import BytesIO
  3. from typing import Any, AsyncGenerator
  4. import httpx
  5. from httpx import AsyncClient, ConnectError, RequestError, Response
  6. from shared.abstractions import R2RClientException, R2RException
  7. from .asnyc_methods import (
  8. ChunksSDK,
  9. CollectionsSDK,
  10. ConversationsSDK,
  11. DocumentsSDK,
  12. GraphsSDK,
  13. IndicesSDK,
  14. PromptsSDK,
  15. RetrievalSDK,
  16. SystemSDK,
  17. UsersSDK,
  18. )
  19. from .base.base_client import BaseClient
  20. class R2RAsyncClient(BaseClient):
  21. """Asynchronous client for interacting with the R2R API."""
  22. def __init__(
  23. self,
  24. base_url: str | None = None,
  25. timeout: float = 300.0,
  26. custom_client=None,
  27. ):
  28. super().__init__(base_url, timeout)
  29. self.client = custom_client or AsyncClient(timeout=timeout)
  30. self.chunks = ChunksSDK(self)
  31. self.collections = CollectionsSDK(self)
  32. self.conversations = ConversationsSDK(self)
  33. self.documents = DocumentsSDK(self)
  34. self.graphs = GraphsSDK(self)
  35. self.indices = IndicesSDK(self)
  36. self.prompts = PromptsSDK(self)
  37. self.retrieval = RetrievalSDK(self)
  38. self.system = SystemSDK(self)
  39. self.users = UsersSDK(self)
  40. async def _make_request(
  41. self, method: str, endpoint: str, version: str = "v3", **kwargs
  42. ):
  43. url = self._get_full_url(endpoint, version)
  44. request_args = self._prepare_request_args(endpoint, **kwargs)
  45. try:
  46. response = await self.client.request(method, url, **request_args)
  47. await self._handle_response(response)
  48. if "application/json" in response.headers.get("Content-Type", ""):
  49. return response.json() if response.content else None
  50. else:
  51. return BytesIO(response.content)
  52. except ConnectError as e:
  53. raise R2RClientException(
  54. message="Unable to connect to the server. Check your network connection and the server URL."
  55. ) from e
  56. except RequestError as e:
  57. raise R2RException(
  58. message=f"Request failed: {str(e)}",
  59. status_code=500,
  60. ) from e
  61. async def _make_streaming_request(
  62. self, method: str, endpoint: str, version: str = "v3", **kwargs
  63. ) -> AsyncGenerator[Any, None]:
  64. url = self._get_full_url(endpoint, version)
  65. request_args = self._prepare_request_args(endpoint, **kwargs)
  66. async with httpx.AsyncClient(timeout=self.timeout) as client:
  67. async with client.stream(method, url, **request_args) as response:
  68. await self._handle_response(response)
  69. async for line in response.aiter_lines():
  70. if line.strip(): # Ignore empty lines
  71. try:
  72. yield json.loads(line)
  73. except Exception:
  74. yield line
  75. async def _handle_response(self, response: Response) -> None:
  76. if response.status_code >= 400:
  77. try:
  78. error_content = response.json()
  79. if isinstance(error_content, dict):
  80. message = (
  81. error_content.get("detail", {}).get(
  82. "message", str(error_content)
  83. )
  84. if isinstance(error_content.get("detail"), dict)
  85. else error_content.get("detail", str(error_content))
  86. )
  87. else:
  88. message = str(error_content)
  89. except json.JSONDecodeError:
  90. message = response.text
  91. except Exception as e:
  92. message = str(e)
  93. raise R2RException(
  94. status_code=response.status_code, message=message
  95. )
  96. async def close(self):
  97. await self.client.aclose()
  98. async def __aenter__(self):
  99. return self
  100. async def __aexit__(self, exc_type, exc_val, exc_tb):
  101. await self.close()
  102. def set_api_key(self, api_key: str) -> None:
  103. if self.access_token:
  104. raise ValueError("Cannot have both access token and api key.")
  105. self.api_key = api_key
  106. def unset_api_key(self) -> None:
  107. self.api_key = None
  108. def set_base_url(self, base_url: str) -> None:
  109. self.base_url = base_url
  110. def set_project_name(self, project_name: str | None) -> None:
  111. self.project_name = project_name
  112. def unset_project_name(self) -> None:
  113. self.project_name = None