123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- import logging
- import os
- from copy import copy
- from typing import Any
- import litellm
- import requests
- from aiohttp import ClientError, ClientSession
- from litellm import AuthenticationError, aembedding, embedding
- from core.base import (
- ChunkSearchResult,
- EmbeddingConfig,
- EmbeddingProvider,
- EmbeddingPurpose,
- R2RException,
- )
- logger = logging.getLogger()
- class LiteLLMEmbeddingProvider(EmbeddingProvider):
- def __init__(
- self,
- config: EmbeddingConfig,
- *args,
- **kwargs,
- ) -> None:
- super().__init__(config)
- self.litellm_embedding = embedding
- self.litellm_aembedding = aembedding
- provider = config.provider
- if not provider:
- raise ValueError(
- "Must set provider in order to initialize `LiteLLMEmbeddingProvider`."
- )
- if provider != "litellm":
- raise ValueError(
- "LiteLLMEmbeddingProvider must be initialized with provider `litellm`."
- )
- self.rerank_url = None
- if config.rerank_model:
- if "huggingface" not in config.rerank_model:
- raise ValueError(
- "LiteLLMEmbeddingProvider only supports re-ranking via the HuggingFace text-embeddings-inference API"
- )
- url = os.getenv("HUGGINGFACE_API_BASE") or config.rerank_url
- if not url:
- raise ValueError(
- "LiteLLMEmbeddingProvider requires a valid reranking API url to be set via `embedding.rerank_url` in the r2r.toml, or via the environment variable `HUGGINGFACE_API_BASE`."
- )
- self.rerank_url = url
- self.base_model = config.base_model
- if "amazon" in self.base_model:
- logger.warn("Amazon embedding model detected, dropping params")
- litellm.drop_params = True
- self.base_dimension = config.base_dimension
- def _get_embedding_kwargs(self, **kwargs):
- embedding_kwargs = {
- "model": self.base_model,
- "dimensions": self.base_dimension,
- }
- embedding_kwargs.update(kwargs)
- return embedding_kwargs
- async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
- texts = task["texts"]
- kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
- try:
- response = await self.litellm_aembedding(
- input=texts,
- **kwargs,
- )
- return [data["embedding"] for data in response.data]
- except AuthenticationError as e:
- logger.error(
- "Authentication error: Invalid API key or credentials."
- )
- raise
- except Exception as e:
- error_msg = f"Error getting embeddings: {str(e)}"
- logger.error(error_msg)
- raise R2RException(error_msg, 400)
- def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
- texts = task["texts"]
- kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
- try:
- response = self.litellm_embedding(
- input=texts,
- **kwargs,
- )
- return [data["embedding"] for data in response.data]
- except AuthenticationError as e:
- logger.error(
- "Authentication error: Invalid API key or credentials."
- )
- raise
- except Exception as e:
- error_msg = f"Error getting embeddings: {str(e)}"
- logger.error(error_msg)
- raise R2RException(error_msg, 400)
- async def async_get_embedding(
- self,
- text: str,
- stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
- purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
- **kwargs,
- ) -> list[float]:
- if stage != EmbeddingProvider.PipeStage.BASE:
- raise ValueError(
- "LiteLLMEmbeddingProvider only supports search stage."
- )
- task = {
- "texts": [text],
- "stage": stage,
- "purpose": purpose,
- "kwargs": kwargs,
- }
- return (await self._execute_with_backoff_async(task))[0]
- def get_embedding(
- self,
- text: str,
- stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
- purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
- **kwargs,
- ) -> list[float]:
- if stage != EmbeddingProvider.PipeStage.BASE:
- raise ValueError(
- "Error getting embeddings: LiteLLMEmbeddingProvider only supports search stage."
- )
- task = {
- "texts": [text],
- "stage": stage,
- "purpose": purpose,
- "kwargs": kwargs,
- }
- return self._execute_with_backoff_sync(task)[0]
- async def async_get_embeddings(
- self,
- texts: list[str],
- stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
- purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
- **kwargs,
- ) -> list[list[float]]:
- if stage != EmbeddingProvider.PipeStage.BASE:
- raise ValueError(
- "LiteLLMEmbeddingProvider only supports search stage."
- )
- task = {
- "texts": texts,
- "stage": stage,
- "purpose": purpose,
- "kwargs": kwargs,
- }
- return await self._execute_with_backoff_async(task)
- def get_embeddings(
- self,
- texts: list[str],
- stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
- purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
- **kwargs,
- ) -> list[list[float]]:
- if stage != EmbeddingProvider.PipeStage.BASE:
- raise ValueError(
- "LiteLLMEmbeddingProvider only supports search stage."
- )
- task = {
- "texts": texts,
- "stage": stage,
- "purpose": purpose,
- "kwargs": kwargs,
- }
- return self._execute_with_backoff_sync(task)
- def rerank(
- self,
- query: str,
- results: list[ChunkSearchResult],
- stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
- limit: int = 10,
- ):
- if self.config.rerank_model is not None:
- if not self.rerank_url:
- raise ValueError(
- "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider"
- )
- texts = [result.text for result in results]
- payload = {
- "query": query,
- "texts": texts,
- "model-id": self.config.rerank_model.split("huggingface/")[1],
- }
- headers = {"Content-Type": "application/json"}
- try:
- response = requests.post(
- self.rerank_url, json=payload, headers=headers
- )
- response.raise_for_status()
- reranked_results = response.json()
- # Copy reranked results into new array
- scored_results = []
- for rank_info in reranked_results:
- original_result = results[rank_info["index"]]
- copied_result = copy(original_result)
- # Inject the reranking score into the result object
- copied_result.score = rank_info["score"]
- scored_results.append(copied_result)
- # Return only the ChunkSearchResult objects, limited to specified count
- return scored_results[:limit]
- except requests.RequestException as e:
- logger.error(f"Error during reranking: {str(e)}")
- # Fall back to returning the original results if reranking fails
- return results[:limit]
- else:
- return results[:limit]
- async def arerank(
- self,
- query: str,
- results: list[ChunkSearchResult],
- stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
- limit: int = 10,
- ) -> list[ChunkSearchResult]:
- """
- Asynchronously rerank search results using the configured rerank model.
- Args:
- query: The search query string
- results: List of ChunkSearchResult objects to rerank
- stage: The pipeline stage (must be RERANK)
- limit: Maximum number of results to return
- Returns:
- List of reranked ChunkSearchResult objects, limited to specified count
- """
- if self.config.rerank_model is not None:
- if not self.rerank_url:
- raise ValueError(
- "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider"
- )
- texts = [result.text for result in results]
- payload = {
- "query": query,
- "texts": texts,
- "model-id": self.config.rerank_model.split("huggingface/")[1],
- }
- headers = {"Content-Type": "application/json"}
- try:
- async with ClientSession() as session:
- async with session.post(
- self.rerank_url, json=payload, headers=headers
- ) as response:
- response.raise_for_status()
- reranked_results = await response.json()
- # Copy reranked results into new array
- scored_results = []
- for rank_info in reranked_results:
- original_result = results[rank_info["index"]]
- copied_result = copy(original_result)
- # Inject the reranking score into the result object
- copied_result.score = rank_info["score"]
- scored_results.append(copied_result)
- # Return only the ChunkSearchResult objects, limited to specified count
- return scored_results[:limit]
- except (ClientError, Exception) as e:
- logger.error(f"Error during async reranking: {str(e)}")
- # Fall back to returning the original results if reranking fails
- return results[:limit]
- else:
- return results[:limit]
|