1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- import logging
- import os
- from typing import Any
- from openai import AsyncOpenAI, OpenAI
- from core.base.abstractions import GenerationConfig
- from core.base.providers.llm import CompletionConfig, CompletionProvider
- logger = logging.getLogger()
- class OpenAICompletionProvider(CompletionProvider):
- def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
- super().__init__(config)
- if config.provider != "openai":
- logger.error(f"Invalid provider: {config.provider}")
- raise ValueError(
- "OpenAICompletionProvider must be initialized with config with `openai` provider."
- )
- if not os.getenv("OPENAI_API_KEY"):
- logger.error("OpenAI API key not found")
- raise ValueError(
- "OpenAI API key not found. Please set the OPENAI_API_KEY environment variable."
- )
- self.async_client = AsyncOpenAI()
- self.client = OpenAI()
- logger.debug("OpenAICompletionProvider initialized successfully")
- def _get_base_args(self, generation_config: GenerationConfig) -> dict:
- args = {
- "model": generation_config.model,
- "temperature": generation_config.temperature,
- "top_p": generation_config.top_p,
- "stream": generation_config.stream,
- "max_tokens": generation_config.max_tokens_to_sample,
- }
- if generation_config.functions is not None:
- args["functions"] = generation_config.functions
- if generation_config.tools is not None:
- args["tools"] = generation_config.tools
- if generation_config.response_format is not None:
- args["response_format"] = generation_config.response_format
- return args
- async def _execute_task(self, task: dict[str, Any]):
- messages = task["messages"]
- generation_config = task["generation_config"]
- kwargs = task["kwargs"]
- args = self._get_base_args(generation_config)
- args["messages"] = messages
- args = {**args, **kwargs}
- logger.debug(f"Executing async OpenAI task with args: {args}")
- try:
- response = await self.async_client.chat.completions.create(**args)
- logger.debug("Async OpenAI task executed successfully")
- return response
- except Exception as e:
- logger.error(f"Async OpenAI task execution failed: {str(e)}")
- raise
- def _execute_task_sync(self, task: dict[str, Any]):
- messages = task["messages"]
- generation_config = task["generation_config"]
- kwargs = task["kwargs"]
- args = self._get_base_args(generation_config)
- args["messages"] = messages
- args = {**args, **kwargs}
- logger.debug(f"Executing sync OpenAI task with args: {args}")
- try:
- response = self.client.chat.completions.create(**args)
- logger.debug("Sync OpenAI task executed successfully")
- return response
- except Exception as e:
- logger.error(f"Sync OpenAI task execution failed: {str(e)}")
- raise
|