embedding.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import asyncio
  2. import logging
  3. import random
  4. import time
  5. from abc import abstractmethod
  6. from enum import Enum
  7. from typing import Any, Optional
  8. from litellm import AuthenticationError
  9. from core.base.abstractions import VectorQuantizationSettings
  10. from ..abstractions import (
  11. ChunkSearchResult,
  12. EmbeddingPurpose,
  13. default_embedding_prefixes,
  14. )
  15. from .base import Provider, ProviderConfig
  16. logger = logging.getLogger()
  17. class EmbeddingConfig(ProviderConfig):
  18. provider: str
  19. base_model: str
  20. base_dimension: int
  21. rerank_model: Optional[str] = None
  22. rerank_url: Optional[str] = None
  23. batch_size: int = 1
  24. prefixes: Optional[dict[str, str]] = None
  25. add_title_as_prefix: bool = True
  26. concurrent_request_limit: int = 256
  27. max_retries: int = 8
  28. initial_backoff: float = 1
  29. max_backoff: float = 64.0
  30. quantization_settings: VectorQuantizationSettings = (
  31. VectorQuantizationSettings()
  32. )
  33. ## deprecated
  34. rerank_dimension: Optional[int] = None
  35. rerank_transformer_type: Optional[str] = None
  36. def validate_config(self) -> None:
  37. if self.provider not in self.supported_providers:
  38. raise ValueError(f"Provider '{self.provider}' is not supported.")
  39. @property
  40. def supported_providers(self) -> list[str]:
  41. return ["litellm", "openai", "ollama"]
  42. class EmbeddingProvider(Provider):
  43. class PipeStage(Enum):
  44. BASE = 1
  45. RERANK = 2
  46. def __init__(self, config: EmbeddingConfig):
  47. if not isinstance(config, EmbeddingConfig):
  48. raise ValueError(
  49. "EmbeddingProvider must be initialized with a `EmbeddingConfig`."
  50. )
  51. logger.info(f"Initializing EmbeddingProvider with config {config}.")
  52. super().__init__(config)
  53. self.config: EmbeddingConfig = config
  54. self.semaphore = asyncio.Semaphore(config.concurrent_request_limit)
  55. self.current_requests = 0
  56. async def _execute_with_backoff_async(self, task: dict[str, Any]):
  57. retries = 0
  58. backoff = self.config.initial_backoff
  59. while retries < self.config.max_retries:
  60. try:
  61. async with self.semaphore:
  62. return await self._execute_task(task)
  63. except AuthenticationError as e:
  64. raise
  65. except Exception as e:
  66. logger.warning(
  67. f"Request failed (attempt {retries + 1}): {str(e)}"
  68. )
  69. retries += 1
  70. if retries == self.config.max_retries:
  71. raise
  72. await asyncio.sleep(random.uniform(0, backoff))
  73. backoff = min(backoff * 2, self.config.max_backoff)
  74. def _execute_with_backoff_sync(self, task: dict[str, Any]):
  75. retries = 0
  76. backoff = self.config.initial_backoff
  77. while retries < self.config.max_retries:
  78. try:
  79. return self._execute_task_sync(task)
  80. except AuthenticationError as e:
  81. raise
  82. except Exception as e:
  83. logger.warning(
  84. f"Request failed (attempt {retries + 1}): {str(e)}"
  85. )
  86. retries += 1
  87. if retries == self.config.max_retries:
  88. raise
  89. time.sleep(random.uniform(0, backoff))
  90. backoff = min(backoff * 2, self.config.max_backoff)
  91. @abstractmethod
  92. async def _execute_task(self, task: dict[str, Any]):
  93. pass
  94. @abstractmethod
  95. def _execute_task_sync(self, task: dict[str, Any]):
  96. pass
  97. async def async_get_embedding(
  98. self,
  99. text: str,
  100. stage: PipeStage = PipeStage.BASE,
  101. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  102. ):
  103. task = {
  104. "text": text,
  105. "stage": stage,
  106. "purpose": purpose,
  107. }
  108. return await self._execute_with_backoff_async(task)
  109. def get_embedding(
  110. self,
  111. text: str,
  112. stage: PipeStage = PipeStage.BASE,
  113. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  114. ):
  115. task = {
  116. "text": text,
  117. "stage": stage,
  118. "purpose": purpose,
  119. }
  120. return self._execute_with_backoff_sync(task)
  121. async def async_get_embeddings(
  122. self,
  123. texts: list[str],
  124. stage: PipeStage = PipeStage.BASE,
  125. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  126. ):
  127. task = {
  128. "texts": texts,
  129. "stage": stage,
  130. "purpose": purpose,
  131. }
  132. return await self._execute_with_backoff_async(task)
  133. def get_embeddings(
  134. self,
  135. texts: list[str],
  136. stage: PipeStage = PipeStage.BASE,
  137. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  138. ) -> list[list[float]]:
  139. task = {
  140. "texts": texts,
  141. "stage": stage,
  142. "purpose": purpose,
  143. }
  144. return self._execute_with_backoff_sync(task)
  145. @abstractmethod
  146. def rerank(
  147. self,
  148. query: str,
  149. results: list[ChunkSearchResult],
  150. stage: PipeStage = PipeStage.RERANK,
  151. limit: int = 10,
  152. ):
  153. pass
  154. @abstractmethod
  155. async def arerank(
  156. self,
  157. query: str,
  158. results: list[ChunkSearchResult],
  159. stage: PipeStage = PipeStage.RERANK,
  160. limit: int = 10,
  161. ):
  162. pass
  163. def set_prefixes(self, config_prefixes: dict[str, str], base_model: str):
  164. self.prefixes = {}
  165. for t, p in config_prefixes.items():
  166. purpose = EmbeddingPurpose(t.lower())
  167. self.prefixes[purpose] = p
  168. if base_model in default_embedding_prefixes:
  169. for t, p in default_embedding_prefixes[base_model].items():
  170. if t not in self.prefixes:
  171. self.prefixes[t] = p