openai.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import contextlib
  2. import logging
  3. import os
  4. from typing import Any
  5. import tiktoken
  6. from openai import AsyncOpenAI, AuthenticationError, OpenAI
  7. from openai._types import NOT_GIVEN
  8. from core.base import (
  9. ChunkSearchResult,
  10. EmbeddingConfig,
  11. EmbeddingProvider,
  12. )
  13. from .utils import truncate_texts_to_token_limit
  14. logger = logging.getLogger()
  15. class OpenAIEmbeddingProvider(EmbeddingProvider):
  16. MODEL_TO_TOKENIZER = {
  17. "text-embedding-ada-002": "cl100k_base",
  18. "text-embedding-3-small": "cl100k_base",
  19. "text-embedding-3-large": "cl100k_base",
  20. }
  21. MODEL_TO_DIMENSIONS = {
  22. "text-embedding-ada-002": [1536],
  23. "text-embedding-3-small": [512, 1536],
  24. "text-embedding-3-large": [256, 1024, 3072],
  25. }
  26. def __init__(self, config: EmbeddingConfig):
  27. super().__init__(config)
  28. if not config.provider:
  29. raise ValueError(
  30. "Must set provider in order to initialize OpenAIEmbeddingProvider."
  31. )
  32. if config.provider != "openai":
  33. raise ValueError(
  34. "OpenAIEmbeddingProvider must be initialized with provider `openai`."
  35. )
  36. if not os.getenv("OPENAI_API_KEY"):
  37. raise ValueError(
  38. "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
  39. )
  40. #self.client = OpenAI()
  41. #self.async_client = AsyncOpenAI()
  42. self.client = OpenAI(
  43. api_key="sk-j9Uwupu0NPZtdDS_IfEZlRWpX1JgFyZFLZProkesy2QbtqMs16pDnylAozU",
  44. base_url="http://172.16.12.13:3000/v1"
  45. )
  46. self.async_client = AsyncOpenAI(
  47. api_key="sk-j9Uwupu0NPZtdDS_IfEZlRWpX1JgFyZFLZProkesy2QbtqMs16pDnylAozU",
  48. base_url="http://172.16.12.13:3000/v1"
  49. )
  50. if config.rerank_model:
  51. raise ValueError(
  52. "OpenAIEmbeddingProvider does not support separate reranking."
  53. )
  54. if config.base_model and "openai/" in config.base_model:
  55. self.base_model = config.base_model.split("/")[-1]
  56. else:
  57. self.base_model = config.base_model
  58. self.base_dimension = config.base_dimension
  59. if not self.base_model:
  60. raise ValueError(
  61. "Must set base_model in order to initialize OpenAIEmbeddingProvider."
  62. )
  63. if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
  64. raise ValueError(
  65. f"OpenAI embedding model {self.base_model} not supported."
  66. )
  67. if self.base_dimension:
  68. if (
  69. self.base_dimension
  70. not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
  71. self.base_model
  72. ]
  73. ):
  74. raise ValueError(
  75. f"Dimensions {self.base_dimension} for {self.base_model} are not supported"
  76. )
  77. else:
  78. # If base_dimension is not set, use the largest available dimension for the model
  79. self.base_dimension = max(
  80. OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model]
  81. )
  82. def _get_dimensions(self):
  83. return (
  84. NOT_GIVEN
  85. if self.base_model == "text-embedding-ada-002"
  86. else self.base_dimension
  87. or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model][-1]
  88. )
  89. def _get_embedding_kwargs(self, **kwargs):
  90. return {
  91. "model": self.base_model,
  92. "dimensions": self._get_dimensions(),
  93. } | kwargs
  94. async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
  95. texts = task["texts"]
  96. kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
  97. try:
  98. # Truncate text if it exceeds the model's max input tokens. Some providers do this by default, others do not.
  99. if kwargs.get("model"):
  100. with contextlib.suppress(Exception):
  101. texts = truncate_texts_to_token_limit(
  102. texts, kwargs["model"]
  103. )
  104. response = await self.async_client.embeddings.create(
  105. input=texts,
  106. **kwargs,
  107. )
  108. return [data.embedding for data in response.data]
  109. except AuthenticationError as e:
  110. raise ValueError(
  111. "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
  112. ) from e
  113. except Exception as e:
  114. error_msg = f"Error getting embeddings: {str(e)}"
  115. logger.error(error_msg)
  116. raise ValueError(error_msg) from e
  117. def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
  118. texts = task["texts"]
  119. kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
  120. try:
  121. # Truncate text if it exceeds the model's max input tokens. Some providers do this by default, others do not.
  122. if kwargs.get("model"):
  123. with contextlib.suppress(Exception):
  124. texts = truncate_texts_to_token_limit(
  125. texts, kwargs["model"]
  126. )
  127. response = self.client.embeddings.create(
  128. input=texts,
  129. **kwargs,
  130. )
  131. return [data.embedding for data in response.data]
  132. except AuthenticationError as e:
  133. raise ValueError(
  134. "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
  135. ) from e
  136. except Exception as e:
  137. error_msg = f"Error getting embeddings: {str(e)}"
  138. logger.error(error_msg)
  139. raise ValueError(error_msg) from e
  140. async def async_get_embedding(
  141. self,
  142. text: str,
  143. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
  144. **kwargs,
  145. ) -> list[float]:
  146. if stage != EmbeddingProvider.Step.BASE:
  147. raise ValueError(
  148. "OpenAIEmbeddingProvider only supports search stage."
  149. )
  150. task = {
  151. "texts": [text],
  152. "stage": stage,
  153. "kwargs": kwargs,
  154. }
  155. result = await self._execute_with_backoff_async(task)
  156. return result[0]
  157. def get_embedding(
  158. self,
  159. text: str,
  160. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
  161. **kwargs,
  162. ) -> list[float]:
  163. if stage != EmbeddingProvider.Step.BASE:
  164. raise ValueError(
  165. "OpenAIEmbeddingProvider only supports search stage."
  166. )
  167. task = {
  168. "texts": [text],
  169. "stage": stage,
  170. "kwargs": kwargs,
  171. }
  172. result = self._execute_with_backoff_sync(task)
  173. return result[0]
  174. async def async_get_embeddings(
  175. self,
  176. texts: list[str],
  177. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
  178. **kwargs,
  179. ) -> list[list[float]]:
  180. if stage != EmbeddingProvider.Step.BASE:
  181. raise ValueError(
  182. "OpenAIEmbeddingProvider only supports search stage."
  183. )
  184. task = {
  185. "texts": texts,
  186. "stage": stage,
  187. "kwargs": kwargs,
  188. }
  189. return await self._execute_with_backoff_async(task)
  190. def get_embeddings(
  191. self,
  192. texts: list[str],
  193. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.BASE,
  194. **kwargs,
  195. ) -> list[list[float]]:
  196. if stage != EmbeddingProvider.Step.BASE:
  197. raise ValueError(
  198. "OpenAIEmbeddingProvider only supports search stage."
  199. )
  200. task = {
  201. "texts": texts,
  202. "stage": stage,
  203. "kwargs": kwargs,
  204. }
  205. return self._execute_with_backoff_sync(task)
  206. def rerank(
  207. self,
  208. query: str,
  209. results: list[ChunkSearchResult],
  210. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
  211. limit: int = 10,
  212. ):
  213. return results[:limit]
  214. async def arerank(
  215. self,
  216. query: str,
  217. results: list[ChunkSearchResult],
  218. stage: EmbeddingProvider.Step = EmbeddingProvider.Step.RERANK,
  219. limit: int = 10,
  220. ):
  221. return results[:limit]
  222. def tokenize_string(self, text: str, model: str) -> list[int]:
  223. if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
  224. raise ValueError(f"OpenAI embedding model {model} not supported.")
  225. encoding = tiktoken.get_encoding(
  226. OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model]
  227. )
  228. return encoding.encode(text)