123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- import asyncio
- import json
- from datetime import datetime
- from enum import Enum
- from typing import Any, Type, TypeVar
- from uuid import UUID
- from pydantic import BaseModel
- T = TypeVar("T", bound="R2RSerializable")
- class R2RSerializable(BaseModel):
- @classmethod
- def from_dict(cls: Type[T], data: dict[str, Any] | str) -> T:
- if isinstance(data, str):
- try:
- data_dict = json.loads(data)
- except json.JSONDecodeError as e:
- raise ValueError(f"Invalid JSON string: {e}") from e
- else:
- data_dict = data
- return cls(**data_dict)
- def to_dict(self) -> dict[str, Any]:
- data = self.model_dump(exclude_unset=True)
- return self._serialize_values(data)
- def to_json(self) -> str:
- data = self.to_dict()
- return json.dumps(data)
- @classmethod
- def from_json(cls: Type[T], json_str: str) -> T:
- return cls.model_validate_json(json_str)
- @staticmethod
- def _serialize_values(data: Any) -> Any:
- if isinstance(data, dict):
- return {
- k: R2RSerializable._serialize_values(v)
- for k, v in data.items()
- }
- elif isinstance(data, list):
- return [R2RSerializable._serialize_values(v) for v in data]
- elif isinstance(data, UUID):
- return str(data)
- elif isinstance(data, Enum):
- return data.value
- elif isinstance(data, datetime):
- return data.isoformat()
- else:
- return data
- class Config:
- arbitrary_types_allowed = True
- json_encoders = {
- UUID: str,
- bytes: lambda v: v.decode("utf-8", errors="ignore"),
- }
- class AsyncSyncMeta(type):
- _event_loop = None # Class-level shared event loop
- @classmethod
- def get_event_loop(cls):
- if cls._event_loop is None or cls._event_loop.is_closed():
- cls._event_loop = asyncio.new_event_loop()
- asyncio.set_event_loop(cls._event_loop)
- return cls._event_loop
- def __new__(cls, name, bases, dct):
- new_cls = super().__new__(cls, name, bases, dct)
- for attr_name, attr_value in dct.items():
- if asyncio.iscoroutinefunction(attr_value) and getattr(
- attr_value, "_syncable", False
- ):
- sync_method_name = attr_name[
- 1:
- ] # Remove leading 'a' for sync method
- async_method = attr_value
- def make_sync_method(async_method):
- def sync_wrapper(self, *args, **kwargs):
- loop = cls.get_event_loop()
- if not loop.is_running():
- # Setup to run the loop in a background thread if necessary
- # to prevent blocking the main thread in a synchronous call environment
- from threading import Thread
- result = None
- exception = None
- def run():
- nonlocal result, exception
- try:
- asyncio.set_event_loop(loop)
- result = loop.run_until_complete(
- async_method(self, *args, **kwargs)
- )
- except Exception as e:
- exception = e
- finally:
- generation_config = kwargs.get(
- "rag_generation_config", None
- )
- if (
- not generation_config
- or not generation_config.stream
- ):
- loop.run_until_complete(
- loop.shutdown_asyncgens()
- )
- loop.close()
- thread = Thread(target=run)
- thread.start()
- thread.join()
- if exception:
- raise exception
- return result
- else:
- # If there's already a running loop, schedule and execute the coroutine
- future = asyncio.run_coroutine_threadsafe(
- async_method(self, *args, **kwargs), loop
- )
- return future.result()
- return sync_wrapper
- setattr(
- new_cls, sync_method_name, make_sync_method(async_method)
- )
- return new_cls
- def syncable(func):
- """Decorator to mark methods for synchronous wrapper creation."""
- func._syncable = True
- return func
|