litellm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. import contextlib
  2. import logging
  3. import math
  4. import os
  5. from copy import copy
  6. from typing import Any
  7. import litellm
  8. import requests
  9. from aiohttp import ClientError, ClientSession
  10. from litellm import AuthenticationError, aembedding, embedding
  11. from core.base import (
  12. ChunkSearchResult,
  13. EmbeddingConfig,
  14. EmbeddingProvider,
  15. R2RException,
  16. )
  17. from .utils import truncate_texts_to_token_limit
  18. logger = logging.getLogger()
  19. #os.environ["OPENAI_API_KEY"] = "sk-j9Uwupu0NPZtdDS_IfEZlRWpX1JgFyZFLZProkesy2QbtqMs16pDnylAozU"
  20. #os.environ["OPENAI_API_BASE"] = "https://onehub.cocorobo.cn/v1"
  21. class LiteLLMEmbeddingProvider(EmbeddingProvider):
  22. def __init__(
  23. self,
  24. config: EmbeddingConfig,
  25. *args,
  26. **kwargs,
  27. ) -> None:
  28. super().__init__(config)
  29. self.litellm_embedding = embedding
  30. self.litellm_aembedding = aembedding
  31. provider = config.provider
  32. if not provider:
  33. raise ValueError(
  34. "Must set provider in order to initialize `LiteLLMEmbeddingProvider`."
  35. )
  36. if provider != "litellm":
  37. raise ValueError(
  38. "LiteLLMEmbeddingProvider must be initialized with provider `litellm`."
  39. )
  40. self.rerank_url = None
  41. if config.rerank_model:
  42. if "huggingface" not in config.rerank_model:
  43. raise ValueError(
  44. "LiteLLMEmbeddingProvider only supports re-ranking via the HuggingFace text-embeddings-inference API"
  45. )
  46. if url := os.getenv("HUGGINGFACE_API_BASE") or config.rerank_url:
  47. self.rerank_url = url
  48. else:
  49. raise ValueError(
  50. "LiteLLMEmbeddingProvider requires a valid reranking API url to be set via `embedding.rerank_url` in the r2r.toml, or via the environment variable `HUGGINGFACE_API_BASE`."
  51. )
  52. self.base_model = config.base_model
  53. if "amazon" in self.base_model:
  54. logger.warning("Amazon embedding model detected, dropping params")
  55. litellm.drop_params = True
  56. self.base_dimension = config.base_dimension
  57. def _get_embedding_kwargs(self, **kwargs):
  58. embedding_kwargs = {
  59. "api_base":"http://172.16.12.13:3000/v1",
  60. "api_key":"sk-j9Uwupu0NPZtdDS_IfEZlRWpX1JgFyZFLZProkesy2QbtqMs16pDnylAozU",
  61. "model": self.base_model,
  62. "dimensions": self.base_dimension,
  63. }
  64. embedding_kwargs.update(kwargs)
  65. return embedding_kwargs
  66. async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
  67. texts = task["texts"]
  68. kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
  69. if "dimensions" in kwargs and math.isnan(kwargs["dimensions"]):
  70. kwargs.pop("dimensions")
  71. logger.warning("Dropping nan dimensions from kwargs")
  72. try:
  73. # Truncate text if it exceeds the model's max input tokens. Some providers do this by default, others do not.
  74. if kwargs.get("model"):
  75. with contextlib.suppress(Exception):
  76. texts = truncate_texts_to_token_limit(
  77. texts, kwargs["model"]
  78. )
  79. response = await self.litellm_aembedding(
  80. input=texts,
  81. **kwargs,
  82. )
  83. return [data["embedding"] for data in response.data]
  84. except AuthenticationError:
  85. logger.error(
  86. "Authentication error: Invalid API key or credentials."
  87. )
  88. raise
  89. except Exception as e:
  90. error_msg = f"Error getting embeddings: {str(e)}"
  91. logger.error(error_msg)
  92. raise R2RException(error_msg, 400) from e
  93. def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
  94. texts = task["texts"]
  95. kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
  96. try:
  97. # Truncate text if it exceeds the model's max input tokens. Some providers do this by default, others do not.
  98. if kwargs.get("model"):
  99. with contextlib.suppress(Exception):
  100. texts = truncate_texts_to_token_limit(
  101. texts, kwargs["model"]
  102. )
  103. response = self.litellm_embedding(
  104. input=texts,
  105. **kwargs,
  106. )
  107. return [data["embedding"] for data in response.data]
  108. except AuthenticationError:
  109. logger.error(
  110. "Authentication error: Invalid API key or credentials."
  111. )
  112. raise
  113. except Exception as e:
  114. error_msg = f"Error getting embeddings: {str(e)}"
  115. logger.error(error_msg)
  116. raise R2RException(error_msg, 400) from e
  117. async def async_get_embedding(
  118. self,
  119. text: str,
  120. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
  121. **kwargs,
  122. ) -> list[float]:
  123. if stage != EmbeddingProvider.Step.BASE:
  124. raise ValueError(
  125. "LiteLLMEmbeddingProvider only supports search stage."
  126. )
  127. task = {
  128. "texts": [text],
  129. "stage": stage,
  130. "kwargs": kwargs,
  131. }
  132. return (await self._execute_with_backoff_async(task))[0]
  133. def get_embedding(
  134. self,
  135. text: str,
  136. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
  137. **kwargs,
  138. ) -> list[float]:
  139. if stage != EmbeddingProvider.Step.BASE:
  140. raise ValueError(
  141. "Error getting embeddings: LiteLLMEmbeddingProvider only supports search stage."
  142. )
  143. task = {
  144. "texts": [text],
  145. "stage": stage,
  146. "kwargs": kwargs,
  147. }
  148. return self._execute_with_backoff_sync(task)[0]
  149. async def async_get_embeddings(
  150. self,
  151. texts: list[str],
  152. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
  153. **kwargs,
  154. ) -> list[list[float]]:
  155. if stage != EmbeddingProvider.Step.BASE:
  156. raise ValueError(
  157. "LiteLLMEmbeddingProvider only supports search stage."
  158. )
  159. task = {
  160. "texts": texts,
  161. "stage": stage,
  162. "kwargs": kwargs,
  163. }
  164. return await self._execute_with_backoff_async(task)
  165. def get_embeddings(
  166. self,
  167. texts: list[str],
  168. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
  169. **kwargs,
  170. ) -> list[list[float]]:
  171. if stage != EmbeddingProvider.Step.BASE:
  172. raise ValueError(
  173. "LiteLLMEmbeddingProvider only supports search stage."
  174. )
  175. task = {
  176. "texts": texts,
  177. "stage": stage,
  178. "kwargs": kwargs,
  179. }
  180. return self._execute_with_backoff_sync(task)
  181. def rerank(
  182. self,
  183. query: str,
  184. results: list[ChunkSearchResult],
  185. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
  186. limit: int = 10,
  187. ):
  188. if self.config.rerank_model is not None:
  189. if not self.rerank_url:
  190. raise ValueError(
  191. "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider"
  192. )
  193. texts = [result.text for result in results]
  194. payload = {
  195. "query": query,
  196. "texts": texts,
  197. "model-id": self.config.rerank_model.split("huggingface/")[1],
  198. }
  199. headers = {"Content-Type": "application/json"}
  200. try:
  201. response = requests.post(
  202. self.rerank_url, json=payload, headers=headers
  203. )
  204. response.raise_for_status()
  205. reranked_results = response.json()
  206. # Copy reranked results into new array
  207. scored_results = []
  208. for rank_info in reranked_results:
  209. original_result = results[rank_info["index"]]
  210. copied_result = copy(original_result)
  211. # Inject the reranking score into the result object
  212. copied_result.score = rank_info["score"]
  213. scored_results.append(copied_result)
  214. # Return only the ChunkSearchResult objects, limited to specified count
  215. return scored_results[:limit]
  216. except requests.RequestException as e:
  217. logger.error(f"Error during reranking: {str(e)}")
  218. # Fall back to returning the original results if reranking fails
  219. return results[:limit]
  220. else:
  221. return results[:limit]
  222. async def arerank(
  223. self,
  224. query: str,
  225. results: list[ChunkSearchResult],
  226. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
  227. limit: int = 10,
  228. ) -> list[ChunkSearchResult]:
  229. """Asynchronously rerank search results using the configured rerank
  230. model.
  231. Args:
  232. query: The search query string
  233. results: List of ChunkSearchResult objects to rerank
  234. limit: Maximum number of results to return
  235. Returns:
  236. List of reranked ChunkSearchResult objects, limited to specified count
  237. """
  238. if self.config.rerank_model is not None:
  239. if not self.rerank_url:
  240. raise ValueError(
  241. "Error, `rerank_url` was expected to be set inside LiteLLMEmbeddingProvider"
  242. )
  243. texts = [result.text for result in results]
  244. payload = {
  245. "query": query,
  246. "texts": texts,
  247. "model-id": self.config.rerank_model.split("huggingface/")[1],
  248. }
  249. headers = {"Content-Type": "application/json"}
  250. try:
  251. async with ClientSession() as session:
  252. async with session.post(
  253. self.rerank_url, json=payload, headers=headers
  254. ) as response:
  255. response.raise_for_status()
  256. reranked_results = await response.json()
  257. # Copy reranked results into new array
  258. scored_results = []
  259. for rank_info in reranked_results:
  260. original_result = results[rank_info["index"]]
  261. copied_result = copy(original_result)
  262. # Inject the reranking score into the result object
  263. copied_result.score = rank_info["score"]
  264. scored_results.append(copied_result)
  265. # Return only the ChunkSearchResult objects, limited to specified count
  266. return scored_results[:limit]
  267. except (ClientError, Exception) as e:
  268. logger.error(f"Error during async reranking: {str(e)}")
  269. # Fall back to returning the original results if reranking fails
  270. return results[:limit]
  271. else:
  272. return results[:limit]