base.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import asyncio
  2. import json
  3. from datetime import datetime
  4. from enum import Enum
  5. from typing import Any, Type, TypeVar, Union
  6. from uuid import UUID
  7. from pydantic import BaseModel
  8. T = TypeVar("T", bound="R2RSerializable")
  9. class R2RSerializable(BaseModel):
  10. @classmethod
  11. def from_dict(cls: Type[T], data: Union[dict[str, Any], str]) -> T:
  12. if isinstance(data, str):
  13. data = json.loads(data)
  14. return cls(**data)
  15. def to_dict(self) -> dict[str, Any]:
  16. data = self.model_dump(exclude_unset=True)
  17. return self._serialize_values(data)
  18. def to_json(self) -> str:
  19. data = self.to_dict()
  20. return json.dumps(data)
  21. @classmethod
  22. def from_json(cls: Type[T], json_str: str) -> T:
  23. return cls.model_validate_json(json_str)
  24. @staticmethod
  25. def _serialize_values(data: Any) -> Any:
  26. if isinstance(data, dict):
  27. return {
  28. k: R2RSerializable._serialize_values(v)
  29. for k, v in data.items()
  30. }
  31. elif isinstance(data, list):
  32. return [R2RSerializable._serialize_values(v) for v in data]
  33. elif isinstance(data, UUID):
  34. return str(data)
  35. elif isinstance(data, Enum):
  36. return data.value
  37. elif isinstance(data, datetime):
  38. return data.isoformat()
  39. else:
  40. return data
  41. class Config:
  42. arbitrary_types_allowed = True
  43. json_encoders = {
  44. UUID: str,
  45. bytes: lambda v: v.decode("utf-8", errors="ignore"),
  46. }
  47. class AsyncSyncMeta(type):
  48. _event_loop = None # Class-level shared event loop
  49. @classmethod
  50. def get_event_loop(cls):
  51. if cls._event_loop is None or cls._event_loop.is_closed():
  52. cls._event_loop = asyncio.new_event_loop()
  53. asyncio.set_event_loop(cls._event_loop)
  54. return cls._event_loop
  55. def __new__(cls, name, bases, dct):
  56. new_cls = super().__new__(cls, name, bases, dct)
  57. for attr_name, attr_value in dct.items():
  58. if asyncio.iscoroutinefunction(attr_value) and getattr(
  59. attr_value, "_syncable", False
  60. ):
  61. sync_method_name = attr_name[
  62. 1:
  63. ] # Remove leading 'a' for sync method
  64. async_method = attr_value
  65. def make_sync_method(async_method):
  66. def sync_wrapper(self, *args, **kwargs):
  67. loop = cls.get_event_loop()
  68. if not loop.is_running():
  69. # Setup to run the loop in a background thread if necessary
  70. # to prevent blocking the main thread in a synchronous call environment
  71. from threading import Thread
  72. result = None
  73. exception = None
  74. def run():
  75. nonlocal result, exception
  76. try:
  77. asyncio.set_event_loop(loop)
  78. result = loop.run_until_complete(
  79. async_method(self, *args, **kwargs)
  80. )
  81. except Exception as e:
  82. exception = e
  83. finally:
  84. generation_config = kwargs.get(
  85. "rag_generation_config", None
  86. )
  87. if (
  88. not generation_config
  89. or not generation_config.stream
  90. ):
  91. loop.run_until_complete(
  92. loop.shutdown_asyncgens()
  93. )
  94. loop.close()
  95. thread = Thread(target=run)
  96. thread.start()
  97. thread.join()
  98. if exception:
  99. raise exception
  100. return result
  101. else:
  102. # If there's already a running loop, schedule and execute the coroutine
  103. future = asyncio.run_coroutine_threadsafe(
  104. async_method(self, *args, **kwargs), loop
  105. )
  106. return future.result()
  107. return sync_wrapper
  108. setattr(
  109. new_cls, sync_method_name, make_sync_method(async_method)
  110. )
  111. return new_cls
  112. def syncable(func):
  113. """Decorator to mark methods for synchronous wrapper creation."""
  114. func._syncable = True
  115. return func