ollama.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import logging
  2. import os
  3. from typing import Any
  4. from ollama import AsyncClient, Client
  5. from core.base import (
  6. ChunkSearchResult,
  7. EmbeddingConfig,
  8. EmbeddingProvider,
  9. EmbeddingPurpose,
  10. R2RException,
  11. )
  12. logger = logging.getLogger()
  13. class OllamaEmbeddingProvider(EmbeddingProvider):
  14. def __init__(self, config: EmbeddingConfig):
  15. super().__init__(config)
  16. provider = config.provider
  17. if not provider:
  18. raise ValueError(
  19. "Must set provider in order to initialize `OllamaEmbeddingProvider`."
  20. )
  21. if provider != "ollama":
  22. raise ValueError(
  23. "OllamaEmbeddingProvider must be initialized with provider `ollama`."
  24. )
  25. if config.rerank_model:
  26. raise ValueError(
  27. "OllamaEmbeddingProvider does not support separate reranking."
  28. )
  29. self.base_model = config.base_model
  30. self.base_dimension = config.base_dimension
  31. self.base_url = os.getenv("OLLAMA_API_BASE")
  32. logger.info(
  33. f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}"
  34. )
  35. self.client = Client(host=self.base_url)
  36. self.aclient = AsyncClient(host=self.base_url)
  37. self.set_prefixes(config.prefixes or {}, self.base_model)
  38. self.batch_size = config.batch_size or 32
  39. def _get_embedding_kwargs(self, **kwargs):
  40. embedding_kwargs = {
  41. "model": self.base_model,
  42. }
  43. embedding_kwargs.update(kwargs)
  44. return embedding_kwargs
  45. async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
  46. texts = task["texts"]
  47. purpose = task.get("purpose", EmbeddingPurpose.INDEX)
  48. kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
  49. try:
  50. embeddings = []
  51. for i in range(0, len(texts), self.batch_size):
  52. batch = texts[i : i + self.batch_size]
  53. prefixed_batch = [
  54. self.prefixes.get(purpose, "") + text for text in batch
  55. ]
  56. response = await self.aclient.embed(
  57. input=prefixed_batch, **kwargs
  58. )
  59. embeddings.extend(response["embeddings"])
  60. return embeddings
  61. except Exception as e:
  62. error_msg = f"Error getting embeddings: {str(e)}"
  63. logger.error(error_msg)
  64. raise R2RException(error_msg, 400)
  65. def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
  66. texts = task["texts"]
  67. purpose = task.get("purpose", EmbeddingPurpose.INDEX)
  68. kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
  69. try:
  70. embeddings = []
  71. for i in range(0, len(texts), self.batch_size):
  72. batch = texts[i : i + self.batch_size]
  73. prefixed_batch = [
  74. self.prefixes.get(purpose, "") + text for text in batch
  75. ]
  76. response = self.client.embed(input=prefixed_batch, **kwargs)
  77. embeddings.extend(response["embeddings"])
  78. return embeddings
  79. except Exception as e:
  80. error_msg = f"Error getting embeddings: {str(e)}"
  81. logger.error(error_msg)
  82. raise R2RException(error_msg, 400)
  83. async def async_get_embedding(
  84. self,
  85. text: str,
  86. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
  87. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  88. **kwargs,
  89. ) -> list[float]:
  90. if stage != EmbeddingProvider.PipeStage.BASE:
  91. raise ValueError(
  92. "OllamaEmbeddingProvider only supports search stage."
  93. )
  94. task = {
  95. "texts": [text],
  96. "stage": stage,
  97. "purpose": purpose,
  98. "kwargs": kwargs,
  99. }
  100. result = await self._execute_with_backoff_async(task)
  101. return result[0]
  102. def get_embedding(
  103. self,
  104. text: str,
  105. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
  106. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  107. **kwargs,
  108. ) -> list[float]:
  109. if stage != EmbeddingProvider.PipeStage.BASE:
  110. raise ValueError(
  111. "OllamaEmbeddingProvider only supports search stage."
  112. )
  113. task = {
  114. "texts": [text],
  115. "stage": stage,
  116. "purpose": purpose,
  117. "kwargs": kwargs,
  118. }
  119. result = self._execute_with_backoff_sync(task)
  120. return result[0]
  121. async def async_get_embeddings(
  122. self,
  123. texts: list[str],
  124. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
  125. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  126. **kwargs,
  127. ) -> list[list[float]]:
  128. if stage != EmbeddingProvider.PipeStage.BASE:
  129. raise ValueError(
  130. "OllamaEmbeddingProvider only supports search stage."
  131. )
  132. task = {
  133. "texts": texts,
  134. "stage": stage,
  135. "purpose": purpose,
  136. "kwargs": kwargs,
  137. }
  138. return await self._execute_with_backoff_async(task)
  139. def get_embeddings(
  140. self,
  141. texts: list[str],
  142. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
  143. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  144. **kwargs,
  145. ) -> list[list[float]]:
  146. if stage != EmbeddingProvider.PipeStage.BASE:
  147. raise ValueError(
  148. "OllamaEmbeddingProvider only supports search stage."
  149. )
  150. task = {
  151. "texts": texts,
  152. "stage": stage,
  153. "purpose": purpose,
  154. "kwargs": kwargs,
  155. }
  156. return self._execute_with_backoff_sync(task)
  157. def rerank(
  158. self,
  159. query: str,
  160. results: list[ChunkSearchResult],
  161. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
  162. limit: int = 10,
  163. ) -> list[ChunkSearchResult]:
  164. return results[:limit]
  165. async def arerank(
  166. self,
  167. query: str,
  168. results: list[ChunkSearchResult],
  169. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
  170. limit: int = 10,
  171. ):
  172. return results[:limit]