litellm.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import logging
  2. from typing import Any
  3. from core.base.abstractions import GenerationConfig
  4. from core.base.providers.llm import CompletionConfig, CompletionProvider
  5. logger = logging.getLogger()
  6. class LiteLLMCompletionProvider(CompletionProvider):
  7. def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
  8. super().__init__(config)
  9. try:
  10. from litellm import acompletion, completion
  11. self.acompletion = acompletion
  12. self.completion = completion
  13. logger.debug("LiteLLM imported successfully")
  14. except ImportError:
  15. logger.error("Failed to import LiteLLM")
  16. raise ImportError(
  17. "Please install the `litellm` package to use the LiteLLMCompletionProvider."
  18. )
  19. if config.provider != "litellm":
  20. logger.error(f"Invalid provider: {config.provider}")
  21. raise ValueError(
  22. "LiteLLMCompletionProvider must be initialized with config with `litellm` provider."
  23. )
  24. def _get_base_args(self, generation_config: GenerationConfig) -> dict:
  25. args = {
  26. "model": generation_config.model,
  27. "temperature": generation_config.temperature,
  28. "top_p": generation_config.top_p,
  29. "stream": generation_config.stream,
  30. "max_tokens": generation_config.max_tokens_to_sample,
  31. "api_base": generation_config.api_base,
  32. }
  33. if generation_config.functions is not None:
  34. args["functions"] = generation_config.functions
  35. if generation_config.tools is not None:
  36. args["tools"] = generation_config.tools
  37. if generation_config.response_format is not None:
  38. args["response_format"] = generation_config.response_format
  39. return args
  40. async def _execute_task(self, task: dict[str, Any]):
  41. messages = task["messages"]
  42. generation_config = task["generation_config"]
  43. kwargs = task["kwargs"]
  44. args = self._get_base_args(generation_config)
  45. args["messages"] = messages
  46. args = {**args, **kwargs}
  47. return await self.acompletion(**args)
  48. def _execute_task_sync(self, task: dict[str, Any]):
  49. messages = task["messages"]
  50. generation_config = task["generation_config"]
  51. kwargs = task["kwargs"]
  52. args = self._get_base_args(generation_config)
  53. args["messages"] = messages
  54. args = {**args, **kwargs}
  55. try:
  56. return self.completion(**args)
  57. except Exception as e:
  58. logger.error(f"Sync LiteLLM task execution failed: {str(e)}")
  59. raise