123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- import asyncio
- import logging
- import random
- import time
- from abc import abstractmethod
- from concurrent.futures import ThreadPoolExecutor
- from typing import Any, AsyncGenerator, Generator, Optional
- from litellm import AuthenticationError
- from core.base.abstractions import (
- GenerationConfig,
- LLMChatCompletion,
- LLMChatCompletionChunk,
- )
- from .base import Provider, ProviderConfig
- logger = logging.getLogger()
- class CompletionConfig(ProviderConfig):
- provider: Optional[str] = None
- generation_config: GenerationConfig = GenerationConfig()
- concurrent_request_limit: int = 256
- max_retries: int = 8
- initial_backoff: float = 1.0
- max_backoff: float = 64.0
- def validate_config(self) -> None:
- if not self.provider:
- raise ValueError("Provider must be set.")
- if self.provider not in self.supported_providers:
- raise ValueError(f"Provider '{self.provider}' is not supported.")
- @property
- def supported_providers(self) -> list[str]:
- return ["litellm", "openai"]
- class CompletionProvider(Provider):
- def __init__(self, config: CompletionConfig) -> None:
- if not isinstance(config, CompletionConfig):
- raise ValueError(
- "CompletionProvider must be initialized with a `CompletionConfig`."
- )
- logger.info(f"Initializing CompletionProvider with config: {config}")
- super().__init__(config)
- self.config: CompletionConfig = config
- self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
- self.thread_pool = ThreadPoolExecutor(
- max_workers=config.concurrent_request_limit
- )
- async def _execute_with_backoff_async(self, task: dict[str, Any]):
- retries = 0
- backoff = self.config.initial_backoff
- while retries < self.config.max_retries:
- try:
- async with self.semaphore:
- return await self._execute_task(task)
- except AuthenticationError as e:
- raise
- except Exception as e:
- logger.warning(
- f"Request failed (attempt {retries + 1}): {str(e)}"
- )
- retries += 1
- if retries == self.config.max_retries:
- raise
- await asyncio.sleep(random.uniform(0, backoff))
- backoff = min(backoff * 2, self.config.max_backoff)
- async def _execute_with_backoff_async_stream(
- self, task: dict[str, Any]
- ) -> AsyncGenerator[Any, None]:
- retries = 0
- backoff = self.config.initial_backoff
- while retries < self.config.max_retries:
- try:
- async with self.semaphore:
- async for chunk in await self._execute_task(task):
- yield chunk
- return # Successful completion of the stream
- except AuthenticationError as e:
- raise
- except Exception as e:
- logger.warning(
- f"Streaming request failed (attempt {retries + 1}): {str(e)}"
- )
- retries += 1
- if retries == self.config.max_retries:
- raise
- await asyncio.sleep(random.uniform(0, backoff))
- backoff = min(backoff * 2, self.config.max_backoff)
- def _execute_with_backoff_sync(self, task: dict[str, Any]):
- retries = 0
- backoff = self.config.initial_backoff
- while retries < self.config.max_retries:
- try:
- return self._execute_task_sync(task)
- except Exception as e:
- logger.warning(
- f"Request failed (attempt {retries + 1}): {str(e)}"
- )
- retries += 1
- if retries == self.config.max_retries:
- raise
- time.sleep(random.uniform(0, backoff))
- backoff = min(backoff * 2, self.config.max_backoff)
- def _execute_with_backoff_sync_stream(
- self, task: dict[str, Any]
- ) -> Generator[Any, None, None]:
- retries = 0
- backoff = self.config.initial_backoff
- while retries < self.config.max_retries:
- try:
- yield from self._execute_task_sync(task)
- return # Successful completion of the stream
- except Exception as e:
- logger.warning(
- f"Streaming request failed (attempt {retries + 1}): {str(e)}"
- )
- retries += 1
- if retries == self.config.max_retries:
- raise
- time.sleep(random.uniform(0, backoff))
- backoff = min(backoff * 2, self.config.max_backoff)
- @abstractmethod
- async def _execute_task(self, task: dict[str, Any]):
- pass
- @abstractmethod
- def _execute_task_sync(self, task: dict[str, Any]):
- pass
- async def aget_completion(
- self,
- messages: list[dict],
- generation_config: GenerationConfig,
- **kwargs,
- ) -> LLMChatCompletion:
- task = {
- "messages": messages,
- "generation_config": generation_config,
- "kwargs": kwargs,
- }
- if modalities := kwargs.get("modalities"):
- task["modalities"] = modalities
- response = await self._execute_with_backoff_async(task)
- return LLMChatCompletion(**response.dict())
- async def aget_completion_stream(
- self,
- messages: list[dict],
- generation_config: GenerationConfig,
- **kwargs,
- ) -> AsyncGenerator[LLMChatCompletionChunk, None]:
- generation_config.stream = True
- task = {
- "messages": messages,
- "generation_config": generation_config,
- "kwargs": kwargs,
- }
- async for chunk in self._execute_with_backoff_async_stream(task):
- yield LLMChatCompletionChunk(**chunk.dict())
- def get_completion_stream(
- self,
- messages: list[dict],
- generation_config: GenerationConfig,
- **kwargs,
- ) -> Generator[LLMChatCompletionChunk, None, None]:
- generation_config.stream = True
- task = {
- "messages": messages,
- "generation_config": generation_config,
- "kwargs": kwargs,
- }
- for chunk in self._execute_with_backoff_sync_stream(task):
- yield LLMChatCompletionChunk(**chunk.dict())
|