base.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import asyncio
  2. import json
  3. from datetime import datetime
  4. from enum import Enum
  5. from typing import Any, Type, TypeVar
  6. from uuid import UUID
  7. from pydantic import BaseModel
  8. T = TypeVar("T", bound="R2RSerializable")
  9. class R2RSerializable(BaseModel):
  10. @classmethod
  11. def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T:
  12. if isinstance(data, str):
  13. try:
  14. data_dict = json.loads(data)
  15. except json.JSONDecodeError as e:
  16. raise ValueError(f"Invalid JSON string: {e}") from e
  17. else:
  18. data_dict = data
  19. return cls(**data_dict)
  20. def as_dict(self) -> dict[str, Any]:
  21. data = self.model_dump(exclude_unset=True)
  22. return self._serialize_values(data)
  23. def to_dict(self) -> dict[str, Any]:
  24. data = self.model_dump(exclude_unset=True)
  25. return self._serialize_values(data)
  26. def to_json(self) -> str:
  27. data = self.to_dict()
  28. return json.dumps(data)
  29. @classmethod
  30. def from_json(cls: Type[T], json_str: str) -> T:
  31. return cls.model_validate_json(json_str)
  32. @staticmethod
  33. def _serialize_values(data: Any) -> Any:
  34. if isinstance(data, dict):
  35. return {
  36. k: R2RSerializable._serialize_values(v)
  37. for k, v in data.items()
  38. }
  39. elif isinstance(data, list):
  40. return [R2RSerializable._serialize_values(v) for v in data]
  41. elif isinstance(data, UUID):
  42. return str(data)
  43. elif isinstance(data, Enum):
  44. return data.value
  45. elif isinstance(data, datetime):
  46. return data.isoformat()
  47. else:
  48. return data
  49. class Config:
  50. arbitrary_types_allowed = True
  51. json_encoders = {
  52. UUID: str,
  53. bytes: lambda v: v.decode("utf-8", errors="ignore"),
  54. }
  55. class AsyncSyncMeta(type):
  56. _event_loop = None # Class-level shared event loop
  57. @classmethod
  58. def get_event_loop(cls):
  59. if cls._event_loop is None or cls._event_loop.is_closed():
  60. cls._event_loop = asyncio.new_event_loop()
  61. asyncio.set_event_loop(cls._event_loop)
  62. return cls._event_loop
  63. def __new__(cls, name, bases, dct):
  64. new_cls = super().__new__(cls, name, bases, dct)
  65. for attr_name, attr_value in dct.items():
  66. if asyncio.iscoroutinefunction(attr_value) and getattr(
  67. attr_value, "_syncable", False
  68. ):
  69. sync_method_name = attr_name[
  70. 1:
  71. ] # Remove leading 'a' for sync method
  72. async_method = attr_value
  73. def make_sync_method(async_method):
  74. def sync_wrapper(self, *args, **kwargs):
  75. loop = cls.get_event_loop()
  76. if not loop.is_running():
  77. # Setup to run the loop in a background thread if necessary
  78. # to prevent blocking the main thread in a synchronous call environment
  79. from threading import Thread
  80. result = None
  81. exception = None
  82. def run():
  83. nonlocal result, exception
  84. try:
  85. asyncio.set_event_loop(loop)
  86. result = loop.run_until_complete(
  87. async_method(self, *args, **kwargs)
  88. )
  89. except Exception as e:
  90. exception = e
  91. finally:
  92. generation_config = kwargs.get(
  93. "rag_generation_config", None
  94. )
  95. if (
  96. not generation_config
  97. or not generation_config.stream
  98. ):
  99. loop.run_until_complete(
  100. loop.shutdown_asyncgens()
  101. )
  102. loop.close()
  103. thread = Thread(target=run)
  104. thread.start()
  105. thread.join()
  106. if exception:
  107. raise exception
  108. return result
  109. else:
  110. # If there's already a running loop, schedule and execute the coroutine
  111. future = asyncio.run_coroutine_threadsafe(
  112. async_method(self, *args, **kwargs), loop
  113. )
  114. return future.result()
  115. return sync_wrapper
  116. setattr(
  117. new_cls, sync_method_name, make_sync_method(async_method)
  118. )
  119. return new_cls
  120. def syncable(func):
  121. """Decorator to mark methods for synchronous wrapper creation."""
  122. func._syncable = True
  123. return func