123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233 |
- import asyncio
- import logging
- import random
- import time
- from abc import abstractmethod
- from concurrent.futures import ThreadPoolExecutor
- from typing import Any, AsyncGenerator, Generator, Optional
- from litellm import AuthenticationError
- from core.base.abstractions import (
- GenerationConfig,
- LLMChatCompletion,
- LLMChatCompletionChunk,
- )
- from .base import Provider, ProviderConfig
- logger = logging.getLogger()
- class CompletionConfig(ProviderConfig):
- provider: Optional[str] = None
- generation_config: Optional[GenerationConfig] = None
- concurrent_request_limit: int = 256
- max_retries: int = 3
- initial_backoff: float = 1.0
- max_backoff: float = 64.0
- request_timeout: float = 15.0
- def validate_config(self) -> None:
- if not self.provider:
- raise ValueError("Provider must be set.")
- if self.provider not in self.supported_providers:
- raise ValueError(f"Provider '{self.provider}' is not supported.")
- @property
- def supported_providers(self) -> list[str]:
- return ["anthropic", "litellm", "openai", "r2r"]
- class CompletionProvider(Provider):
- def __init__(self, config: CompletionConfig) -> None:
- if not isinstance(config, CompletionConfig):
- raise ValueError(
- "CompletionProvider must be initialized with a `CompletionConfig`."
- )
- logger.info(f"Initializing CompletionProvider with config: {config}")
- super().__init__(config)
- self.config: CompletionConfig = config
- self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
- self.thread_pool = ThreadPoolExecutor(
- max_workers=config.concurrent_request_limit
- )
- async def _execute_with_backoff_async(
- self,
- task: dict[str, Any],
- apply_timeout: bool = False,
- ):
- retries = 0
- backoff = self.config.initial_backoff
- while retries < self.config.max_retries:
- try:
- # A semaphore allows us to limit concurrent requests
- async with self.semaphore:
- if not apply_timeout:
- return await self._execute_task(task)
- try: # Use asyncio.wait_for to set a timeout for the request
- return await asyncio.wait_for(
- self._execute_task(task),
- timeout=self.config.request_timeout,
- )
- except asyncio.TimeoutError as e:
- raise TimeoutError(
- f"Request timed out after {self.config.request_timeout} seconds"
- ) from e
- except AuthenticationError:
- raise
- except Exception as e:
- logger.warning(
- f"Request failed (attempt {retries + 1}): {str(e)}"
- )
- retries += 1
- if retries == self.config.max_retries:
- raise
- await asyncio.sleep(random.uniform(0, backoff))
- backoff = min(backoff * 2, self.config.max_backoff)
- async def _execute_with_backoff_async_stream(
- self, task: dict[str, Any]
- ) -> AsyncGenerator[Any, None]:
- retries = 0
- backoff = self.config.initial_backoff
- while retries < self.config.max_retries:
- try:
- async with self.semaphore:
- async for chunk in await self._execute_task(task):
- yield chunk
- return # Successful completion of the stream
- except AuthenticationError:
- raise
- except Exception as e:
- logger.warning(
- f"Streaming request failed (attempt {retries + 1}): {str(e)}"
- )
- retries += 1
- if retries == self.config.max_retries:
- raise
- await asyncio.sleep(random.uniform(0, backoff))
- backoff = min(backoff * 2, self.config.max_backoff)
- def _execute_with_backoff_sync(
- self,
- task: dict[str, Any],
- apply_timeout: bool = False,
- ):
- retries = 0
- backoff = self.config.initial_backoff
- while retries < self.config.max_retries:
- if not apply_timeout:
- return self._execute_task_sync(task)
- try:
- future = self.thread_pool.submit(self._execute_task_sync, task)
- return future.result(timeout=self.config.request_timeout)
- except TimeoutError as e:
- raise TimeoutError(
- f"Request timed out after {self.config.request_timeout} seconds"
- ) from e
- except Exception as e:
- logger.warning(
- f"Request failed (attempt {retries + 1}): {str(e)}"
- )
- retries += 1
- if retries == self.config.max_retries:
- raise
- time.sleep(random.uniform(0, backoff))
- backoff = min(backoff * 2, self.config.max_backoff)
- def _execute_with_backoff_sync_stream(
- self, task: dict[str, Any]
- ) -> Generator[Any, None, None]:
- retries = 0
- backoff = self.config.initial_backoff
- while retries < self.config.max_retries:
- try:
- yield from self._execute_task_sync(task)
- return # Successful completion of the stream
- except Exception as e:
- logger.warning(
- f"Streaming request failed (attempt {retries + 1}): {str(e)}"
- )
- retries += 1
- if retries == self.config.max_retries:
- raise
- time.sleep(random.uniform(0, backoff))
- backoff = min(backoff * 2, self.config.max_backoff)
- @abstractmethod
- async def _execute_task(self, task: dict[str, Any]):
- pass
- @abstractmethod
- def _execute_task_sync(self, task: dict[str, Any]):
- pass
- async def aget_completion(
- self,
- messages: list[dict],
- generation_config: GenerationConfig,
- apply_timeout: bool = False,
- **kwargs,
- ) -> LLMChatCompletion:
- task = {
- "messages": messages,
- "generation_config": generation_config,
- "kwargs": kwargs,
- }
- response = await self._execute_with_backoff_async(
- task=task, apply_timeout=apply_timeout
- )
- return LLMChatCompletion(**response.dict())
- async def aget_completion_stream(
- self,
- messages: list[dict],
- generation_config: GenerationConfig,
- **kwargs,
- ) -> AsyncGenerator[LLMChatCompletionChunk, None]:
- generation_config.stream = True
- task = {
- "messages": messages,
- "generation_config": generation_config,
- "kwargs": kwargs,
- }
- async for chunk in self._execute_with_backoff_async_stream(task):
- if isinstance(chunk, dict):
- yield LLMChatCompletionChunk(**chunk)
- continue
- if chunk.choices and len(chunk.choices) > 0:
- chunk.choices[0].finish_reason = (
- chunk.choices[0].finish_reason
- if chunk.choices[0].finish_reason != ""
- else None
- ) # handle error output conventions
- chunk.choices[0].finish_reason = (
- chunk.choices[0].finish_reason
- if chunk.choices[0].finish_reason != "eos"
- else "stop"
- ) # hardcode `eos` to `stop` for consistency
- try:
- yield LLMChatCompletionChunk(**(chunk.dict()))
- except Exception as e:
- logger.error(f"Error parsing chunk: {e}")
- yield LLMChatCompletionChunk(**(chunk.as_dict()))
- def get_completion_stream(
- self,
- messages: list[dict],
- generation_config: GenerationConfig,
- **kwargs,
- ) -> Generator[LLMChatCompletionChunk, None, None]:
- generation_config.stream = True
- task = {
- "messages": messages,
- "generation_config": generation_config,
- "kwargs": kwargs,
- }
- for chunk in self._execute_with_backoff_sync_stream(task):
- yield LLMChatCompletionChunk(**chunk.dict())
|