llm.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import asyncio
  2. import logging
  3. import random
  4. import time
  5. from abc import abstractmethod
  6. from concurrent.futures import ThreadPoolExecutor
  7. from typing import Any, AsyncGenerator, Generator, Optional
  8. from litellm import AuthenticationError
  9. from core.base.abstractions import (
  10. GenerationConfig,
  11. LLMChatCompletion,
  12. LLMChatCompletionChunk,
  13. )
  14. from .base import Provider, ProviderConfig
  15. logger = logging.getLogger()
  16. class CompletionConfig(ProviderConfig):
  17. provider: Optional[str] = None
  18. generation_config: Optional[GenerationConfig] = None
  19. concurrent_request_limit: int = 256
  20. max_retries: int = 3
  21. initial_backoff: float = 1.0
  22. max_backoff: float = 64.0
  23. request_timeout: float = 15.0
  24. def validate_config(self) -> None:
  25. if not self.provider:
  26. raise ValueError("Provider must be set.")
  27. if self.provider not in self.supported_providers:
  28. raise ValueError(f"Provider '{self.provider}' is not supported.")
  29. @property
  30. def supported_providers(self) -> list[str]:
  31. return ["anthropic", "litellm", "openai", "r2r"]
  32. class CompletionProvider(Provider):
  33. def __init__(self, config: CompletionConfig) -> None:
  34. if not isinstance(config, CompletionConfig):
  35. raise ValueError(
  36. "CompletionProvider must be initialized with a `CompletionConfig`."
  37. )
  38. logger.info(f"Initializing CompletionProvider with config: {config}")
  39. super().__init__(config)
  40. self.config: CompletionConfig = config
  41. self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
  42. self.thread_pool = ThreadPoolExecutor(
  43. max_workers=config.concurrent_request_limit
  44. )
  45. async def _execute_with_backoff_async(
  46. self,
  47. task: dict[str, Any],
  48. apply_timeout: bool = False,
  49. ):
  50. retries = 0
  51. backoff = self.config.initial_backoff
  52. while retries < self.config.max_retries:
  53. try:
  54. # A semaphore allows us to limit concurrent requests
  55. async with self.semaphore:
  56. if not apply_timeout:
  57. return await self._execute_task(task)
  58. try: # Use asyncio.wait_for to set a timeout for the request
  59. return await asyncio.wait_for(
  60. self._execute_task(task),
  61. timeout=self.config.request_timeout,
  62. )
  63. except asyncio.TimeoutError as e:
  64. raise TimeoutError(
  65. f"Request timed out after {self.config.request_timeout} seconds"
  66. ) from e
  67. except AuthenticationError:
  68. raise
  69. except Exception as e:
  70. logger.warning(
  71. f"Request failed (attempt {retries + 1}): {str(e)}"
  72. )
  73. retries += 1
  74. if retries == self.config.max_retries:
  75. raise
  76. await asyncio.sleep(random.uniform(0, backoff))
  77. backoff = min(backoff * 2, self.config.max_backoff)
  78. async def _execute_with_backoff_async_stream(
  79. self, task: dict[str, Any]
  80. ) -> AsyncGenerator[Any, None]:
  81. retries = 0
  82. backoff = self.config.initial_backoff
  83. while retries < self.config.max_retries:
  84. try:
  85. async with self.semaphore:
  86. async for chunk in await self._execute_task(task):
  87. yield chunk
  88. return # Successful completion of the stream
  89. except AuthenticationError:
  90. raise
  91. except Exception as e:
  92. logger.warning(
  93. f"Streaming request failed (attempt {retries + 1}): {str(e)}"
  94. )
  95. retries += 1
  96. if retries == self.config.max_retries:
  97. raise
  98. await asyncio.sleep(random.uniform(0, backoff))
  99. backoff = min(backoff * 2, self.config.max_backoff)
  100. def _execute_with_backoff_sync(
  101. self,
  102. task: dict[str, Any],
  103. apply_timeout: bool = False,
  104. ):
  105. retries = 0
  106. backoff = self.config.initial_backoff
  107. while retries < self.config.max_retries:
  108. if not apply_timeout:
  109. return self._execute_task_sync(task)
  110. try:
  111. future = self.thread_pool.submit(self._execute_task_sync, task)
  112. return future.result(timeout=self.config.request_timeout)
  113. except TimeoutError as e:
  114. raise TimeoutError(
  115. f"Request timed out after {self.config.request_timeout} seconds"
  116. ) from e
  117. except Exception as e:
  118. logger.warning(
  119. f"Request failed (attempt {retries + 1}): {str(e)}"
  120. )
  121. retries += 1
  122. if retries == self.config.max_retries:
  123. raise
  124. time.sleep(random.uniform(0, backoff))
  125. backoff = min(backoff * 2, self.config.max_backoff)
  126. def _execute_with_backoff_sync_stream(
  127. self, task: dict[str, Any]
  128. ) -> Generator[Any, None, None]:
  129. retries = 0
  130. backoff = self.config.initial_backoff
  131. while retries < self.config.max_retries:
  132. try:
  133. yield from self._execute_task_sync(task)
  134. return # Successful completion of the stream
  135. except Exception as e:
  136. logger.warning(
  137. f"Streaming request failed (attempt {retries + 1}): {str(e)}"
  138. )
  139. retries += 1
  140. if retries == self.config.max_retries:
  141. raise
  142. time.sleep(random.uniform(0, backoff))
  143. backoff = min(backoff * 2, self.config.max_backoff)
  144. @abstractmethod
  145. async def _execute_task(self, task: dict[str, Any]):
  146. pass
  147. @abstractmethod
  148. def _execute_task_sync(self, task: dict[str, Any]):
  149. pass
  150. async def aget_completion(
  151. self,
  152. messages: list[dict],
  153. generation_config: GenerationConfig,
  154. apply_timeout: bool = False,
  155. **kwargs,
  156. ) -> LLMChatCompletion:
  157. task = {
  158. "messages": messages,
  159. "generation_config": generation_config,
  160. "kwargs": kwargs,
  161. }
  162. response = await self._execute_with_backoff_async(
  163. task=task, apply_timeout=apply_timeout
  164. )
  165. return LLMChatCompletion(**response.dict())
  166. async def aget_completion_stream(
  167. self,
  168. messages: list[dict],
  169. generation_config: GenerationConfig,
  170. **kwargs,
  171. ) -> AsyncGenerator[LLMChatCompletionChunk, None]:
  172. generation_config.stream = True
  173. task = {
  174. "messages": messages,
  175. "generation_config": generation_config,
  176. "kwargs": kwargs,
  177. }
  178. async for chunk in self._execute_with_backoff_async_stream(task):
  179. if isinstance(chunk, dict):
  180. yield LLMChatCompletionChunk(**chunk)
  181. continue
  182. if chunk.choices and len(chunk.choices) > 0:
  183. chunk.choices[0].finish_reason = (
  184. chunk.choices[0].finish_reason
  185. if chunk.choices[0].finish_reason != ""
  186. else None
  187. ) # handle error output conventions
  188. chunk.choices[0].finish_reason = (
  189. chunk.choices[0].finish_reason
  190. if chunk.choices[0].finish_reason != "eos"
  191. else "stop"
  192. ) # hardcode `eos` to `stop` for consistency
  193. try:
  194. yield LLMChatCompletionChunk(**(chunk.dict()))
  195. except Exception as e:
  196. logger.error(f"Error parsing chunk: {e}")
  197. yield LLMChatCompletionChunk(**(chunk.as_dict()))
  198. def get_completion_stream(
  199. self,
  200. messages: list[dict],
  201. generation_config: GenerationConfig,
  202. **kwargs,
  203. ) -> Generator[LLMChatCompletionChunk, None, None]:
  204. generation_config.stream = True
  205. task = {
  206. "messages": messages,
  207. "generation_config": generation_config,
  208. "kwargs": kwargs,
  209. }
  210. for chunk in self._execute_with_backoff_sync_stream(task):
  211. yield LLMChatCompletionChunk(**chunk.dict())