openai.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. import logging
  2. import os
  3. from typing import Any
  4. from openai import AsyncAzureOpenAI, AsyncOpenAI, OpenAI
  5. from core.base.abstractions import GenerationConfig
  6. from core.base.providers.llm import CompletionConfig, CompletionProvider
  7. from .utils import resize_base64_image
  8. logger = logging.getLogger()
  9. class OpenAICompletionProvider(CompletionProvider):
  10. def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
  11. super().__init__(config)
  12. self.openai_client = None
  13. self.async_openai_client = None
  14. self.azure_client = None
  15. self.async_azure_client = None
  16. self.deepseek_client = None
  17. self.async_deepseek_client = None
  18. self.ollama_client = None
  19. self.async_ollama_client = None
  20. self.lmstudio_client = None
  21. self.async_lmstudio_client = None
  22. # NEW: Azure Foundry clients using the Azure Inference API
  23. self.azure_foundry_client = None
  24. self.async_azure_foundry_client = None
  25. # Initialize OpenAI clients if credentials exist
  26. if os.getenv("OPENAI_API_KEY"):
  27. #self.openai_client = OpenAI()
  28. #self.async_openai_client = AsyncOpenAI()
  29. self.openai_client = OpenAI(
  30. api_key="sk-j9Uwupu0NPZtdDS_IfEZlRWpX1JgFyZFLZProkesy2QbtqMs16pDnylAozU",
  31. base_url="http://172.16.12.13:3000/v1"
  32. )
  33. self.async_openai_client = AsyncOpenAI(
  34. api_key="sk-j9Uwupu0NPZtdDS_IfEZlRWpX1JgFyZFLZProkesy2QbtqMs16pDnylAozU",
  35. base_url="http://172.16.12.13:3000/v1"
  36. )
  37. logger.debug("OpenAI clients initialized successfully")
  38. # Initialize Azure OpenAI clients if credentials exist
  39. azure_api_key = os.getenv("AZURE_API_KEY")
  40. azure_api_base = os.getenv("AZURE_API_BASE")
  41. if azure_api_key and azure_api_base:
  42. self.azure_client = AsyncAzureOpenAI(
  43. api_key=azure_api_key,
  44. api_version=os.getenv(
  45. "AZURE_API_VERSION", "2024-02-15-preview"
  46. ),
  47. azure_endpoint=azure_api_base,
  48. )
  49. self.async_azure_client = AsyncAzureOpenAI(
  50. api_key=azure_api_key,
  51. api_version=os.getenv(
  52. "AZURE_API_VERSION", "2024-02-15-preview"
  53. ),
  54. azure_endpoint=azure_api_base,
  55. )
  56. logger.debug("Azure OpenAI clients initialized successfully")
  57. # Initialize Deepseek clients if credentials exist
  58. deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
  59. deepseek_api_base = os.getenv(
  60. "DEEPSEEK_API_BASE", "https://api.deepseek.com"
  61. )
  62. if deepseek_api_key and deepseek_api_base:
  63. self.deepseek_client = OpenAI(
  64. api_key=deepseek_api_key,
  65. base_url=deepseek_api_base,
  66. )
  67. self.async_deepseek_client = AsyncOpenAI(
  68. api_key=deepseek_api_key,
  69. base_url=deepseek_api_base,
  70. )
  71. logger.debug("Deepseek OpenAI clients initialized successfully")
  72. # Initialize Ollama clients with default API key
  73. ollama_api_base = os.getenv(
  74. "OLLAMA_API_BASE", "http://localhost:11434/v1"
  75. )
  76. if ollama_api_base:
  77. self.ollama_client = OpenAI(
  78. api_key=os.getenv("OLLAMA_API_KEY", "dummy"),
  79. base_url=ollama_api_base,
  80. )
  81. self.async_ollama_client = AsyncOpenAI(
  82. api_key=os.getenv("OLLAMA_API_KEY", "dummy"),
  83. base_url=ollama_api_base,
  84. )
  85. logger.debug("Ollama OpenAI clients initialized successfully")
  86. # Initialize LMStudio clients
  87. lmstudio_api_base = os.getenv(
  88. "LMSTUDIO_API_BASE", "http://localhost:1234/v1"
  89. )
  90. if lmstudio_api_base:
  91. self.lmstudio_client = OpenAI(
  92. api_key=os.getenv("LMSTUDIO_API_KEY", "lm-studio"),
  93. base_url=lmstudio_api_base,
  94. )
  95. self.async_lmstudio_client = AsyncOpenAI(
  96. api_key=os.getenv("LMSTUDIO_API_KEY", "lm-studio"),
  97. base_url=lmstudio_api_base,
  98. )
  99. logger.debug("LMStudio OpenAI clients initialized successfully")
  100. # Initialize Azure Foundry clients if credentials exist.
  101. # These use the Azure Inference API (currently pasted into this handler).
  102. azure_foundry_api_key = os.getenv("AZURE_FOUNDRY_API_KEY")
  103. azure_foundry_api_endpoint = os.getenv("AZURE_FOUNDRY_API_ENDPOINT")
  104. if azure_foundry_api_key and azure_foundry_api_endpoint:
  105. from azure.ai.inference import (
  106. ChatCompletionsClient as AzureChatCompletionsClient,
  107. )
  108. from azure.ai.inference.aio import (
  109. ChatCompletionsClient as AsyncAzureChatCompletionsClient,
  110. )
  111. from azure.core.credentials import AzureKeyCredential
  112. self.azure_foundry_client = AzureChatCompletionsClient(
  113. endpoint=azure_foundry_api_endpoint,
  114. credential=AzureKeyCredential(azure_foundry_api_key),
  115. api_version=os.getenv(
  116. "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
  117. ),
  118. )
  119. self.async_azure_foundry_client = AsyncAzureChatCompletionsClient(
  120. endpoint=azure_foundry_api_endpoint,
  121. credential=AzureKeyCredential(azure_foundry_api_key),
  122. api_version=os.getenv(
  123. "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
  124. ),
  125. )
  126. logger.debug("Azure Foundry clients initialized successfully")
  127. if not any(
  128. [
  129. self.openai_client,
  130. self.azure_client,
  131. self.ollama_client,
  132. self.lmstudio_client,
  133. self.azure_foundry_client,
  134. ]
  135. ):
  136. raise ValueError(
  137. "No valid client credentials found. Please set either OPENAI_API_KEY, "
  138. "both AZURE_API_KEY and AZURE_API_BASE environment variables, "
  139. "OLLAMA_API_BASE, LMSTUDIO_API_BASE, or AZURE_FOUNDRY_API_KEY and AZURE_FOUNDRY_API_ENDPOINT."
  140. )
  141. def _get_client_and_model(self, model: str):
  142. """Determine which client to use based on model prefix and return the
  143. appropriate client and model name."""
  144. if model.startswith("azure/"):
  145. if not self.azure_client:
  146. raise ValueError(
  147. "Azure OpenAI credentials not configured but azure/ model prefix used"
  148. )
  149. return self.azure_client, model[6:] # Strip 'azure/' prefix
  150. elif model.startswith("openai/"):
  151. if not self.openai_client:
  152. raise ValueError(
  153. "OpenAI credentials not configured but openai/ model prefix used"
  154. )
  155. return self.openai_client, model[7:] # Strip 'openai/' prefix
  156. elif model.startswith("deepseek/"):
  157. if not self.deepseek_client:
  158. raise ValueError(
  159. "Deepseek OpenAI credentials not configured but deepseek/ model prefix used"
  160. )
  161. return self.deepseek_client, model[9:] # Strip 'deepseek/' prefix
  162. elif model.startswith("ollama/"):
  163. if not self.ollama_client:
  164. raise ValueError(
  165. "Ollama OpenAI credentials not configured but ollama/ model prefix used"
  166. )
  167. return self.ollama_client, model[7:] # Strip 'ollama/' prefix
  168. elif model.startswith("lmstudio/"):
  169. if not self.lmstudio_client:
  170. raise ValueError(
  171. "LMStudio credentials not configured but lmstudio/ model prefix used"
  172. )
  173. return self.lmstudio_client, model[9:] # Strip 'lmstudio/' prefix
  174. elif model.startswith("azure-foundry/"):
  175. if not self.azure_foundry_client:
  176. raise ValueError(
  177. "Azure Foundry credentials not configured but azure-foundry/ model prefix used"
  178. )
  179. return (
  180. self.azure_foundry_client,
  181. model[14:],
  182. ) # Strip 'azure-foundry/' prefix
  183. else:
  184. # Default to OpenAI if no prefix is provided.
  185. if self.openai_client:
  186. return self.openai_client, model
  187. elif self.azure_client:
  188. return self.azure_client, model
  189. elif self.ollama_client:
  190. return self.ollama_client, model
  191. elif self.lmstudio_client:
  192. return self.lmstudio_client, model
  193. elif self.azure_foundry_client:
  194. return self.azure_foundry_client, model
  195. else:
  196. raise ValueError("No valid client available for model prefix")
  197. def _get_async_client_and_model(self, model: str):
  198. """Get async client and model name based on prefix."""
  199. if model.startswith("azure/"):
  200. if not self.async_azure_client:
  201. raise ValueError(
  202. "Azure OpenAI credentials not configured but azure/ model prefix used"
  203. )
  204. return self.async_azure_client, model[6:]
  205. elif model.startswith("openai/"):
  206. if not self.async_openai_client:
  207. raise ValueError(
  208. "OpenAI credentials not configured but openai/ model prefix used"
  209. )
  210. return self.async_openai_client, model[7:]
  211. elif model.startswith("deepseek/"):
  212. if not self.async_deepseek_client:
  213. raise ValueError(
  214. "Deepseek OpenAI credentials not configured but deepseek/ model prefix used"
  215. )
  216. return self.async_deepseek_client, model[9:].strip()
  217. elif model.startswith("ollama/"):
  218. if not self.async_ollama_client:
  219. raise ValueError(
  220. "Ollama OpenAI credentials not configured but ollama/ model prefix used"
  221. )
  222. return self.async_ollama_client, model[7:]
  223. elif model.startswith("lmstudio/"):
  224. if not self.async_lmstudio_client:
  225. raise ValueError(
  226. "LMStudio credentials not configured but lmstudio/ model prefix used"
  227. )
  228. return self.async_lmstudio_client, model[9:]
  229. elif model.startswith("azure-foundry/"):
  230. if not self.async_azure_foundry_client:
  231. raise ValueError(
  232. "Azure Foundry credentials not configured but azure-foundry/ model prefix used"
  233. )
  234. return self.async_azure_foundry_client, model[14:]
  235. else:
  236. if self.async_openai_client:
  237. return self.async_openai_client, model
  238. elif self.async_azure_client:
  239. return self.async_azure_client, model
  240. elif self.async_ollama_client:
  241. return self.async_ollama_client, model
  242. elif self.async_lmstudio_client:
  243. return self.async_lmstudio_client, model
  244. elif self.async_azure_foundry_client:
  245. return self.async_azure_foundry_client, model
  246. else:
  247. raise ValueError(
  248. "No valid async client available for model prefix"
  249. )
  250. def _process_messages_with_images(
  251. self, messages: list[dict]
  252. ) -> list[dict]:
  253. """
  254. Process messages that may contain image_url or image_data fields.
  255. Now includes aggressive image resizing similar to Anthropic provider.
  256. """
  257. processed_messages = []
  258. for msg in messages:
  259. if msg.get("role") == "system":
  260. # System messages don't support content arrays in OpenAI
  261. processed_messages.append(msg)
  262. continue
  263. # Check if the message contains image data
  264. image_url = msg.pop("image_url", None)
  265. image_data = msg.pop("image_data", None)
  266. content = msg.get("content")
  267. if image_url or image_data:
  268. # Convert to content array format
  269. new_content = []
  270. # Add image content
  271. if image_url:
  272. new_content.append(
  273. {"type": "image_url", "image_url": {"url": image_url}}
  274. )
  275. elif image_data:
  276. # Resize the base64 image data if available
  277. media_type = image_data.get("media_type", "image/jpeg")
  278. data = image_data.get("data", "")
  279. # Apply image resizing if PIL is available
  280. if data:
  281. data = resize_base64_image(data)
  282. logger.debug(
  283. f"Image resized, new size: {len(data)} chars"
  284. )
  285. # OpenAI expects base64 images in data URL format
  286. data_url = f"data:{media_type};base64,{data}"
  287. new_content.append(
  288. {"type": "image_url", "image_url": {"url": data_url}}
  289. )
  290. # Add text content if present
  291. if content:
  292. new_content.append({"type": "text", "text": content})
  293. # Update the message
  294. new_msg = dict(msg)
  295. new_msg["content"] = new_content
  296. processed_messages.append(new_msg)
  297. else:
  298. processed_messages.append(msg)
  299. return processed_messages
  300. def _process_array_content_with_images(self, content: list) -> list:
  301. """
  302. Process content array that may contain image_url items.
  303. Used for messages that already have content in array format.
  304. """
  305. if not content or not isinstance(content, list):
  306. return content
  307. processed_content = []
  308. for item in content:
  309. if isinstance(item, dict):
  310. if item.get("type") == "image_url":
  311. # Process image URL if needed
  312. processed_content.append(item)
  313. elif item.get("type") == "image" and item.get("source"):
  314. # Convert Anthropic-style to OpenAI-style
  315. source = item.get("source", {})
  316. if source.get("type") == "base64" and source.get("data"):
  317. # Resize the base64 image data
  318. resized_data = resize_base64_image(source.get("data"))
  319. media_type = source.get("media_type", "image/jpeg")
  320. data_url = f"data:{media_type};base64,{resized_data}"
  321. processed_content.append(
  322. {
  323. "type": "image_url",
  324. "image_url": {"url": data_url},
  325. }
  326. )
  327. elif source.get("type") == "url" and source.get("url"):
  328. processed_content.append(
  329. {
  330. "type": "image_url",
  331. "image_url": {"url": source.get("url")},
  332. }
  333. )
  334. else:
  335. # Pass through other types
  336. processed_content.append(item)
  337. else:
  338. processed_content.append(item)
  339. return processed_content
  340. def _preprocess_messages(self, messages: list[dict]) -> list[dict]:
  341. """
  342. Preprocess all messages to optimize images before sending to OpenAI API.
  343. """
  344. if not messages or not isinstance(messages, list):
  345. return messages
  346. processed_messages = []
  347. for msg in messages:
  348. # Skip system messages as they're handled separately
  349. if msg.get("role") == "system":
  350. processed_messages.append(msg)
  351. continue
  352. # Process array-format content (might contain images)
  353. if isinstance(msg.get("content"), list):
  354. new_msg = dict(msg)
  355. new_msg["content"] = self._process_array_content_with_images(
  356. msg["content"]
  357. )
  358. processed_messages.append(new_msg)
  359. else:
  360. # Standard processing for non-array content
  361. processed_messages.append(msg)
  362. return processed_messages
  363. def _get_base_args(self, generation_config: GenerationConfig) -> dict:
  364. # Keep existing implementation...
  365. args: dict[str, Any] = {
  366. "model": generation_config.model,
  367. "stream": generation_config.stream,
  368. }
  369. model_str = generation_config.model or ""
  370. if any(
  371. model_prefix in model_str.lower()
  372. for model_prefix in ["o1", "o3", "gpt-5"]
  373. ):
  374. args["max_completion_tokens"] = (
  375. generation_config.max_tokens_to_sample
  376. )
  377. else:
  378. args["max_tokens"] = generation_config.max_tokens_to_sample
  379. args["temperature"] = generation_config.temperature
  380. args["top_p"] = generation_config.top_p
  381. if generation_config.reasoning_effort is not None:
  382. args["reasoning_effort"] = generation_config.reasoning_effort
  383. if generation_config.functions is not None:
  384. args["functions"] = generation_config.functions
  385. if generation_config.tools is not None:
  386. args["tools"] = generation_config.tools
  387. if generation_config.response_format is not None:
  388. args["response_format"] = generation_config.response_format
  389. return args
  390. async def _execute_task(self, task: dict[str, Any]):
  391. messages = task["messages"]
  392. generation_config = task["generation_config"]
  393. kwargs = task["kwargs"]
  394. # First preprocess to handle any images in array format
  395. messages = self._preprocess_messages(messages)
  396. # Then process messages with direct image_url or image_data fields
  397. processed_messages = self._process_messages_with_images(messages)
  398. args = self._get_base_args(generation_config)
  399. client, model_name = self._get_async_client_and_model(args["model"])
  400. args["model"] = model_name
  401. args["messages"] = processed_messages
  402. args = {**args, **kwargs}
  403. # Check if we're using a vision-capable model when images are present
  404. contains_images = any(
  405. isinstance(msg.get("content"), list)
  406. and any(
  407. item.get("type") == "image_url"
  408. for item in msg.get("content", [])
  409. )
  410. for msg in processed_messages
  411. )
  412. if contains_images:
  413. vision_models = ["gpt-4-vision", "gpt-4.1"]
  414. if all(
  415. vision_model in model_name for vision_model in vision_models
  416. ):
  417. logger.warning(
  418. f"Using model {model_name} with images, but it may not support vision"
  419. )
  420. logger.debug(f"Executing async task with args: {args}")
  421. try:
  422. # Same as before...
  423. if client == self.async_azure_foundry_client:
  424. model_value = args.pop(
  425. "model"
  426. ) # Remove model before passing args
  427. response = await client.complete(**args)
  428. else:
  429. response = await client.chat.completions.create(**args)
  430. logger.debug("Async task executed successfully")
  431. return response
  432. except Exception as e:
  433. logger.error(f"Async task execution failed: {str(e)}")
  434. # HACK: print the exception to the console for debugging
  435. raise
  436. def _execute_task_sync(self, task: dict[str, Any]):
  437. messages = task["messages"]
  438. generation_config = task["generation_config"]
  439. kwargs = task["kwargs"]
  440. # First preprocess to handle any images in array format
  441. messages = self._preprocess_messages(messages)
  442. # Then process messages with direct image_url or image_data fields
  443. processed_messages = self._process_messages_with_images(messages)
  444. args = self._get_base_args(generation_config)
  445. client, model_name = self._get_client_and_model(args["model"])
  446. args["model"] = model_name
  447. args["messages"] = processed_messages
  448. args = {**args, **kwargs}
  449. # Same vision model check as in async version
  450. contains_images = any(
  451. isinstance(msg.get("content"), list)
  452. and any(
  453. item.get("type") == "image_url"
  454. for item in msg.get("content", [])
  455. )
  456. for msg in processed_messages
  457. )
  458. if contains_images:
  459. vision_models = ["gpt-4-vision", "gpt-4.1"]
  460. if all(
  461. vision_model in model_name for vision_model in vision_models
  462. ):
  463. logger.warning(
  464. f"Using model {model_name} with images, but it may not support vision"
  465. )
  466. logger.debug(f"Executing sync OpenAI task with args: {args}")
  467. try:
  468. # Same as before...
  469. if client == self.azure_foundry_client:
  470. args.pop("model")
  471. response = client.complete(**args)
  472. else:
  473. response = client.chat.completions.create(**args)
  474. logger.debug("Sync task executed successfully")
  475. return response
  476. except Exception as e:
  477. logger.error(f"Sync task execution failed: {str(e)}")
  478. raise