sync_client.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import asyncio
  2. import contextlib
  3. import functools
  4. import inspect
  5. from typing import Any, Callable, Coroutine, TypeVar
  6. from .async_client import R2RAsyncClient
  7. T = TypeVar("T")
  8. class R2RClient(R2RAsyncClient):
  9. def __init__(self, *args: Any, **kwargs: Any):
  10. super().__init__(*args, **kwargs)
  11. self._loop = asyncio.new_event_loop()
  12. asyncio.set_event_loop(self._loop)
  13. # Store async version of _make_request
  14. self._async_make_request = self._make_request
  15. # Only wrap v3 methods since they're already working
  16. self._wrap_v3_methods()
  17. def _make_sync_request(self, *args, **kwargs):
  18. """Sync version of _make_request for v2 methods"""
  19. return self._loop.run_until_complete(
  20. self._async_make_request(*args, **kwargs)
  21. )
  22. def _wrap_v3_methods(self) -> None:
  23. """Wraps only v3 SDK object methods"""
  24. sdk_objects = [
  25. self.chunks,
  26. self.collections,
  27. self.conversations,
  28. self.documents,
  29. self.graphs,
  30. self.indices,
  31. self.prompts,
  32. self.retrieval,
  33. self.users,
  34. self.system,
  35. ]
  36. for sdk_obj in sdk_objects:
  37. for name in dir(sdk_obj):
  38. if name.startswith("_"):
  39. continue
  40. attr = getattr(sdk_obj, name)
  41. if inspect.iscoroutinefunction(attr):
  42. wrapped = self._make_sync_method(attr)
  43. setattr(sdk_obj, name, wrapped)
  44. # def _make_sync_method(self, async_method):
  45. def _make_sync_method(
  46. self, async_method: Callable[..., Coroutine[Any, Any, T]]
  47. ) -> Callable[..., T]:
  48. @functools.wraps(async_method)
  49. def wrapped(*args, **kwargs):
  50. return self._loop.run_until_complete(async_method(*args, **kwargs))
  51. return wrapped
  52. def __del__(self):
  53. if hasattr(self, "_loop") and self._loop is not None:
  54. with contextlib.suppress(Exception):
  55. if not self._loop.is_closed():
  56. try:
  57. self._loop.run_until_complete(self._async_close())
  58. except RuntimeError:
  59. # If the event loop is already running, we can't use run_until_complete
  60. if self._loop.is_running():
  61. self._loop.call_soon_threadsafe(self._sync_close)
  62. else:
  63. asyncio.run_coroutine_threadsafe(
  64. self._async_close(), self._loop
  65. )
  66. finally:
  67. self._loop.close()