llm.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import asyncio
  2. import logging
  3. import random
  4. import time
  5. from abc import abstractmethod
  6. from concurrent.futures import ThreadPoolExecutor
  7. from typing import Any, AsyncGenerator, Generator, Optional
  8. from litellm import AuthenticationError
  9. from core.base.abstractions import (
  10. GenerationConfig,
  11. LLMChatCompletion,
  12. LLMChatCompletionChunk,
  13. )
  14. from .base import Provider, ProviderConfig
  15. logger = logging.getLogger()
  16. class CompletionConfig(ProviderConfig):
  17. provider: Optional[str] = None
  18. generation_config: GenerationConfig = GenerationConfig()
  19. concurrent_request_limit: int = 256
  20. max_retries: int = 8
  21. initial_backoff: float = 1.0
  22. max_backoff: float = 64.0
  23. def validate_config(self) -> None:
  24. if not self.provider:
  25. raise ValueError("Provider must be set.")
  26. if self.provider not in self.supported_providers:
  27. raise ValueError(f"Provider '{self.provider}' is not supported.")
  28. @property
  29. def supported_providers(self) -> list[str]:
  30. return ["litellm", "openai"]
  31. class CompletionProvider(Provider):
  32. def __init__(self, config: CompletionConfig) -> None:
  33. if not isinstance(config, CompletionConfig):
  34. raise ValueError(
  35. "CompletionProvider must be initialized with a `CompletionConfig`."
  36. )
  37. logger.info(f"Initializing CompletionProvider with config: {config}")
  38. super().__init__(config)
  39. self.config: CompletionConfig = config
  40. self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
  41. self.thread_pool = ThreadPoolExecutor(
  42. max_workers=config.concurrent_request_limit
  43. )
  44. async def _execute_with_backoff_async(self, task: dict[str, Any]):
  45. retries = 0
  46. backoff = self.config.initial_backoff
  47. while retries < self.config.max_retries:
  48. try:
  49. async with self.semaphore:
  50. return await self._execute_task(task)
  51. except AuthenticationError as e:
  52. raise
  53. except Exception as e:
  54. logger.warning(
  55. f"Request failed (attempt {retries + 1}): {str(e)}"
  56. )
  57. retries += 1
  58. if retries == self.config.max_retries:
  59. raise
  60. await asyncio.sleep(random.uniform(0, backoff))
  61. backoff = min(backoff * 2, self.config.max_backoff)
  62. async def _execute_with_backoff_async_stream(
  63. self, task: dict[str, Any]
  64. ) -> AsyncGenerator[Any, None]:
  65. retries = 0
  66. backoff = self.config.initial_backoff
  67. while retries < self.config.max_retries:
  68. try:
  69. async with self.semaphore:
  70. async for chunk in await self._execute_task(task):
  71. yield chunk
  72. return # Successful completion of the stream
  73. except AuthenticationError as e:
  74. raise
  75. except Exception as e:
  76. logger.warning(
  77. f"Streaming request failed (attempt {retries + 1}): {str(e)}"
  78. )
  79. retries += 1
  80. if retries == self.config.max_retries:
  81. raise
  82. await asyncio.sleep(random.uniform(0, backoff))
  83. backoff = min(backoff * 2, self.config.max_backoff)
  84. def _execute_with_backoff_sync(self, task: dict[str, Any]):
  85. retries = 0
  86. backoff = self.config.initial_backoff
  87. while retries < self.config.max_retries:
  88. try:
  89. return self._execute_task_sync(task)
  90. except Exception as e:
  91. logger.warning(
  92. f"Request failed (attempt {retries + 1}): {str(e)}"
  93. )
  94. retries += 1
  95. if retries == self.config.max_retries:
  96. raise
  97. time.sleep(random.uniform(0, backoff))
  98. backoff = min(backoff * 2, self.config.max_backoff)
  99. def _execute_with_backoff_sync_stream(
  100. self, task: dict[str, Any]
  101. ) -> Generator[Any, None, None]:
  102. retries = 0
  103. backoff = self.config.initial_backoff
  104. while retries < self.config.max_retries:
  105. try:
  106. yield from self._execute_task_sync(task)
  107. return # Successful completion of the stream
  108. except Exception as e:
  109. logger.warning(
  110. f"Streaming request failed (attempt {retries + 1}): {str(e)}"
  111. )
  112. retries += 1
  113. if retries == self.config.max_retries:
  114. raise
  115. time.sleep(random.uniform(0, backoff))
  116. backoff = min(backoff * 2, self.config.max_backoff)
  117. @abstractmethod
  118. async def _execute_task(self, task: dict[str, Any]):
  119. pass
  120. @abstractmethod
  121. def _execute_task_sync(self, task: dict[str, Any]):
  122. pass
  123. async def aget_completion(
  124. self,
  125. messages: list[dict],
  126. generation_config: GenerationConfig,
  127. **kwargs,
  128. ) -> LLMChatCompletion:
  129. task = {
  130. "messages": messages,
  131. "generation_config": generation_config,
  132. "kwargs": kwargs,
  133. }
  134. if modalities := kwargs.get("modalities"):
  135. task["modalities"] = modalities
  136. response = await self._execute_with_backoff_async(task)
  137. return LLMChatCompletion(**response.dict())
  138. async def aget_completion_stream(
  139. self,
  140. messages: list[dict],
  141. generation_config: GenerationConfig,
  142. **kwargs,
  143. ) -> AsyncGenerator[LLMChatCompletionChunk, None]:
  144. generation_config.stream = True
  145. task = {
  146. "messages": messages,
  147. "generation_config": generation_config,
  148. "kwargs": kwargs,
  149. }
  150. async for chunk in self._execute_with_backoff_async_stream(task):
  151. yield LLMChatCompletionChunk(**chunk.dict())
  152. def get_completion_stream(
  153. self,
  154. messages: list[dict],
  155. generation_config: GenerationConfig,
  156. **kwargs,
  157. ) -> Generator[LLMChatCompletionChunk, None, None]:
  158. generation_config.stream = True
  159. task = {
  160. "messages": messages,
  161. "generation_config": generation_config,
  162. "kwargs": kwargs,
  163. }
  164. for chunk in self._execute_with_backoff_sync_stream(task):
  165. yield LLMChatCompletionChunk(**chunk.dict())