litellm.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import logging
  2. from typing import Any
  3. import litellm
  4. from litellm import acompletion, completion
  5. from core.base.abstractions import GenerationConfig
  6. from core.base.providers.llm import CompletionConfig, CompletionProvider
  7. logger = logging.getLogger()
  8. class LiteLLMCompletionProvider(CompletionProvider):
  9. def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
  10. super().__init__(config)
  11. litellm.modify_params = True
  12. self.acompletion = acompletion
  13. self.completion = completion
  14. # if config.provider != "litellm":
  15. # logger.error(f"Invalid provider: {config.provider}")
  16. # raise ValueError(
  17. # "LiteLLMCompletionProvider must be initialized with config with `litellm` provider."
  18. # )
  19. def _get_base_args(
  20. self, generation_config: GenerationConfig
  21. ) -> dict[str, Any]:
  22. args: dict[str, Any] = {
  23. "model": generation_config.model,
  24. "temperature": generation_config.temperature,
  25. "top_p": generation_config.top_p,
  26. "stream": generation_config.stream,
  27. "max_tokens": generation_config.max_tokens_to_sample,
  28. "api_base": generation_config.api_base,
  29. }
  30. # Fix the type errors by properly typing these assignments
  31. if generation_config.functions is not None:
  32. args["functions"] = generation_config.functions
  33. if generation_config.tools is not None:
  34. args["tools"] = generation_config.tools
  35. if generation_config.response_format is not None:
  36. args["response_format"] = generation_config.response_format
  37. return args
  38. async def _execute_task(self, task: dict[str, Any]):
  39. messages = task["messages"]
  40. generation_config = task["generation_config"]
  41. kwargs = task["kwargs"]
  42. args = self._get_base_args(generation_config)
  43. args["messages"] = messages
  44. args = {**args, **kwargs}
  45. logger.debug(
  46. f"Executing LiteLLM task with generation_config={generation_config}"
  47. )
  48. return await self.acompletion(**args)
  49. def _execute_task_sync(self, task: dict[str, Any]):
  50. messages = task["messages"]
  51. generation_config = task["generation_config"]
  52. kwargs = task["kwargs"]
  53. args = self._get_base_args(generation_config)
  54. args["messages"] = messages
  55. args = {**args, **kwargs}
  56. logger.debug(
  57. f"Executing LiteLLM task with generation_config={generation_config}"
  58. )
  59. try:
  60. return self.completion(**args)
  61. except Exception as e:
  62. logger.error(f"Sync LiteLLM task execution failed: {str(e)}")
  63. raise