sync_client.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import asyncio
  2. import contextlib
  3. import functools
  4. import inspect
  5. from typing import Any, Callable, Coroutine, TypeVar
  6. from .async_client import R2RAsyncClient
  7. from .v2 import (
  8. SyncAuthMixins,
  9. SyncIngestionMixins,
  10. SyncKGMixins,
  11. SyncManagementMixins,
  12. SyncRetrievalMixins,
  13. SyncServerMixins,
  14. )
  15. T = TypeVar("T")
  16. class R2RClient(R2RAsyncClient):
  17. def __init__(self, *args: Any, **kwargs: Any):
  18. super().__init__(*args, **kwargs)
  19. self._loop = asyncio.new_event_loop()
  20. asyncio.set_event_loop(self._loop)
  21. # Store async version of _make_request
  22. self._async_make_request = self._make_request
  23. # Only wrap v3 methods since they're already working
  24. self._wrap_v3_methods()
  25. # Override v2 methods with sync versions
  26. self._override_v2_methods()
  27. def _make_sync_request(self, *args, **kwargs):
  28. """Sync version of _make_request for v2 methods"""
  29. return self._loop.run_until_complete(
  30. self._async_make_request(*args, **kwargs)
  31. )
  32. def _override_v2_methods(self):
  33. """
  34. Replace async v2 methods with sync versions
  35. This is really ugly, but it's the only way to make it work once we
  36. remove v2, we can just resort to the metaclass approach that is in utils
  37. """
  38. sync_mixins = {
  39. SyncAuthMixins: ["auth_methods"],
  40. SyncIngestionMixins: ["ingestion_methods"],
  41. SyncKGMixins: ["kg_methods"],
  42. SyncManagementMixins: ["management_methods"],
  43. SyncRetrievalMixins: ["retrieval_methods"],
  44. SyncServerMixins: ["server_methods"],
  45. }
  46. for sync_class in sync_mixins:
  47. for name, method in sync_class.__dict__.items():
  48. if not name.startswith("_") and inspect.isfunction(method):
  49. # Create a wrapper that uses sync _make_request
  50. def wrap_method(m):
  51. def wrapped(self, *args, **kwargs):
  52. # Temporarily swap _make_request
  53. original_make_request = self._make_request
  54. self._make_request = self._make_sync_request
  55. try:
  56. return m(self, *args, **kwargs)
  57. finally:
  58. # Restore original _make_request
  59. self._make_request = original_make_request
  60. return wrapped
  61. bound_method = wrap_method(method).__get__(
  62. self, self.__class__
  63. )
  64. setattr(self, name, bound_method)
  65. def _wrap_v3_methods(self) -> None:
  66. """Wraps only v3 SDK object methods"""
  67. sdk_objects = [
  68. self.chunks,
  69. self.collections,
  70. self.conversations,
  71. self.documents,
  72. self.graphs,
  73. self.indices,
  74. self.prompts,
  75. self.retrieval,
  76. self.users,
  77. ]
  78. for sdk_obj in sdk_objects:
  79. for name in dir(sdk_obj):
  80. if name.startswith("_"):
  81. continue
  82. attr = getattr(sdk_obj, name)
  83. if inspect.iscoroutinefunction(attr):
  84. wrapped = self._make_sync_method(attr)
  85. setattr(sdk_obj, name, wrapped)
  86. # def _make_sync_method(self, async_method):
  87. def _make_sync_method(
  88. self, async_method: Callable[..., Coroutine[Any, Any, T]]
  89. ) -> Callable[..., T]:
  90. @functools.wraps(async_method)
  91. def wrapped(*args, **kwargs):
  92. return self._loop.run_until_complete(async_method(*args, **kwargs))
  93. return wrapped
  94. def __del__(self):
  95. if hasattr(self, "_loop") and self._loop is not None:
  96. with contextlib.suppress(Exception):
  97. if not self._loop.is_closed():
  98. try:
  99. self._loop.run_until_complete(self._async_close())
  100. except RuntimeError:
  101. # If the event loop is already running, we can't use run_until_complete
  102. if self._loop.is_running():
  103. self._loop.call_soon_threadsafe(self._sync_close)
  104. else:
  105. asyncio.run_coroutine_threadsafe(
  106. self._async_close(), self._loop
  107. )
  108. finally:
  109. self._loop.close()