import asyncio import logging from abc import abstractmethod from enum import Enum from typing import Any, AsyncGenerator, Generic, Optional, TypeVar from uuid import UUID from pydantic import BaseModel from core.base.logger.run_manager import RunManager, manage_run logger = logging.getLogger() class AsyncState: """A state object for storing data between pipes.""" def __init__(self): self.data = {} self.lock = asyncio.Lock() async def update(self, outer_key: str, values: dict): """Update the state with new values.""" async with self.lock: if not isinstance(values, dict): raise ValueError("Values must be contained in a dictionary.") if outer_key not in self.data: self.data[outer_key] = {} for inner_key, inner_value in values.items(): self.data[outer_key][inner_key] = inner_value async def get(self, outer_key: str, inner_key: str, default=None): """Get a value from the state.""" async with self.lock: if outer_key not in self.data: raise ValueError( f"Key {outer_key} does not exist in the state." ) if inner_key not in self.data[outer_key]: return default or {} return self.data[outer_key][inner_key] async def delete(self, outer_key: str, inner_key: Optional[str] = None): """Delete a value from the state.""" async with self.lock: if outer_key in self.data and not inner_key: del self.data[outer_key] else: if inner_key not in self.data[outer_key]: raise ValueError( f"Key {inner_key} does not exist in the state." ) del self.data[outer_key][inner_key] T = TypeVar("T") class AsyncPipe(Generic[T]): """An asynchronous pipe for processing data with logging capabilities.""" class PipeConfig(BaseModel): """Configuration for a pipe.""" name: str = "default_pipe" max_log_queue_size: int = 100 class Config: extra = "forbid" arbitrary_types_allowed = True class Input(BaseModel): """Input for a pipe.""" message: Any class Config: extra = "forbid" arbitrary_types_allowed = True def __init__( self, config: PipeConfig, run_manager: Optional[RunManager] = None, ): # TODO - Deprecate self._config = config or self.PipeConfig() self._run_manager = run_manager or RunManager() logger.debug(f"Initialized pipe {self.config.name}") @property def config(self) -> PipeConfig: return self._config async def run( self, input: Input, state: Optional[AsyncState], run_manager: Optional[RunManager] = None, *args: Any, **kwargs: Any, ) -> AsyncGenerator[T, None]: """Run the pipe with logging capabilities.""" run_manager = run_manager or self._run_manager state = state or AsyncState() async def wrapped_run() -> AsyncGenerator[Any, None]: async with manage_run(run_manager) as run_id: # type: ignore async for result in self._run_logic( # type: ignore input, state, run_id, *args, **kwargs # type: ignore ): yield result return wrapped_run() @abstractmethod async def _run_logic( self, input: Input, state: AsyncState, run_id: UUID, *args: Any, **kwargs: Any, ) -> AsyncGenerator[T, None]: pass