123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- import asyncio
- import logging
- import random
- import time
- from abc import abstractmethod
- from enum import Enum
- from typing import Any, Optional
- from litellm import AuthenticationError
- from core.base.abstractions import VectorQuantizationSettings
- from ..abstractions import (
- ChunkSearchResult,
- )
- from .base import Provider, ProviderConfig
- logger = logging.getLogger()
- class EmbeddingConfig(ProviderConfig):
- provider: str
- base_model: str
- base_dimension: int | float
- rerank_model: Optional[str] = None
- rerank_url: Optional[str] = None
- batch_size: int = 1
- concurrent_request_limit: int = 256
- max_retries: int = 3
- initial_backoff: float = 1
- max_backoff: float = 64.0
- quantization_settings: VectorQuantizationSettings = (
- VectorQuantizationSettings()
- )
- def validate_config(self) -> None:
- if self.provider not in self.supported_providers:
- raise ValueError(f"Provider '{self.provider}' is not supported.")
- @property
- def supported_providers(self) -> list[str]:
- return ["litellm", "openai", "ollama"]
- class EmbeddingProvider(Provider):
- class Step(Enum):
- BASE = 1
- RERANK = 2
- def __init__(self, config: EmbeddingConfig):
- if not isinstance(config, EmbeddingConfig):
- raise ValueError(
- "EmbeddingProvider must be initialized with a `EmbeddingConfig`."
- )
- logger.info(f"Initializing EmbeddingProvider with config {config}.")
- super().__init__(config)
- self.config: EmbeddingConfig = config
- self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
- self.current_requests = 0
- async def _execute_with_backoff_async(self, task: dict[str, Any]):
- retries = 0
- backoff = self.config.initial_backoff
- while retries < self.config.max_retries:
- try:
- async with self.semaphore:
- return await self._execute_task(task)
- except AuthenticationError:
- raise
- except Exception as e:
- logger.warning(
- f"Request failed (attempt {retries + 1}): {str(e)}"
- )
- retries += 1
- if retries == self.config.max_retries:
- raise
- await asyncio.sleep(random.uniform(0, backoff))
- backoff = min(backoff * 2, self.config.max_backoff)
- def _execute_with_backoff_sync(self, task: dict[str, Any]):
- retries = 0
- backoff = self.config.initial_backoff
- while retries < self.config.max_retries:
- try:
- return self._execute_task_sync(task)
- except AuthenticationError:
- raise
- except Exception as e:
- logger.warning(
- f"Request failed (attempt {retries + 1}): {str(e)}"
- )
- retries += 1
- if retries == self.config.max_retries:
- raise
- time.sleep(random.uniform(0, backoff))
- backoff = min(backoff * 2, self.config.max_backoff)
- @abstractmethod
- async def _execute_task(self, task: dict[str, Any]):
- pass
- @abstractmethod
- def _execute_task_sync(self, task: dict[str, Any]):
- pass
- async def async_get_embedding(
- self,
- text: str,
- stage: Step = Step.BASE,
- ):
- task = {
- "text": text,
- "stage": stage,
- }
- return await self._execute_with_backoff_async(task)
- def get_embedding(
- self,
- text: str,
- stage: Step = Step.BASE,
- ):
- task = {
- "text": text,
- "stage": stage,
- }
- return self._execute_with_backoff_sync(task)
- async def async_get_embeddings(
- self,
- texts: list[str],
- stage: Step = Step.BASE,
- ):
- task = {
- "texts": texts,
- "stage": stage,
- }
- return await self._execute_with_backoff_async(task)
- def get_embeddings(
- self,
- texts: list[str],
- stage: Step = Step.BASE,
- ) -> list[list[float]]:
- task = {
- "texts": texts,
- "stage": stage,
- }
- return self._execute_with_backoff_sync(task)
- @abstractmethod
- def rerank(
- self,
- query: str,
- results: list[ChunkSearchResult],
- stage: Step = Step.RERANK,
- limit: int = 10,
- ):
- pass
- @abstractmethod
- async def arerank(
- self,
- query: str,
- results: list[ChunkSearchResult],
- stage: Step = Step.RERANK,
- limit: int = 10,
- ):
- pass
|