123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- import asyncio
- import logging
- from abc import abstractmethod
- from enum import Enum
- from typing import Any, AsyncGenerator, Generic, Optional, TypeVar
- from uuid import UUID
- from pydantic import BaseModel
- from core.base.logger.run_manager import RunManager, manage_run
- logger = logging.getLogger()
- class AsyncState:
- """A state object for storing data between pipes."""
- def __init__(self):
- self.data = {}
- self.lock = asyncio.Lock()
- async def update(self, outer_key: str, values: dict):
- """Update the state with new values."""
- async with self.lock:
- if not isinstance(values, dict):
- raise ValueError("Values must be contained in a dictionary.")
- if outer_key not in self.data:
- self.data[outer_key] = {}
- for inner_key, inner_value in values.items():
- self.data[outer_key][inner_key] = inner_value
- async def get(self, outer_key: str, inner_key: str, default=None):
- """Get a value from the state."""
- async with self.lock:
- if outer_key not in self.data:
- raise ValueError(
- f"Key {outer_key} does not exist in the state."
- )
- if inner_key not in self.data[outer_key]:
- return default or {}
- return self.data[outer_key][inner_key]
- async def delete(self, outer_key: str, inner_key: Optional[str] = None):
- """Delete a value from the state."""
- async with self.lock:
- if outer_key in self.data and not inner_key:
- del self.data[outer_key]
- else:
- if inner_key not in self.data[outer_key]:
- raise ValueError(
- f"Key {inner_key} does not exist in the state."
- )
- del self.data[outer_key][inner_key]
- T = TypeVar("T")
- class AsyncPipe(Generic[T]):
- """An asynchronous pipe for processing data with logging capabilities."""
- class PipeConfig(BaseModel):
- """Configuration for a pipe."""
- name: str = "default_pipe"
- max_log_queue_size: int = 100
- class Config:
- extra = "forbid"
- arbitrary_types_allowed = True
- class Input(BaseModel):
- """Input for a pipe."""
- message: Any
- class Config:
- extra = "forbid"
- arbitrary_types_allowed = True
- def __init__(
- self,
- config: PipeConfig,
- run_manager: Optional[RunManager] = None,
- ):
- # TODO - Deprecate
- self._config = config or self.PipeConfig()
- self._run_manager = run_manager or RunManager()
- logger.debug(f"Initialized pipe {self.config.name}")
- @property
- def config(self) -> PipeConfig:
- return self._config
- async def run(
- self,
- input: Input,
- state: Optional[AsyncState],
- run_manager: Optional[RunManager] = None,
- *args: Any,
- **kwargs: Any,
- ) -> AsyncGenerator[T, None]:
- """Run the pipe with logging capabilities."""
- run_manager = run_manager or self._run_manager
- state = state or AsyncState()
- async def wrapped_run() -> AsyncGenerator[Any, None]:
- async with manage_run(run_manager) as run_id: # type: ignore
- async for result in self._run_logic( # type: ignore
- input, state, run_id, *args, **kwargs # type: ignore
- ):
- yield result
- return wrapped_run()
- @abstractmethod
- async def _run_logic(
- self,
- input: Input,
- state: AsyncState,
- run_id: UUID,
- *args: Any,
- **kwargs: Any,
- ) -> AsyncGenerator[T, None]:
- pass
|