openai.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import logging
  2. import os
  3. from typing import Any
  4. from openai import AsyncOpenAI, AuthenticationError, OpenAI
  5. from openai._types import NOT_GIVEN
  6. from core.base import (
  7. ChunkSearchResult,
  8. EmbeddingConfig,
  9. EmbeddingProvider,
  10. EmbeddingPurpose,
  11. )
  12. logger = logging.getLogger()
  13. class OpenAIEmbeddingProvider(EmbeddingProvider):
  14. MODEL_TO_TOKENIZER = {
  15. "text-embedding-ada-002": "cl100k_base",
  16. "text-embedding-3-small": "cl100k_base",
  17. "text-embedding-3-large": "cl100k_base",
  18. }
  19. MODEL_TO_DIMENSIONS = {
  20. "text-embedding-ada-002": [1536],
  21. "text-embedding-3-small": [512, 1536],
  22. "text-embedding-3-large": [256, 1024, 3072],
  23. }
  24. def __init__(self, config: EmbeddingConfig):
  25. super().__init__(config)
  26. if not config.provider:
  27. raise ValueError(
  28. "Must set provider in order to initialize OpenAIEmbeddingProvider."
  29. )
  30. if config.provider != "openai":
  31. raise ValueError(
  32. "OpenAIEmbeddingProvider must be initialized with provider `openai`."
  33. )
  34. if not os.getenv("OPENAI_API_KEY"):
  35. raise ValueError(
  36. "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
  37. )
  38. self.client = OpenAI()
  39. self.async_client = AsyncOpenAI()
  40. if config.rerank_model:
  41. raise ValueError(
  42. "OpenAIEmbeddingProvider does not support separate reranking."
  43. )
  44. if config.base_model and "openai/" in config.base_model:
  45. self.base_model = config.base_model.split("/")[-1]
  46. else:
  47. self.base_model = config.base_model
  48. self.base_dimension = config.base_dimension
  49. if not self.base_model:
  50. raise ValueError(
  51. "Must set base_model in order to initialize OpenAIEmbeddingProvider."
  52. )
  53. if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
  54. raise ValueError(
  55. f"OpenAI embedding model {self.base_model} not supported."
  56. )
  57. if self.base_dimension:
  58. if (
  59. self.base_dimension
  60. not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
  61. self.base_model
  62. ]
  63. ):
  64. raise ValueError(
  65. f"Dimensions {self.base_dimension} for {self.base_model} are not supported"
  66. )
  67. else:
  68. # If base_dimension is not set, use the largest available dimension for the model
  69. self.base_dimension = max(
  70. OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model]
  71. )
  72. def _get_dimensions(self):
  73. return (
  74. NOT_GIVEN
  75. if self.base_model == "text-embedding-ada-002"
  76. else self.base_dimension
  77. or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model][-1]
  78. )
  79. def _get_embedding_kwargs(self, **kwargs):
  80. return {
  81. "model": self.base_model,
  82. "dimensions": self._get_dimensions(),
  83. } | kwargs
  84. async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
  85. texts = task["texts"]
  86. kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
  87. try:
  88. response = await self.async_client.embeddings.create(
  89. input=texts,
  90. **kwargs,
  91. )
  92. return [data.embedding for data in response.data]
  93. except AuthenticationError as e:
  94. raise ValueError(
  95. "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
  96. ) from e
  97. except Exception as e:
  98. error_msg = f"Error getting embeddings: {str(e)}"
  99. logger.error(error_msg)
  100. raise ValueError(error_msg) from e
  101. def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
  102. texts = task["texts"]
  103. kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
  104. try:
  105. response = self.client.embeddings.create(
  106. input=texts,
  107. **kwargs,
  108. )
  109. return [data.embedding for data in response.data]
  110. except AuthenticationError as e:
  111. raise ValueError(
  112. "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
  113. ) from e
  114. except Exception as e:
  115. error_msg = f"Error getting embeddings: {str(e)}"
  116. logger.error(error_msg)
  117. raise ValueError(error_msg) from e
  118. async def async_get_embedding(
  119. self,
  120. text: str,
  121. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
  122. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  123. **kwargs,
  124. ) -> list[float]:
  125. if stage != EmbeddingProvider.PipeStage.BASE:
  126. raise ValueError(
  127. "OpenAIEmbeddingProvider only supports search stage."
  128. )
  129. task = {
  130. "texts": [text],
  131. "stage": stage,
  132. "purpose": purpose,
  133. "kwargs": kwargs,
  134. }
  135. result = await self._execute_with_backoff_async(task)
  136. return result[0]
  137. def get_embedding(
  138. self,
  139. text: str,
  140. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
  141. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  142. **kwargs,
  143. ) -> list[float]:
  144. if stage != EmbeddingProvider.PipeStage.BASE:
  145. raise ValueError(
  146. "OpenAIEmbeddingProvider only supports search stage."
  147. )
  148. task = {
  149. "texts": [text],
  150. "stage": stage,
  151. "purpose": purpose,
  152. "kwargs": kwargs,
  153. }
  154. result = self._execute_with_backoff_sync(task)
  155. return result[0]
  156. async def async_get_embeddings(
  157. self,
  158. texts: list[str],
  159. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
  160. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  161. **kwargs,
  162. ) -> list[list[float]]:
  163. if stage != EmbeddingProvider.PipeStage.BASE:
  164. raise ValueError(
  165. "OpenAIEmbeddingProvider only supports search stage."
  166. )
  167. task = {
  168. "texts": texts,
  169. "stage": stage,
  170. "purpose": purpose,
  171. "kwargs": kwargs,
  172. }
  173. return await self._execute_with_backoff_async(task)
  174. def get_embeddings(
  175. self,
  176. texts: list[str],
  177. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
  178. purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
  179. **kwargs,
  180. ) -> list[list[float]]:
  181. if stage != EmbeddingProvider.PipeStage.BASE:
  182. raise ValueError(
  183. "OpenAIEmbeddingProvider only supports search stage."
  184. )
  185. task = {
  186. "texts": texts,
  187. "stage": stage,
  188. "purpose": purpose,
  189. "kwargs": kwargs,
  190. }
  191. return self._execute_with_backoff_sync(task)
  192. def rerank(
  193. self,
  194. query: str,
  195. results: list[ChunkSearchResult],
  196. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
  197. limit: int = 10,
  198. ):
  199. return results[:limit]
  200. async def arerank(
  201. self,
  202. query: str,
  203. results: list[ChunkSearchResult],
  204. stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
  205. limit: int = 10,
  206. ):
  207. return results[:limit]
  208. def tokenize_string(self, text: str, model: str) -> list[int]:
  209. try:
  210. import tiktoken
  211. except ImportError as e:
  212. raise ValueError(
  213. "Must download tiktoken library to run `tokenize_string`."
  214. ) from e
  215. if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
  216. raise ValueError(f"OpenAI embedding model {model} not supported.")
  217. encoding = tiktoken.get_encoding(
  218. OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model]
  219. )
  220. return encoding.encode(text)