123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- import asyncio
- import logging
- import random
- import time
- from abc import abstractmethod
- from enum import Enum
- from typing import Any, Optional
- from litellm import AuthenticationError
- from core.base.abstractions import VectorQuantizationSettings
- from ..abstractions import (
- ChunkSearchResult,
- EmbeddingPurpose,
- default_embedding_prefixes,
- )
- from .base import Provider, ProviderConfig
- logger = logging.getLogger()
- class EmbeddingConfig(ProviderConfig):
- provider: str
- base_model: str
- base_dimension: int
- rerank_model: Optional[str] = None
- rerank_url: Optional[str] = None
- batch_size: int = 1
- prefixes: Optional[dict[str, str]] = None
- add_title_as_prefix: bool = True
- concurrent_request_limit: int = 256
- max_retries: int = 8
- initial_backoff: float = 1
- max_backoff: float = 64.0
- quantization_settings: VectorQuantizationSettings = (
- VectorQuantizationSettings()
- )
- ## deprecated
- rerank_dimension: Optional[int] = None
- rerank_transformer_type: Optional[str] = None
- def validate_config(self) -> None:
- if self.provider not in self.supported_providers:
- raise ValueError(f"Provider '{self.provider}' is not supported.")
- @property
- def supported_providers(self) -> list[str]:
- return ["litellm", "openai", "ollama"]
- class EmbeddingProvider(Provider):
- class PipeStage(Enum):
- BASE = 1
- RERANK = 2
- def __init__(self, config: EmbeddingConfig):
- if not isinstance(config, EmbeddingConfig):
- raise ValueError(
- "EmbeddingProvider must be initialized with a `EmbeddingConfig`."
- )
- logger.info(f"Initializing EmbeddingProvider with config {config}.")
- super().__init__(config)
- self.config: EmbeddingConfig = config
- self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
- self.current_requests = 0
- async def _execute_with_backoff_async(self, task: dict[str, Any]):
- retries = 0
- backoff = self.config.initial_backoff
- while retries < self.config.max_retries:
- try:
- async with self.semaphore:
- return await self._execute_task(task)
- except AuthenticationError as e:
- raise
- except Exception as e:
- logger.warning(
- f"Request failed (attempt {retries + 1}): {str(e)}"
- )
- retries += 1
- if retries == self.config.max_retries:
- raise
- await asyncio.sleep(random.uniform(0, backoff))
- backoff = min(backoff * 2, self.config.max_backoff)
- def _execute_with_backoff_sync(self, task: dict[str, Any]):
- retries = 0
- backoff = self.config.initial_backoff
- while retries < self.config.max_retries:
- try:
- return self._execute_task_sync(task)
- except AuthenticationError as e:
- raise
- except Exception as e:
- logger.warning(
- f"Request failed (attempt {retries + 1}): {str(e)}"
- )
- retries += 1
- if retries == self.config.max_retries:
- raise
- time.sleep(random.uniform(0, backoff))
- backoff = min(backoff * 2, self.config.max_backoff)
- @abstractmethod
- async def _execute_task(self, task: dict[str, Any]):
- pass
- @abstractmethod
- def _execute_task_sync(self, task: dict[str, Any]):
- pass
- async def async_get_embedding(
- self,
- text: str,
- stage: PipeStage = PipeStage.BASE,
- purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
- ):
- task = {
- "text": text,
- "stage": stage,
- "purpose": purpose,
- }
- return await self._execute_with_backoff_async(task)
- def get_embedding(
- self,
- text: str,
- stage: PipeStage = PipeStage.BASE,
- purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
- ):
- task = {
- "text": text,
- "stage": stage,
- "purpose": purpose,
- }
- return self._execute_with_backoff_sync(task)
- async def async_get_embeddings(
- self,
- texts: list[str],
- stage: PipeStage = PipeStage.BASE,
- purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
- ):
- task = {
- "texts": texts,
- "stage": stage,
- "purpose": purpose,
- }
- return await self._execute_with_backoff_async(task)
- def get_embeddings(
- self,
- texts: list[str],
- stage: PipeStage = PipeStage.BASE,
- purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
- ) -> list[list[float]]:
- task = {
- "texts": texts,
- "stage": stage,
- "purpose": purpose,
- }
- return self._execute_with_backoff_sync(task)
- @abstractmethod
- def rerank(
- self,
- query: str,
- results: list[ChunkSearchResult],
- stage: PipeStage = PipeStage.RERANK,
- limit: int = 10,
- ):
- pass
- @abstractmethod
- async def arerank(
- self,
- query: str,
- results: list[ChunkSearchResult],
- stage: PipeStage = PipeStage.RERANK,
- limit: int = 10,
- ):
- pass
- def set_prefixes(self, config_prefixes: dict[str, str], base_model: str):
- self.prefixes = {}
- for t, p in config_prefixes.items():
- purpose = EmbeddingPurpose(t.lower())
- self.prefixes[purpose] = p
- if base_model in default_embedding_prefixes:
- for t, p in default_embedding_prefixes[base_model].items():
- if t not in self.prefixes:
- self.prefixes[t] = p
|