azure_foundry.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import logging
  2. import os
  3. from typing import Any, Optional
  4. from azure.ai.inference import (
  5. ChatCompletionsClient as AzureChatCompletionsClient,
  6. )
  7. from azure.ai.inference.aio import (
  8. ChatCompletionsClient as AsyncAzureChatCompletionsClient,
  9. )
  10. from azure.core.credentials import AzureKeyCredential
  11. from core.base.abstractions import GenerationConfig
  12. from core.base.providers.llm import CompletionConfig, CompletionProvider
  13. logger = logging.getLogger(__name__)
  14. class AzureFoundryCompletionProvider(CompletionProvider):
  15. def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
  16. super().__init__(config)
  17. self.azure_foundry_client: Optional[AzureChatCompletionsClient] = None
  18. self.async_azure_foundry_client: Optional[
  19. AsyncAzureChatCompletionsClient
  20. ] = None
  21. # Initialize Azure Foundry clients if credentials exist.
  22. azure_foundry_api_key = os.getenv("AZURE_FOUNDRY_API_KEY")
  23. azure_foundry_api_endpoint = os.getenv("AZURE_FOUNDRY_API_ENDPOINT")
  24. if azure_foundry_api_key and azure_foundry_api_endpoint:
  25. self.azure_foundry_client = AzureChatCompletionsClient(
  26. endpoint=azure_foundry_api_endpoint,
  27. credential=AzureKeyCredential(azure_foundry_api_key),
  28. api_version=os.getenv(
  29. "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
  30. ),
  31. )
  32. self.async_azure_foundry_client = AsyncAzureChatCompletionsClient(
  33. endpoint=azure_foundry_api_endpoint,
  34. credential=AzureKeyCredential(azure_foundry_api_key),
  35. api_version=os.getenv(
  36. "AZURE_FOUNDRY_API_VERSION", "2024-05-01-preview"
  37. ),
  38. )
  39. logger.debug("Azure Foundry clients initialized successfully")
  40. def _get_base_args(
  41. self, generation_config: GenerationConfig
  42. ) -> dict[str, Any]:
  43. # Construct arguments similar to the other providers.
  44. args: dict[str, Any] = {
  45. "top_p": generation_config.top_p,
  46. "stream": generation_config.stream,
  47. "max_tokens": generation_config.max_tokens_to_sample,
  48. "temperature": generation_config.temperature,
  49. }
  50. if generation_config.functions is not None:
  51. args["functions"] = generation_config.functions
  52. if generation_config.tools is not None:
  53. args["tools"] = generation_config.tools
  54. if generation_config.response_format is not None:
  55. args["response_format"] = generation_config.response_format
  56. return args
  57. async def _execute_task(self, task: dict[str, Any]):
  58. messages = task["messages"]
  59. generation_config = task["generation_config"]
  60. kwargs = task["kwargs"]
  61. args = self._get_base_args(generation_config)
  62. # Azure Foundry does not require a "model" argument; the endpoint is fixed.
  63. args["messages"] = messages
  64. args = {**args, **kwargs}
  65. logger.debug(f"Executing async Azure Foundry task with args: {args}")
  66. try:
  67. if self.async_azure_foundry_client is None:
  68. raise ValueError("Azure Foundry client is not initialized")
  69. response = await self.async_azure_foundry_client.complete(**args)
  70. logger.debug("Async Azure Foundry task executed successfully")
  71. return response
  72. except Exception as e:
  73. logger.error(
  74. f"Async Azure Foundry task execution failed: {str(e)}"
  75. )
  76. raise
  77. def _execute_task_sync(self, task: dict[str, Any]):
  78. messages = task["messages"]
  79. generation_config = task["generation_config"]
  80. kwargs = task["kwargs"]
  81. args = self._get_base_args(generation_config)
  82. args["messages"] = messages
  83. args = {**args, **kwargs}
  84. logger.debug(f"Executing sync Azure Foundry task with args: {args}")
  85. try:
  86. if self.azure_foundry_client is None:
  87. raise ValueError("Azure Foundry client is not initialized")
  88. response = self.azure_foundry_client.complete(**args)
  89. logger.debug("Sync Azure Foundry task executed successfully")
  90. return response
  91. except Exception as e:
  92. logger.error(f"Sync Azure Foundry task execution failed: {str(e)}")
  93. raise