ollama.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import logging
  2. import os
  3. from typing import Any
  4. from ollama import AsyncClient, Client
  5. from core.base import (
  6. ChunkSearchResult,
  7. EmbeddingConfig,
  8. EmbeddingProvider,
  9. R2RException,
  10. )
  11. logger = logging.getLogger()
  12. class OllamaEmbeddingProvider(EmbeddingProvider):
  13. def __init__(self, config: EmbeddingConfig):
  14. super().__init__(config)
  15. provider = config.provider
  16. if not provider:
  17. raise ValueError(
  18. "Must set provider in order to initialize `OllamaEmbeddingProvider`."
  19. )
  20. if provider != "ollama":
  21. raise ValueError(
  22. "OllamaEmbeddingProvider must be initialized with provider `ollama`."
  23. )
  24. if config.rerank_model:
  25. raise ValueError(
  26. "OllamaEmbeddingProvider does not support separate reranking."
  27. )
  28. self.base_model = config.base_model
  29. self.base_dimension = config.base_dimension
  30. self.base_url = os.getenv("OLLAMA_API_BASE")
  31. logger.info(
  32. f"Using Ollama API base URL: {self.base_url or 'http://127.0.0.1:11434'}"
  33. )
  34. self.client = Client(host=self.base_url)
  35. self.aclient = AsyncClient(host=self.base_url)
  36. self.batch_size = config.batch_size or 32
  37. def _get_embedding_kwargs(self, **kwargs):
  38. embedding_kwargs = {
  39. "model": self.base_model,
  40. }
  41. embedding_kwargs.update(kwargs)
  42. return embedding_kwargs
  43. async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
  44. texts = task["texts"]
  45. kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
  46. try:
  47. embeddings = []
  48. for i in range(0, len(texts), self.batch_size):
  49. batch = texts[i : i + self.batch_size]
  50. response = await self.aclient.embed(input=batch, **kwargs)
  51. embeddings.extend(response["embeddings"])
  52. return embeddings
  53. except Exception as e:
  54. error_msg = f"Error getting embeddings: {str(e)}"
  55. logger.error(error_msg)
  56. raise R2RException(error_msg, 400) from e
  57. def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
  58. texts = task["texts"]
  59. kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
  60. try:
  61. embeddings = []
  62. for i in range(0, len(texts), self.batch_size):
  63. batch = texts[i : i + self.batch_size]
  64. response = self.client.embed(input=batch, **kwargs)
  65. embeddings.extend(response["embeddings"])
  66. return embeddings
  67. except Exception as e:
  68. error_msg = f"Error getting embeddings: {str(e)}"
  69. logger.error(error_msg)
  70. raise R2RException(error_msg, 400) from e
  71. async def async_get_embedding(
  72. self,
  73. text: str,
  74. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
  75. **kwargs,
  76. ) -> list[float]:
  77. if stage != EmbeddingProvider.Step.BASE:
  78. raise ValueError(
  79. "OllamaEmbeddingProvider only supports search stage."
  80. )
  81. task = {
  82. "texts": [text],
  83. "stage": stage,
  84. "kwargs": kwargs,
  85. }
  86. result = await self._execute_with_backoff_async(task)
  87. return result[0]
  88. def get_embedding(
  89. self,
  90. text: str,
  91. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
  92. **kwargs,
  93. ) -> list[float]:
  94. if stage != EmbeddingProvider.Step.BASE:
  95. raise ValueError(
  96. "OllamaEmbeddingProvider only supports search stage."
  97. )
  98. task = {
  99. "texts": [text],
  100. "stage": stage,
  101. "kwargs": kwargs,
  102. }
  103. result = self._execute_with_backoff_sync(task)
  104. return result[0]
  105. async def async_get_embeddings(
  106. self,
  107. texts: list[str],
  108. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
  109. **kwargs,
  110. ) -> list[list[float]]:
  111. if stage != EmbeddingProvider.Step.BASE:
  112. raise ValueError(
  113. "OllamaEmbeddingProvider only supports search stage."
  114. )
  115. task = {
  116. "texts": texts,
  117. "stage": stage,
  118. "kwargs": kwargs,
  119. }
  120. return await self._execute_with_backoff_async(task)
  121. def get_embeddings(
  122. self,
  123. texts: list[str],
  124. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
  125. **kwargs,
  126. ) -> list[list[float]]:
  127. if stage != EmbeddingProvider.Step.BASE:
  128. raise ValueError(
  129. "OllamaEmbeddingProvider only supports search stage."
  130. )
  131. task = {
  132. "texts": texts,
  133. "stage": stage,
  134. "kwargs": kwargs,
  135. }
  136. return self._execute_with_backoff_sync(task)
  137. def rerank(
  138. self,
  139. query: str,
  140. results: list[ChunkSearchResult],
  141. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
  142. limit: int = 10,
  143. ) -> list[ChunkSearchResult]:
  144. return results[:limit]
  145. async def arerank(
  146. self,
  147. query: str,
  148. results: list[ChunkSearchResult],
  149. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
  150. limit: int = 10,
  151. ):
  152. return results[:limit]