base_pipe.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import asyncio
  2. import logging
  3. from abc import abstractmethod
  4. from enum import Enum
  5. from typing import Any, AsyncGenerator, Generic, Optional, TypeVar
  6. from uuid import UUID
  7. from pydantic import BaseModel
  8. from core.base.logger.run_manager import RunManager, manage_run
  9. logger = logging.getLogger()
  10. class AsyncState:
  11. """A state object for storing data between pipes."""
  12. def __init__(self):
  13. self.data = {}
  14. self.lock = asyncio.Lock()
  15. async def update(self, outer_key: str, values: dict):
  16. """Update the state with new values."""
  17. async with self.lock:
  18. if not isinstance(values, dict):
  19. raise ValueError("Values must be contained in a dictionary.")
  20. if outer_key not in self.data:
  21. self.data[outer_key] = {}
  22. for inner_key, inner_value in values.items():
  23. self.data[outer_key][inner_key] = inner_value
  24. async def get(self, outer_key: str, inner_key: str, default=None):
  25. """Get a value from the state."""
  26. async with self.lock:
  27. if outer_key not in self.data:
  28. raise ValueError(
  29. f"Key {outer_key} does not exist in the state."
  30. )
  31. if inner_key not in self.data[outer_key]:
  32. return default or {}
  33. return self.data[outer_key][inner_key]
  34. async def delete(self, outer_key: str, inner_key: Optional[str] = None):
  35. """Delete a value from the state."""
  36. async with self.lock:
  37. if outer_key in self.data and not inner_key:
  38. del self.data[outer_key]
  39. else:
  40. if inner_key not in self.data[outer_key]:
  41. raise ValueError(
  42. f"Key {inner_key} does not exist in the state."
  43. )
  44. del self.data[outer_key][inner_key]
  45. T = TypeVar("T")
  46. class AsyncPipe(Generic[T]):
  47. """An asynchronous pipe for processing data with logging capabilities."""
  48. class PipeConfig(BaseModel):
  49. """Configuration for a pipe."""
  50. name: str = "default_pipe"
  51. max_log_queue_size: int = 100
  52. class Config:
  53. extra = "forbid"
  54. arbitrary_types_allowed = True
  55. class Input(BaseModel):
  56. """Input for a pipe."""
  57. message: Any
  58. class Config:
  59. extra = "forbid"
  60. arbitrary_types_allowed = True
  61. def __init__(
  62. self,
  63. config: PipeConfig,
  64. run_manager: Optional[RunManager] = None,
  65. ):
  66. # TODO - Deprecate
  67. self._config = config or self.PipeConfig()
  68. self._run_manager = run_manager or RunManager()
  69. logger.debug(f"Initialized pipe {self.config.name}")
  70. @property
  71. def config(self) -> PipeConfig:
  72. return self._config
  73. async def run(
  74. self,
  75. input: Input,
  76. state: Optional[AsyncState],
  77. run_manager: Optional[RunManager] = None,
  78. *args: Any,
  79. **kwargs: Any,
  80. ) -> AsyncGenerator[T, None]:
  81. """Run the pipe with logging capabilities."""
  82. run_manager = run_manager or self._run_manager
  83. state = state or AsyncState()
  84. async def wrapped_run() -> AsyncGenerator[Any, None]:
  85. async with manage_run(run_manager) as run_id: # type: ignore
  86. async for result in self._run_logic( # type: ignore
  87. input, state, run_id, *args, **kwargs # type: ignore
  88. ):
  89. yield result
  90. return wrapped_run()
  91. @abstractmethod
  92. async def _run_logic(
  93. self,
  94. input: Input,
  95. state: AsyncState,
  96. run_id: UUID,
  97. *args: Any,
  98. **kwargs: Any,
  99. ) -> AsyncGenerator[T, None]:
  100. pass