123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534 |
- import logging
- import os
- from typing import Any
- from openai import AsyncAzureOpenAI, AsyncOpenAI, OpenAI
- from core.base.abstractions import GenerationConfig
- from core.base.providers.llm import CompletionConfig, CompletionProvider
- from .utils import resize_base64_image
- logger = logging.getLogger()
- class OpenAICompletionProvider(CompletionProvider):
- def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
- super().__init__(config)
- self.openai_client = None
- self.async_openai_client = None
- self.azure_client = None
- self.async_azure_client = None
- self.deepseek_client = None
- self.async_deepseek_client = None
- self.ollama_client = None
- self.async_ollama_client = None
- self.lmstudio_client = None
- self.async_lmstudio_client = None
- # NEW: Azure Foundry clients using the Azure Inference API
- self.azure_foundry_client = None
- self.async_azure_foundry_client = None
- # Initialize OpenAI clients if credentials exist
- if os.getenv("OPENAI_API_KEY"):
- #self.openai_client = OpenAI()
- #self.async_openai_client = AsyncOpenAI()
- self.openai_client = OpenAI(
- api_key="sk-j9Uwupu0NPZtdDS_IfEZlRWpX1JgFyZFLZProkesy2QbtqMs16pDnylAozU",
- base_url="http://172.16.12.13:3000/v1"
- )
- self.async_openai_client = AsyncOpenAI(
- api_key="sk-j9Uwupu0NPZtdDS_IfEZlRWpX1JgFyZFLZProkesy2QbtqMs16pDnylAozU",
- base_url="http://172.16.12.13:3000/v1"
- )
- logger.debug("OpenAI clients initialized successfully")
- # Initialize Azure OpenAI clients if credentials exist
- azure_api_key = os.getenv("AZURE_API_KEY")
- azure_api_base = os.getenv("AZURE_API_BASE")
- if azure_api_key and azure_api_base:
- self.azure_client = AsyncAzureOpenAI(
- api_key=azure_api_key,
- api_version=os.getenv(
- "AZURE_API_VERSION", "2024-02-15-preview"
- ),
- azure_endpoint=azure_api_base,
- )
- self.async_azure_client = AsyncAzureOpenAI(
- api_key=azure_api_key,
- api_version=os.getenv(
- "AZURE_API_VERSION", "2024-02-15-preview"
- ),
- azure_endpoint=azure_api_base,
- )
- logger.debug("Azure OpenAI clients initialized successfully")
- # Initialize Deepseek clients if credentials exist
- deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
- deepseek_api_base = os.getenv(
- "DEEPSEEK_API_BASE", "https://api.deepseek.com"
- )
- if deepseek_api_key and deepseek_api_base:
- self.deepseek_client = OpenAI(
- api_key=deepseek_api_key,
- base_url=deepseek_api_base,
- )
- self.async_deepseek_client = AsyncOpenAI(
- api_key=deepseek_api_key,
- base_url=deepseek_api_base,
- )
- logger.debug("Deepseek OpenAI clients initialized successfully")
- # Initialize Ollama clients with default API key
- ollama_api_base = os.getenv(
- "OLLAMA_API_BASE", "http://localhost:11434/v1"
- )
- if ollama_api_base:
- self.ollama_client = OpenAI(
- api_key=os.getenv("OLLAMA_API_KEY", "dummy"),
- base_url=ollama_api_base,
- )
- self.async_ollama_client = AsyncOpenAI(
- api_key=os.getenv("OLLAMA_API_KEY", "dummy"),
- base_url=ollama_api_base,
- )
- logger.debug("Ollama OpenAI clients initialized successfully")
- # Initialize LMStudio clients
- lmstudio_api_base = os.getenv(
- "LMSTUDIO_API_BASE", "http://localhost:1234/v1"
- )
- if lmstudio_api_base:
- self.lmstudio_client = OpenAI(
- api_key=os.getenv("LMSTUDIO_API_KEY", "lm-studio"),
- base_url=lmstudio_api_base,
- )
- self.async_lmstudio_client = AsyncOpenAI(
- api_key=os.getenv("LMSTUDIO_API_KEY", "lm-studio"),
- base_url=lmstudio_api_base,
- )
- logger.debug("LMStudio OpenAI clients initialized successfully")
- # Initialize Azure Foundry clients if credentials exist.
- # These use the Azure Inference API (currently pasted into this handler).
- azure_foundry_api_key = os.getenv("AZURE_FOUNDRY_API_KEY")
- azure_foundry_api_endpoint = os.getenv("AZURE_FOUNDRY_API_ENDPOINT")
- if azure_foundry_api_key and azure_foundry_api_endpoint:
- from azure.ai.inference import (
- ChatCompletionsClient as AzureChatCompletionsClient,
- )
- from azure.ai.inference.aio import (
- ChatCompletionsClient as AsyncAzureChatCompletionsClient,
- )
- from azure.core.credentials import AzureKeyCredential
- self.azure_foundry_client = AzureChatCompletionsClient(
- endpoint=azure_foundry_api_endpoint,
- credential=AzureKeyCredential(azure_foundry_api_key),
- api_version=os.getenv(
- "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
- ),
- )
- self.async_azure_foundry_client = AsyncAzureChatCompletionsClient(
- endpoint=azure_foundry_api_endpoint,
- credential=AzureKeyCredential(azure_foundry_api_key),
- api_version=os.getenv(
- "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
- ),
- )
- logger.debug("Azure Foundry clients initialized successfully")
- if not any(
- [
- self.openai_client,
- self.azure_client,
- self.ollama_client,
- self.lmstudio_client,
- self.azure_foundry_client,
- ]
- ):
- raise ValueError(
- "No valid client credentials found. Please set either OPENAI_API_KEY, "
- "both AZURE_API_KEY and AZURE_API_BASE environment variables, "
- "OLLAMA_API_BASE, LMSTUDIO_API_BASE, or AZURE_FOUNDRY_API_KEY and AZURE_FOUNDRY_API_ENDPOINT."
- )
- def _get_client_and_model(self, model: str):
- """Determine which client to use based on model prefix and return the
- appropriate client and model name."""
- if model.startswith("azure/"):
- if not self.azure_client:
- raise ValueError(
- "Azure OpenAI credentials not configured but azure/ model prefix used"
- )
- return self.azure_client, model[6:] # Strip 'azure/' prefix
- elif model.startswith("openai/"):
- if not self.openai_client:
- raise ValueError(
- "OpenAI credentials not configured but openai/ model prefix used"
- )
- return self.openai_client, model[7:] # Strip 'openai/' prefix
- elif model.startswith("deepseek/"):
- if not self.deepseek_client:
- raise ValueError(
- "Deepseek OpenAI credentials not configured but deepseek/ model prefix used"
- )
- return self.deepseek_client, model[9:] # Strip 'deepseek/' prefix
- elif model.startswith("ollama/"):
- if not self.ollama_client:
- raise ValueError(
- "Ollama OpenAI credentials not configured but ollama/ model prefix used"
- )
- return self.ollama_client, model[7:] # Strip 'ollama/' prefix
- elif model.startswith("lmstudio/"):
- if not self.lmstudio_client:
- raise ValueError(
- "LMStudio credentials not configured but lmstudio/ model prefix used"
- )
- return self.lmstudio_client, model[9:] # Strip 'lmstudio/' prefix
- elif model.startswith("azure-foundry/"):
- if not self.azure_foundry_client:
- raise ValueError(
- "Azure Foundry credentials not configured but azure-foundry/ model prefix used"
- )
- return (
- self.azure_foundry_client,
- model[14:],
- ) # Strip 'azure-foundry/' prefix
- else:
- # Default to OpenAI if no prefix is provided.
- if self.openai_client:
- return self.openai_client, model
- elif self.azure_client:
- return self.azure_client, model
- elif self.ollama_client:
- return self.ollama_client, model
- elif self.lmstudio_client:
- return self.lmstudio_client, model
- elif self.azure_foundry_client:
- return self.azure_foundry_client, model
- else:
- raise ValueError("No valid client available for model prefix")
- def _get_async_client_and_model(self, model: str):
- """Get async client and model name based on prefix."""
- if model.startswith("azure/"):
- if not self.async_azure_client:
- raise ValueError(
- "Azure OpenAI credentials not configured but azure/ model prefix used"
- )
- return self.async_azure_client, model[6:]
- elif model.startswith("openai/"):
- if not self.async_openai_client:
- raise ValueError(
- "OpenAI credentials not configured but openai/ model prefix used"
- )
- return self.async_openai_client, model[7:]
- elif model.startswith("deepseek/"):
- if not self.async_deepseek_client:
- raise ValueError(
- "Deepseek OpenAI credentials not configured but deepseek/ model prefix used"
- )
- return self.async_deepseek_client, model[9:].strip()
- elif model.startswith("ollama/"):
- if not self.async_ollama_client:
- raise ValueError(
- "Ollama OpenAI credentials not configured but ollama/ model prefix used"
- )
- return self.async_ollama_client, model[7:]
- elif model.startswith("lmstudio/"):
- if not self.async_lmstudio_client:
- raise ValueError(
- "LMStudio credentials not configured but lmstudio/ model prefix used"
- )
- return self.async_lmstudio_client, model[9:]
- elif model.startswith("azure-foundry/"):
- if not self.async_azure_foundry_client:
- raise ValueError(
- "Azure Foundry credentials not configured but azure-foundry/ model prefix used"
- )
- return self.async_azure_foundry_client, model[14:]
- else:
- if self.async_openai_client:
- return self.async_openai_client, model
- elif self.async_azure_client:
- return self.async_azure_client, model
- elif self.async_ollama_client:
- return self.async_ollama_client, model
- elif self.async_lmstudio_client:
- return self.async_lmstudio_client, model
- elif self.async_azure_foundry_client:
- return self.async_azure_foundry_client, model
- else:
- raise ValueError(
- "No valid async client available for model prefix"
- )
- def _process_messages_with_images(
- self, messages: list[dict]
- ) -> list[dict]:
- """
- Process messages that may contain image_url or image_data fields.
- Now includes aggressive image resizing similar to Anthropic provider.
- """
- processed_messages = []
- for msg in messages:
- if msg.get("role") == "system":
- # System messages don't support content arrays in OpenAI
- processed_messages.append(msg)
- continue
- # Check if the message contains image data
- image_url = msg.pop("image_url", None)
- image_data = msg.pop("image_data", None)
- content = msg.get("content")
- if image_url or image_data:
- # Convert to content array format
- new_content = []
- # Add image content
- if image_url:
- new_content.append(
- {"type": "image_url", "image_url": {"url": image_url}}
- )
- elif image_data:
- # Resize the base64 image data if available
- media_type = image_data.get("media_type", "image/jpeg")
- data = image_data.get("data", "")
- # Apply image resizing if PIL is available
- if data:
- data = resize_base64_image(data)
- logger.debug(
- f"Image resized, new size: {len(data)} chars"
- )
- # OpenAI expects base64 images in data URL format
- data_url = f"data:{media_type};base64,{data}"
- new_content.append(
- {"type": "image_url", "image_url": {"url": data_url}}
- )
- # Add text content if present
- if content:
- new_content.append({"type": "text", "text": content})
- # Update the message
- new_msg = dict(msg)
- new_msg["content"] = new_content
- processed_messages.append(new_msg)
- else:
- processed_messages.append(msg)
- return processed_messages
- def _process_array_content_with_images(self, content: list) -> list:
- """
- Process content array that may contain image_url items.
- Used for messages that already have content in array format.
- """
- if not content or not isinstance(content, list):
- return content
- processed_content = []
- for item in content:
- if isinstance(item, dict):
- if item.get("type") == "image_url":
- # Process image URL if needed
- processed_content.append(item)
- elif item.get("type") == "image" and item.get("source"):
- # Convert Anthropic-style to OpenAI-style
- source = item.get("source", {})
- if source.get("type") == "base64" and source.get("data"):
- # Resize the base64 image data
- resized_data = resize_base64_image(source.get("data"))
- media_type = source.get("media_type", "image/jpeg")
- data_url = f"data:{media_type};base64,{resized_data}"
- processed_content.append(
- {
- "type": "image_url",
- "image_url": {"url": data_url},
- }
- )
- elif source.get("type") == "url" and source.get("url"):
- processed_content.append(
- {
- "type": "image_url",
- "image_url": {"url": source.get("url")},
- }
- )
- else:
- # Pass through other types
- processed_content.append(item)
- else:
- processed_content.append(item)
- return processed_content
- def _preprocess_messages(self, messages: list[dict]) -> list[dict]:
- """
- Preprocess all messages to optimize images before sending to OpenAI API.
- """
- if not messages or not isinstance(messages, list):
- return messages
- processed_messages = []
- for msg in messages:
- # Skip system messages as they're handled separately
- if msg.get("role") == "system":
- processed_messages.append(msg)
- continue
- # Process array-format content (might contain images)
- if isinstance(msg.get("content"), list):
- new_msg = dict(msg)
- new_msg["content"] = self._process_array_content_with_images(
- msg["content"]
- )
- processed_messages.append(new_msg)
- else:
- # Standard processing for non-array content
- processed_messages.append(msg)
- return processed_messages
- def _get_base_args(self, generation_config: GenerationConfig) -> dict:
- # Keep existing implementation...
- args: dict[str, Any] = {
- "model": generation_config.model,
- "stream": generation_config.stream,
- }
- model_str = generation_config.model or ""
- if any(
- model_prefix in model_str.lower()
- for model_prefix in ["o1", "o3", "gpt-5"]
- ):
- args["max_completion_tokens"] = (
- generation_config.max_tokens_to_sample
- )
- else:
- args["max_tokens"] = generation_config.max_tokens_to_sample
- args["temperature"] = generation_config.temperature
- args["top_p"] = generation_config.top_p
- if generation_config.reasoning_effort is not None:
- args["reasoning_effort"] = generation_config.reasoning_effort
- if generation_config.functions is not None:
- args["functions"] = generation_config.functions
- if generation_config.tools is not None:
- args["tools"] = generation_config.tools
- if generation_config.response_format is not None:
- args["response_format"] = generation_config.response_format
- return args
- async def _execute_task(self, task: dict[str, Any]):
- messages = task["messages"]
- generation_config = task["generation_config"]
- kwargs = task["kwargs"]
- # First preprocess to handle any images in array format
- messages = self._preprocess_messages(messages)
- # Then process messages with direct image_url or image_data fields
- processed_messages = self._process_messages_with_images(messages)
- args = self._get_base_args(generation_config)
- client, model_name = self._get_async_client_and_model(args["model"])
- args["model"] = model_name
- args["messages"] = processed_messages
- args = {**args, **kwargs}
- # Check if we're using a vision-capable model when images are present
- contains_images = any(
- isinstance(msg.get("content"), list)
- and any(
- item.get("type") == "image_url"
- for item in msg.get("content", [])
- )
- for msg in processed_messages
- )
- if contains_images:
- vision_models = ["gpt-4-vision", "gpt-4.1"]
- if all(
- vision_model in model_name for vision_model in vision_models
- ):
- logger.warning(
- f"Using model {model_name} with images, but it may not support vision"
- )
- logger.debug(f"Executing async task with args: {args}")
- try:
- # Same as before...
- if client == self.async_azure_foundry_client:
- model_value = args.pop(
- "model"
- ) # Remove model before passing args
- response = await client.complete(**args)
- else:
- response = await client.chat.completions.create(**args)
- logger.debug("Async task executed successfully")
- return response
- except Exception as e:
- logger.error(f"Async task execution failed: {str(e)}")
- # HACK: print the exception to the console for debugging
- raise
- def _execute_task_sync(self, task: dict[str, Any]):
- messages = task["messages"]
- generation_config = task["generation_config"]
- kwargs = task["kwargs"]
- # First preprocess to handle any images in array format
- messages = self._preprocess_messages(messages)
- # Then process messages with direct image_url or image_data fields
- processed_messages = self._process_messages_with_images(messages)
- args = self._get_base_args(generation_config)
- client, model_name = self._get_client_and_model(args["model"])
- args["model"] = model_name
- args["messages"] = processed_messages
- args = {**args, **kwargs}
- # Same vision model check as in async version
- contains_images = any(
- isinstance(msg.get("content"), list)
- and any(
- item.get("type") == "image_url"
- for item in msg.get("content", [])
- )
- for msg in processed_messages
- )
- if contains_images:
- vision_models = ["gpt-4-vision", "gpt-4.1"]
- if all(
- vision_model in model_name for vision_model in vision_models
- ):
- logger.warning(
- f"Using model {model_name} with images, but it may not support vision"
- )
- logger.debug(f"Executing sync OpenAI task with args: {args}")
- try:
- # Same as before...
- if client == self.azure_foundry_client:
- args.pop("model")
- response = client.complete(**args)
- else:
- response = client.chat.completions.create(**args)
- logger.debug("Sync task executed successfully")
- return response
- except Exception as e:
- logger.error(f"Sync task execution failed: {str(e)}")
- raise
|