base.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import asyncio
  2. import json
  3. from datetime import datetime
  4. from enum import Enum
  5. from typing import Any, Type, TypeVar
  6. from uuid import UUID
  7. from pydantic import BaseModel
  8. T = TypeVar("T", bound="R2RSerializable")
  9. class R2RSerializable(BaseModel):
  10. @classmethod
  11. def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T:
  12. if isinstance(data, str):
  13. try:
  14. data_dict = json.loads(data)
  15. except json.JSONDecodeError as e:
  16. raise ValueError(f"Invalid JSON string: {e}") from e
  17. else:
  18. data_dict = data
  19. return cls(**data_dict)
  20. def to_dict(self) -> dict[str, Any]:
  21. data = self.model_dump(exclude_unset=True)
  22. return self._serialize_values(data)
  23. def to_json(self) -> str:
  24. data = self.to_dict()
  25. return json.dumps(data)
  26. @classmethod
  27. def from_json(cls: Type[T], json_str: str) -> T:
  28. return cls.model_validate_json(json_str)
  29. @staticmethod
  30. def _serialize_values(data: Any) -> Any:
  31. if isinstance(data, dict):
  32. return {
  33. k: R2RSerializable._serialize_values(v)
  34. for k, v in data.items()
  35. }
  36. elif isinstance(data, list):
  37. return [R2RSerializable._serialize_values(v) for v in data]
  38. elif isinstance(data, UUID):
  39. return str(data)
  40. elif isinstance(data, Enum):
  41. return data.value
  42. elif isinstance(data, datetime):
  43. return data.isoformat()
  44. else:
  45. return data
  46. class Config:
  47. arbitrary_types_allowed = True
  48. json_encoders = {
  49. UUID: str,
  50. bytes: lambda v: v.decode("utf-8", errors="ignore"),
  51. }
  52. class AsyncSyncMeta(type):
  53. _event_loop = None # Class-level shared event loop
  54. @classmethod
  55. def get_event_loop(cls):
  56. if cls._event_loop is None or cls._event_loop.is_closed():
  57. cls._event_loop = asyncio.new_event_loop()
  58. asyncio.set_event_loop(cls._event_loop)
  59. return cls._event_loop
  60. def __new__(cls, name, bases, dct):
  61. new_cls = super().__new__(cls, name, bases, dct)
  62. for attr_name, attr_value in dct.items():
  63. if asyncio.iscoroutinefunction(attr_value) and getattr(
  64. attr_value, "_syncable", False
  65. ):
  66. sync_method_name = attr_name[
  67. 1:
  68. ] # Remove leading 'a' for sync method
  69. async_method = attr_value
  70. def make_sync_method(async_method):
  71. def sync_wrapper(self, *args, **kwargs):
  72. loop = cls.get_event_loop()
  73. if not loop.is_running():
  74. # Setup to run the loop in a background thread if necessary
  75. # to prevent blocking the main thread in a synchronous call environment
  76. from threading import Thread
  77. result = None
  78. exception = None
  79. def run():
  80. nonlocal result, exception
  81. try:
  82. asyncio.set_event_loop(loop)
  83. result = loop.run_until_complete(
  84. async_method(self, *args, **kwargs)
  85. )
  86. except Exception as e:
  87. exception = e
  88. finally:
  89. generation_config = kwargs.get(
  90. "rag_generation_config", None
  91. )
  92. if (
  93. not generation_config
  94. or not generation_config.stream
  95. ):
  96. loop.run_until_complete(
  97. loop.shutdown_asyncgens()
  98. )
  99. loop.close()
  100. thread = Thread(target=run)
  101. thread.start()
  102. thread.join()
  103. if exception:
  104. raise exception
  105. return result
  106. else:
  107. # If there's already a running loop, schedule and execute the coroutine
  108. future = asyncio.run_coroutine_threadsafe(
  109. async_method(self, *args, **kwargs), loop
  110. )
  111. return future.result()
  112. return sync_wrapper
  113. setattr(
  114. new_cls, sync_method_name, make_sync_method(async_method)
  115. )
  116. return new_cls
  117. def syncable(func):
  118. """Decorator to mark methods for synchronous wrapper creation."""
  119. func._syncable = True
  120. return func