embedding.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import asyncio
  2. import logging
  3. import random
  4. import time
  5. from abc import abstractmethod
  6. from enum import Enum
  7. from typing import Any, Optional
  8. from litellm import AuthenticationError
  9. from core.base.abstractions import VectorQuantizationSettings
  10. from ..abstractions import (
  11. ChunkSearchResult,
  12. )
  13. from .base import Provider, ProviderConfig
  14. logger = logging.getLogger()
  15. class EmbeddingConfig(ProviderConfig):
  16. provider: str
  17. base_model: str
  18. base_dimension: int | float
  19. rerank_model: Optional[str] = None
  20. rerank_url: Optional[str] = None
  21. batch_size: int = 1
  22. concurrent_request_limit: int = 256
  23. max_retries: int = 3
  24. initial_backoff: float = 1
  25. max_backoff: float = 64.0
  26. quantization_settings: VectorQuantizationSettings = (
  27. VectorQuantizationSettings()
  28. )
  29. def validate_config(self) -> None:
  30. if self.provider not in self.supported_providers:
  31. raise ValueError(f"Provider '{self.provider}' is not supported.")
  32. @property
  33. def supported_providers(self) -> list[str]:
  34. return ["litellm", "openai", "ollama"]
  35. class EmbeddingProvider(Provider):
  36. class Step(Enum):
  37. BASE = 1
  38. RERANK = 2
  39. def __init__(self, config: EmbeddingConfig):
  40. if not isinstance(config, EmbeddingConfig):
  41. raise ValueError(
  42. "EmbeddingProvider must be initialized with a `EmbeddingConfig`."
  43. )
  44. logger.info(f"Initializing EmbeddingProvider with config {config}.")
  45. super().__init__(config)
  46. self.config: EmbeddingConfig = config
  47. self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
  48. self.current_requests = 0
  49. async def _execute_with_backoff_async(self, task: dict[str, Any]):
  50. retries = 0
  51. backoff = self.config.initial_backoff
  52. while retries < self.config.max_retries:
  53. try:
  54. async with self.semaphore:
  55. return await self._execute_task(task)
  56. except AuthenticationError:
  57. raise
  58. except Exception as e:
  59. logger.warning(
  60. f"Request failed (attempt {retries + 1}): {str(e)}"
  61. )
  62. retries += 1
  63. if retries == self.config.max_retries:
  64. raise
  65. await asyncio.sleep(random.uniform(0, backoff))
  66. backoff = min(backoff * 2, self.config.max_backoff)
  67. def _execute_with_backoff_sync(self, task: dict[str, Any]):
  68. retries = 0
  69. backoff = self.config.initial_backoff
  70. while retries < self.config.max_retries:
  71. try:
  72. return self._execute_task_sync(task)
  73. except AuthenticationError:
  74. raise
  75. except Exception as e:
  76. logger.warning(
  77. f"Request failed (attempt {retries + 1}): {str(e)}"
  78. )
  79. retries += 1
  80. if retries == self.config.max_retries:
  81. raise
  82. time.sleep(random.uniform(0, backoff))
  83. backoff = min(backoff * 2, self.config.max_backoff)
  84. @abstractmethod
  85. async def _execute_task(self, task: dict[str, Any]):
  86. pass
  87. @abstractmethod
  88. def _execute_task_sync(self, task: dict[str, Any]):
  89. pass
  90. async def async_get_embedding(
  91. self,
  92. text: str,
  93. stage: Step = Step.BASE,
  94. ):
  95. task = {
  96. "text": text,
  97. "stage": stage,
  98. }
  99. return await self._execute_with_backoff_async(task)
  100. def get_embedding(
  101. self,
  102. text: str,
  103. stage: Step = Step.BASE,
  104. ):
  105. task = {
  106. "text": text,
  107. "stage": stage,
  108. }
  109. return self._execute_with_backoff_sync(task)
  110. async def async_get_embeddings(
  111. self,
  112. texts: list[str],
  113. stage: Step = Step.BASE,
  114. ):
  115. task = {
  116. "texts": texts,
  117. "stage": stage,
  118. }
  119. return await self._execute_with_backoff_async(task)
  120. def get_embeddings(
  121. self,
  122. texts: list[str],
  123. stage: Step = Step.BASE,
  124. ) -> list[list[float]]:
  125. task = {
  126. "texts": texts,
  127. "stage": stage,
  128. }
  129. return self._execute_with_backoff_sync(task)
  130. @abstractmethod
  131. def rerank(
  132. self,
  133. query: str,
  134. results: list[ChunkSearchResult],
  135. stage: Step = Step.RERANK,
  136. limit: int = 10,
  137. ):
  138. pass
  139. @abstractmethod
  140. async def arerank(
  141. self,
  142. query: str,
  143. results: list[ChunkSearchResult],
  144. stage: Step = Step.RERANK,
  145. limit: int = 10,
  146. ):
  147. pass