sync_client.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import json
  2. from io import BytesIO
  3. from typing import Any, Generator
  4. from httpx import Client, ConnectError, RequestError, Response
  5. from shared.abstractions import R2RClientException, R2RException
  6. from .base.base_client import BaseClient
  7. from .sync_methods import (
  8. ChunksSDK,
  9. CollectionsSDK,
  10. ConversationsSDK,
  11. DocumentsSDK,
  12. GraphsSDK,
  13. IndicesSDK,
  14. PromptsSDK,
  15. RetrievalSDK,
  16. SystemSDK,
  17. UsersSDK,
  18. )
  19. class R2RClient(BaseClient):
  20. def __init__(
  21. self,
  22. base_url: str | None = None,
  23. timeout: float = 300.0,
  24. custom_client=None,
  25. ):
  26. super().__init__(base_url, timeout)
  27. self.client = custom_client or Client(timeout=timeout)
  28. self.chunks = ChunksSDK(self)
  29. self.collections = CollectionsSDK(self)
  30. self.conversations = ConversationsSDK(self)
  31. self.documents = DocumentsSDK(self)
  32. self.graphs = GraphsSDK(self)
  33. self.indices = IndicesSDK(self)
  34. self.prompts = PromptsSDK(self)
  35. self.retrieval = RetrievalSDK(self)
  36. self.system = SystemSDK(self)
  37. self.users = UsersSDK(self)
  38. def _make_request(
  39. self, method: str, endpoint: str, version: str = "v3", **kwargs
  40. ) -> dict[str, Any] | BytesIO | None:
  41. url = self._get_full_url(endpoint, version)
  42. request_args = self._prepare_request_args(endpoint, **kwargs)
  43. try:
  44. response = self.client.request(method, url, **request_args)
  45. self._handle_response(response)
  46. if "application/json" in response.headers.get("Content-Type", ""):
  47. return response.json() if response.content else None
  48. else:
  49. return BytesIO(response.content)
  50. except ConnectError as e:
  51. raise R2RClientException(
  52. message="Unable to connect to the server. Check your network connection and the server URL."
  53. ) from e
  54. except RequestError as e:
  55. raise R2RException(
  56. message=f"Request failed: {str(e)}",
  57. status_code=500,
  58. ) from e
  59. def _make_streaming_request(
  60. self, method: str, endpoint: str, version: str = "v3", **kwargs
  61. ) -> Generator[dict[str, str], None, None]:
  62. """
  63. Make a streaming request, parsing Server-Sent Events (SSE) in multiline form.
  64. Yields a dictionary with keys:
  65. - "event": the event type (or "unknown" if not provided)
  66. - "data": the JSON string (possibly spanning multiple lines) accumulated from the event's data lines
  67. """
  68. url = self._get_full_url(endpoint, version)
  69. request_args = self._prepare_request_args(endpoint, **kwargs)
  70. with Client(timeout=self.timeout) as client:
  71. with client.stream(method, url, **request_args) as response:
  72. self._handle_response(response)
  73. sse_event_block: dict[str, Any] = {"event": None, "data": []}
  74. for line in response.iter_lines():
  75. if isinstance(line, bytes):
  76. line = line.decode("utf-8", "replace")
  77. # Blank line -> end of this SSE event
  78. if line == "":
  79. # If there's any accumulated data, yield this event
  80. if sse_event_block["data"]:
  81. data_str = "".join(sse_event_block["data"])
  82. yield {
  83. "event": sse_event_block["event"] or "unknown",
  84. "data": data_str,
  85. }
  86. # Reset the block
  87. sse_event_block = {"event": None, "data": []}
  88. continue
  89. # Otherwise, parse the line
  90. if line.startswith("event:"):
  91. sse_event_block["event"] = line[
  92. len("event:") :
  93. ].lstrip()
  94. elif line.startswith("data:"):
  95. # Accumulate the exact substring after "data:"
  96. # Notice we do *not* strip() the entire line
  97. chunk = line[len("data:") :]
  98. sse_event_block["data"].append(chunk)
  99. # Optionally handle id:, retry:, etc. if needed
  100. # If something remains in the buffer at the end
  101. if sse_event_block["data"]:
  102. data_str = "".join(sse_event_block["data"])
  103. yield {
  104. "event": sse_event_block["event"] or "unknown",
  105. "data": data_str,
  106. }
  107. def _handle_response(self, response: Response) -> None:
  108. if response.status_code >= 400:
  109. try:
  110. error_content = response.json()
  111. if isinstance(error_content, dict):
  112. message = (
  113. error_content.get("detail", {}).get(
  114. "message", str(error_content)
  115. )
  116. if isinstance(error_content.get("detail"), dict)
  117. else error_content.get("detail", str(error_content))
  118. )
  119. else:
  120. message = str(error_content)
  121. except json.JSONDecodeError:
  122. message = response.text
  123. except Exception as e:
  124. message = str(e)
  125. raise R2RException(
  126. status_code=response.status_code, message=message
  127. )
  128. def set_api_key(self, api_key: str) -> None:
  129. if self.access_token:
  130. raise ValueError("Cannot have both access token and api key.")
  131. self.api_key = api_key
  132. def unset_api_key(self) -> None:
  133. self.api_key = None
  134. def set_base_url(self, base_url: str) -> None:
  135. self.base_url = base_url
  136. def set_project_name(self, project_name: str | None) -> None:
  137. self.project_name = project_name
  138. def unset_project_name(self) -> None:
  139. self.project_name = None