openai.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import logging
  2. import os
  3. from typing import Any
  4. from openai import AsyncOpenAI, OpenAI
  5. from core.base.abstractions import GenerationConfig
  6. from core.base.providers.llm import CompletionConfig, CompletionProvider
  7. logger = logging.getLogger()
  8. class OpenAICompletionProvider(CompletionProvider):
  9. def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
  10. super().__init__(config)
  11. if config.provider != "openai":
  12. logger.error(f"Invalid provider: {config.provider}")
  13. raise ValueError(
  14. "OpenAICompletionProvider must be initialized with config with `openai` provider."
  15. )
  16. if not os.getenv("OPENAI_API_KEY"):
  17. logger.error("OpenAI API key not found")
  18. raise ValueError(
  19. "OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
  20. )
  21. self.async_client = AsyncOpenAI()
  22. self.client = OpenAI()
  23. logger.debug("OpenAICompletionProvider initialized successfully")
  24. def _get_base_args(self, generation_config: GenerationConfig) -> dict:
  25. args = {
  26. "model": generation_config.model,
  27. "temperature": generation_config.temperature,
  28. "top_p": generation_config.top_p,
  29. "stream": generation_config.stream,
  30. "max_tokens": generation_config.max_tokens_to_sample,
  31. }
  32. if generation_config.functions is not None:
  33. args["functions"] = generation_config.functions
  34. if generation_config.tools is not None:
  35. args["tools"] = generation_config.tools
  36. if generation_config.response_format is not None:
  37. args["response_format"] = generation_config.response_format
  38. return args
  39. async def _execute_task(self, task: dict[str, Any]):
  40. messages = task["messages"]
  41. generation_config = task["generation_config"]
  42. kwargs = task["kwargs"]
  43. args = self._get_base_args(generation_config)
  44. args["messages"] = messages
  45. args = {**args, **kwargs}
  46. logger.debug(f"Executing async OpenAI task with args: {args}")
  47. try:
  48. response = await self.async_client.chat.completions.create(**args)
  49. logger.debug("Async OpenAI task executed successfully")
  50. return response
  51. except Exception as e:
  52. logger.error(f"Async OpenAI task execution failed: {str(e)}")
  53. raise
  54. def _execute_task_sync(self, task: dict[str, Any]):
  55. messages = task["messages"]
  56. generation_config = task["generation_config"]
  57. kwargs = task["kwargs"]
  58. args = self._get_base_args(generation_config)
  59. args["messages"] = messages
  60. args = {**args, **kwargs}
  61. logger.debug(f"Executing sync OpenAI task with args: {args}")
  62. try:
  63. response = self.client.chat.completions.create(**args)
  64. logger.debug("Sync OpenAI task executed successfully")
  65. return response
  66. except Exception as e:
  67. logger.error(f"Sync OpenAI task execution failed: {str(e)}")
  68. raise