import logging import os from copy import copy from typing import Any import litellm import requests from aiohttp import ClientError, ClientSession from litellm import AuthenticationError, aembedding, embedding from core.base import ( ChunkSearchResult, EmbeddingConfig, EmbeddingProvider, EmbeddingPurpose, R2RException, ) logger = logging.getLogger() class LiteLLMEmbeddingProvider(EmbeddingProvider): def __init__( self, config: EmbeddingConfig, *args, **kwargs, ) -> None: super().__init__(config) self.litellm_embedding = embedding self.litellm_aembedding = aembedding provider = config.provider if not provider: raise ValueError( "Must set provider in order to initialize `LiteLLMEmbeddingProvider`." ) if provider != "litellm": raise ValueError( "LiteLLMEmbeddingProvider must be initialized with provider `litellm`." ) self.rerank_url = None if config.rerank_model: if "huggingface" not in config.rerank_model: raise ValueError( "LiteLLMEmbeddingProvider only supports re-ranking via the HuggingFace text-embeddings-inference API" ) url = os.getenv("HUGGINGFACE_API_BASE") or config.rerank_url if not url: raise ValueError( "LiteLLMEmbeddingProvider requires a valid reranking API url to be set via `embedding.rerank_url` in the r2r.toml, or via the environment variable `HUGGINGFACE_API_BASE`." ) self.rerank_url = url self.base_model = config.base_model if "amazon" in self.base_model: logger.warn("Amazon embedding model detected, dropping params") litellm.drop_params = True self.base_dimension = config.base_dimension def _get_embedding_kwargs(self, **kwargs): embedding_kwargs = { "model": self.base_model, "dimensions": self.base_dimension, } embedding_kwargs.update(kwargs) return embedding_kwargs async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]: texts = task["texts"] kwargs = self._get_embedding_kwargs(**task.get("kwargs", {})) try: response = await self.litellm_aembedding( input=texts, **kwargs, ) return [data["embedding"] for data in response.data] except AuthenticationError as e: logger.error( "Authentication error: Invalid API key or credentials." ) raise except Exception as e: error_msg = f"Error getting embeddings: {str(e)}" logger.error(error_msg) raise R2RException(error_msg, 400) def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]: texts = task["texts"] kwargs = self._get_embedding_kwargs(**task.get("kwargs", {})) try: response = self.litellm_embedding( input=texts, **kwargs, ) return [data["embedding"] for data in response.data] except AuthenticationError as e: logger.error( "Authentication error: Invalid API key or credentials." ) raise except Exception as e: error_msg = f"Error getting embeddings: {str(e)}" logger.error(error_msg) raise R2RException(error_msg, 400) async def async_get_embedding( self, text: str, stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, **kwargs, ) -> list[float]: if stage != EmbeddingProvider.PipeStage.BASE: raise ValueError( "LiteLLMEmbeddingProvider only supports search stage." ) task = { "texts": [text], "stage": stage, "purpose": purpose, "kwargs": kwargs, } return (await self._execute_with_backoff_async(task))[0] def get_embedding( self, text: str, stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, **kwargs, ) -> list[float]: if stage != EmbeddingProvider.PipeStage.BASE: raise ValueError( "Error getting embeddings: LiteLLMEmbeddingProvider only supports search stage." ) task = { "texts": [text], "stage": stage, "purpose": purpose, "kwargs": kwargs, } return self._execute_with_backoff_sync(task)[0] async def async_get_embeddings( self, texts: list[str], stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, **kwargs, ) -> list[list[float]]: if stage != EmbeddingProvider.PipeStage.BASE: raise ValueError( "LiteLLMEmbeddingProvider only supports search stage." ) task = { "texts": texts, "stage": stage, "purpose": purpose, "kwargs": kwargs, } return await self._execute_with_backoff_async(task) def get_embeddings( self, texts: list[str], stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE, purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX, **kwargs, ) -> list[list[float]]: if stage != EmbeddingProvider.PipeStage.BASE: raise ValueError( "LiteLLMEmbeddingProvider only supports search stage." ) task = { "texts": texts, "stage": stage, "purpose": purpose, "kwargs": kwargs, } return self._execute_with_backoff_sync(task) def rerank( self, query: str, results: list[ChunkSearchResult], stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK, limit: int = 10, ): if self.config.rerank_model is not None: if not self.rerank_url: raise ValueError( "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider" ) texts = [result.text for result in results] payload = { "query": query, "texts": texts, "model-id": self.config.rerank_model.split("huggingface/")[1], } headers = {"Content-Type": "application/json"} try: response = requests.post( self.rerank_url, json=payload, headers=headers ) response.raise_for_status() reranked_results = response.json() # Copy reranked results into new array scored_results = [] for rank_info in reranked_results: original_result = results[rank_info["index"]] copied_result = copy(original_result) # Inject the reranking score into the result object copied_result.score = rank_info["score"] scored_results.append(copied_result) # Return only the ChunkSearchResult objects, limited to specified count return scored_results[:limit] except requests.RequestException as e: logger.error(f"Error during reranking: {str(e)}") # Fall back to returning the original results if reranking fails return results[:limit] else: return results[:limit] async def arerank( self, query: str, results: list[ChunkSearchResult], stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK, limit: int = 10, ) -> list[ChunkSearchResult]: """ Asynchronously rerank search results using the configured rerank model. Args: query: The search query string results: List of ChunkSearchResult objects to rerank stage: The pipeline stage (must be RERANK) limit: Maximum number of results to return Returns: List of reranked ChunkSearchResult objects, limited to specified count """ if self.config.rerank_model is not None: if not self.rerank_url: raise ValueError( "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider" ) texts = [result.text for result in results] payload = { "query": query, "texts": texts, "model-id": self.config.rerank_model.split("huggingface/")[1], } headers = {"Content-Type": "application/json"} try: async with ClientSession() as session: async with session.post( self.rerank_url, json=payload, headers=headers ) as response: response.raise_for_status() reranked_results = await response.json() # Copy reranked results into new array scored_results = [] for rank_info in reranked_results: original_result = results[rank_info["index"]] copied_result = copy(original_result) # Inject the reranking score into the result object copied_result.score = rank_info["score"] scored_results.append(copied_result) # Return only the ChunkSearchResult objects, limited to specified count return scored_results[:limit] except (ClientError, Exception) as e: logger.error(f"Error during async reranking: {str(e)}") # Fall back to returning the original results if reranking fails return results[:limit] else: return results[:limit]