r2r_llm.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import logging
  2. from typing import Any
  3. from core.base.abstractions import GenerationConfig
  4. from core.base.providers.llm import CompletionConfig, CompletionProvider
  5. from .anthropic import AnthropicCompletionProvider
  6. from .azure_foundry import AzureFoundryCompletionProvider
  7. from .litellm import LiteLLMCompletionProvider
  8. from .openai import OpenAICompletionProvider
  9. logger = logging.getLogger(__name__)
  10. class R2RCompletionProvider(CompletionProvider):
  11. """A provider that routes to the right LLM provider (R2R):
  12. - If `generation_config.model` starts with "anthropic/", call AnthropicCompletionProvider.
  13. - If it starts with "azure-foundry/", call AzureFoundryCompletionProvider.
  14. - If it starts with one of the other OpenAI-like prefixes ("openai/", "azure/", "deepseek/", "ollama/", "lmstudio/")
  15. or has no prefix (e.g. "gpt-4", "gpt-3.5"), call OpenAICompletionProvider.
  16. - Otherwise, fallback to LiteLLMCompletionProvider.
  17. """
  18. def __init__(self, config: CompletionConfig, *args, **kwargs) -> None:
  19. """Initialize sub-providers for OpenAI, Anthropic, LiteLLM, and Azure
  20. Foundry."""
  21. super().__init__(config)
  22. self.config = config
  23. logger.info("Initializing R2RCompletionProvider...")
  24. self._openai_provider = OpenAICompletionProvider(
  25. self.config, *args, **kwargs
  26. )
  27. self._anthropic_provider = AnthropicCompletionProvider(
  28. self.config, *args, **kwargs
  29. )
  30. self._litellm_provider = LiteLLMCompletionProvider(
  31. self.config, *args, **kwargs
  32. )
  33. self._azure_foundry_provider = AzureFoundryCompletionProvider(
  34. self.config, *args, **kwargs
  35. ) # New provider
  36. logger.debug(
  37. "R2RCompletionProvider initialized with OpenAI, Anthropic, LiteLLM, and Azure Foundry sub-providers."
  38. )
  39. def _choose_subprovider_by_model(
  40. self, model_name: str, is_streaming: bool = False
  41. ) -> CompletionProvider:
  42. """Decide which underlying sub-provider to call based on the model name
  43. (prefix)."""
  44. # Route to Anthropic if appropriate.
  45. if model_name.startswith("anthropic/"):
  46. return self._anthropic_provider
  47. # Route to Azure Foundry explicitly.
  48. if model_name.startswith("azure-foundry/"):
  49. return self._azure_foundry_provider
  50. # OpenAI-like prefixes.
  51. openai_like_prefixes = [
  52. "openai/",
  53. "azure/",
  54. "deepseek/",
  55. "ollama/",
  56. "lmstudio/",
  57. ]
  58. if (
  59. any(
  60. model_name.startswith(prefix)
  61. for prefix in openai_like_prefixes
  62. )
  63. or "/" not in model_name
  64. ):
  65. return self._openai_provider
  66. # Fallback to LiteLLM.
  67. return self._litellm_provider
  68. async def _execute_task(self, task: dict[str, Any]):
  69. """Pick the sub-provider based on model name and forward the async
  70. call."""
  71. generation_config: GenerationConfig = task["generation_config"]
  72. model_name = generation_config.model
  73. sub_provider = self._choose_subprovider_by_model(model_name or "")
  74. return await sub_provider._execute_task(task)
  75. def _execute_task_sync(self, task: dict[str, Any]):
  76. """Pick the sub-provider based on model name and forward the sync
  77. call."""
  78. generation_config: GenerationConfig = task["generation_config"]
  79. model_name = generation_config.model
  80. sub_provider = self._choose_subprovider_by_model(model_name or "")
  81. return sub_provider._execute_task_sync(task)