litellm.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. import logging
  2. import os
  3. from copy import copy
  4. from typing import Any
  5. import litellm
  6. import requests
  7. from aiohttp import ClientError, ClientSession
  8. from litellm import AuthenticationError, aembedding, embedding
  9. from core.base import (
  10. ChunkSearchResult,
  11. EmbeddingConfig,
  12. EmbeddingProvider,
  13. EmbeddingPurpose,
  14. R2RException,
  15. )
  16. logger = logging.getLogger()
  17. class LiteLLMEmbeddingProvider(EmbeddingProvider):
  18. def __init__(
  19. self,
  20. config: EmbeddingConfig,
  21. *args,
  22. **kwargs,
  23. ) -> None:
  24. super().__init__(config)
  25. self.litellm_embedding = embedding
  26. self.litellm_aembedding = aembedding
  27. provider = config.provider
  28. if not provider:
  29. raise ValueError(
  30. "Must set provider in order to initialize `LiteLLMEmbeddingProvider`."
  31. )
  32. if provider != "litellm":
  33. raise ValueError(
  34. "LiteLLMEmbeddingProvider must be initialized with provider `litellm`."
  35. )
  36. self.rerank_url = None
  37. if config.rerank_model:
  38. if "huggingface" not in config.rerank_model:
  39. raise ValueError(
  40. "LiteLLMEmbeddingProvider only supports re-ranking via the HuggingFace text-embeddings-inference API"
  41. )
  42. url = os.getenv("HUGGINGFACE_API_BASE") or config.rerank_url
  43. if not url:
  44. raise ValueError(
  45. "LiteLLMEmbeddingProvider requires a valid reranking API url to be set via `embedding.rerank_url` in the r2r.toml, or via the environment variable `HUGGINGFACE_API_BASE`."
  46. )
  47. self.rerank_url = url
  48. self.base_model = config.base_model
  49. if "amazon" in self.base_model:
  50. logger.warn("Amazon embedding model detected, dropping params")
  51. litellm.drop_params = True
  52. self.base_dimension = config.base_dimension
  53. def _get_embedding_kwargs(self, **kwargs):
  54. embedding_kwargs = {
  55. "model": self.base_model,
  56. "dimensions": self.base_dimension,
  57. }
  58. embedding_kwargs.update(kwargs)
  59. return embedding_kwargs
  60. async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
  61. texts = task["texts"]
  62. kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
  63. try:
  64. response = await self.litellm_aembedding(
  65. input=texts,
  66. **kwargs,
  67. )
  68. return [data["embedding"] for data in response.data]
  69. except AuthenticationError as e:
  70. logger.error(
  71. "Authentication error: Invalid API key or credentials."
  72. )
  73. raise
  74. except Exception as e:
  75. error_msg = f"Error getting embeddings: {str(e)}"
  76. logger.error(error_msg)
  77. raise R2RException(error_msg, 400)
  78. def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
  79. texts = task["texts"]
  80. kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
  81. try:
  82. response = self.litellm_embedding(
  83. input=texts,
  84. **kwargs,
  85. )
  86. return [data["embedding"] for data in response.data]
  87. except AuthenticationError as e:
  88. logger.error(
  89. "Authentication error: Invalid API key or credentials."
  90. )
  91. raise
  92. except Exception as e:
  93. error_msg = f"Error getting embeddings: {str(e)}"
  94. logger.error(error_msg)
  95. raise R2RException(error_msg, 400)
  96. async def async_get_embedding(
  97. self,
  98. text: str,
  99. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
  100. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  101. **kwargs,
  102. ) -> list[float]:
  103. if stage != EmbeddingProvider.PipeStage.BASE:
  104. raise ValueError(
  105. "LiteLLMEmbeddingProvider only supports search stage."
  106. )
  107. task = {
  108. "texts": [text],
  109. "stage": stage,
  110. "purpose": purpose,
  111. "kwargs": kwargs,
  112. }
  113. return (await self._execute_with_backoff_async(task))[0]
  114. def get_embedding(
  115. self,
  116. text: str,
  117. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
  118. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  119. **kwargs,
  120. ) -> list[float]:
  121. if stage != EmbeddingProvider.PipeStage.BASE:
  122. raise ValueError(
  123. "Error getting embeddings: LiteLLMEmbeddingProvider only supports search stage."
  124. )
  125. task = {
  126. "texts": [text],
  127. "stage": stage,
  128. "purpose": purpose,
  129. "kwargs": kwargs,
  130. }
  131. return self._execute_with_backoff_sync(task)[0]
  132. async def async_get_embeddings(
  133. self,
  134. texts: list[str],
  135. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
  136. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  137. **kwargs,
  138. ) -> list[list[float]]:
  139. if stage != EmbeddingProvider.PipeStage.BASE:
  140. raise ValueError(
  141. "LiteLLMEmbeddingProvider only supports search stage."
  142. )
  143. task = {
  144. "texts": texts,
  145. "stage": stage,
  146. "purpose": purpose,
  147. "kwargs": kwargs,
  148. }
  149. return await self._execute_with_backoff_async(task)
  150. def get_embeddings(
  151. self,
  152. texts: list[str],
  153. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
  154. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  155. **kwargs,
  156. ) -> list[list[float]]:
  157. if stage != EmbeddingProvider.PipeStage.BASE:
  158. raise ValueError(
  159. "LiteLLMEmbeddingProvider only supports search stage."
  160. )
  161. task = {
  162. "texts": texts,
  163. "stage": stage,
  164. "purpose": purpose,
  165. "kwargs": kwargs,
  166. }
  167. return self._execute_with_backoff_sync(task)
  168. def rerank(
  169. self,
  170. query: str,
  171. results: list[ChunkSearchResult],
  172. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
  173. limit: int = 10,
  174. ):
  175. if self.config.rerank_model is not None:
  176. if not self.rerank_url:
  177. raise ValueError(
  178. "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider"
  179. )
  180. texts = [result.text for result in results]
  181. payload = {
  182. "query": query,
  183. "texts": texts,
  184. "model-id": self.config.rerank_model.split("huggingface/")[1],
  185. }
  186. headers = {"Content-Type": "application/json"}
  187. try:
  188. response = requests.post(
  189. self.rerank_url, json=payload, headers=headers
  190. )
  191. response.raise_for_status()
  192. reranked_results = response.json()
  193. # Copy reranked results into new array
  194. scored_results = []
  195. for rank_info in reranked_results:
  196. original_result = results[rank_info["index"]]
  197. copied_result = copy(original_result)
  198. # Inject the reranking score into the result object
  199. copied_result.score = rank_info["score"]
  200. scored_results.append(copied_result)
  201. # Return only the ChunkSearchResult objects, limited to specified count
  202. return scored_results[:limit]
  203. except requests.RequestException as e:
  204. logger.error(f"Error during reranking: {str(e)}")
  205. # Fall back to returning the original results if reranking fails
  206. return results[:limit]
  207. else:
  208. return results[:limit]
  209. async def arerank(
  210. self,
  211. query: str,
  212. results: list[ChunkSearchResult],
  213. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
  214. limit: int = 10,
  215. ) -> list[ChunkSearchResult]:
  216. """
  217. Asynchronously rerank search results using the configured rerank model.
  218. Args:
  219. query: The search query string
  220. results: List of ChunkSearchResult objects to rerank
  221. stage: The pipeline stage (must be RERANK)
  222. limit: Maximum number of results to return
  223. Returns:
  224. List of reranked ChunkSearchResult objects, limited to specified count
  225. """
  226. if self.config.rerank_model is not None:
  227. if not self.rerank_url:
  228. raise ValueError(
  229. "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider"
  230. )
  231. texts = [result.text for result in results]
  232. payload = {
  233. "query": query,
  234. "texts": texts,
  235. "model-id": self.config.rerank_model.split("huggingface/")[1],
  236. }
  237. headers = {"Content-Type": "application/json"}
  238. try:
  239. async with ClientSession() as session:
  240. async with session.post(
  241. self.rerank_url, json=payload, headers=headers
  242. ) as response:
  243. response.raise_for_status()
  244. reranked_results = await response.json()
  245. # Copy reranked results into new array
  246. scored_results = []
  247. for rank_info in reranked_results:
  248. original_result = results[rank_info["index"]]
  249. copied_result = copy(original_result)
  250. # Inject the reranking score into the result object
  251. copied_result.score = rank_info["score"]
  252. scored_results.append(copied_result)
  253. # Return only the ChunkSearchResult objects, limited to specified count
  254. return scored_results[:limit]
  255. except (ClientError, Exception) as e:
  256. logger.error(f"Error during async reranking: {str(e)}")
  257. # Fall back to returning the original results if reranking fails
  258. return results[:limit]
  259. else:
  260. return results[:limit]