123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248 |
- import logging
- import os
- from typing import Any
- from openai import AsyncOpenAI, AuthenticationError, OpenAI
- from openai._types import NOT_GIVEN
- from core.base import (
- ChunkSearchResult,
- EmbeddingConfig,
- EmbeddingProvider,
- EmbeddingPurpose,
- )
- logger = logging.getLogger()
- class OpenAIEmbeddingProvider(EmbeddingProvider):
- MODEL_TO_TOKENIZER = {
- "text-embedding-ada-002": "cl100k_base",
- "text-embedding-3-small": "cl100k_base",
- "text-embedding-3-large": "cl100k_base",
- }
- MODEL_TO_DIMENSIONS = {
- "text-embedding-ada-002": [1536],
- "text-embedding-3-small": [512, 1536],
- "text-embedding-3-large": [256, 1024, 3072],
- }
- def __init__(self, config: EmbeddingConfig):
- super().__init__(config)
- if not config.provider:
- raise ValueError(
- "Must set provider in order to initialize OpenAIEmbeddingProvider."
- )
- if config.provider != "openai":
- raise ValueError(
- "OpenAIEmbeddingProvider must be initialized with provider `openai`."
- )
- if not os.getenv("OPENAI_API_KEY"):
- raise ValueError(
- "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider."
- )
- self.client = OpenAI()
- self.async_client = AsyncOpenAI()
- if config.rerank_model:
- raise ValueError(
- "OpenAIEmbeddingProvider does not support separate reranking."
- )
- if config.base_model and "openai/" in config.base_model:
- self.base_model = config.base_model.split("/")[-1]
- else:
- self.base_model = config.base_model
- self.base_dimension = config.base_dimension
- if not self.base_model:
- raise ValueError(
- "Must set base_model in order to initialize OpenAIEmbeddingProvider."
- )
- if self.base_model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
- raise ValueError(
- f"OpenAI embedding model {self.base_model} not supported."
- )
- if self.base_dimension:
- if (
- self.base_dimension
- not in OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[
- self.base_model
- ]
- ):
- raise ValueError(
- f"Dimensions {self.base_dimension} for {self.base_model} are not supported"
- )
- else:
- # If base_dimension is not set, use the largest available dimension for the model
- self.base_dimension = max(
- OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model]
- )
- def _get_dimensions(self):
- return (
- NOT_GIVEN
- if self.base_model == "text-embedding-ada-002"
- else self.base_dimension
- or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.base_model][-1]
- )
- def _get_embedding_kwargs(self, **kwargs):
- return {
- "model": self.base_model,
- "dimensions": self._get_dimensions(),
- } | kwargs
- async def _execute_task(self, task: dict[str, Any]) -> list[list[float]]:
- texts = task["texts"]
- kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
- try:
- response = await self.async_client.embeddings.create(
- input=texts,
- **kwargs,
- )
- return [data.embedding for data in response.data]
- except AuthenticationError as e:
- raise ValueError(
- "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
- ) from e
- except Exception as e:
- error_msg = f"Error getting embeddings: {str(e)}"
- logger.error(error_msg)
- raise ValueError(error_msg) from e
- def _execute_task_sync(self, task: dict[str, Any]) -> list[list[float]]:
- texts = task["texts"]
- kwargs = self._get_embedding_kwargs(**task.get("kwargs", {}))
- try:
- response = self.client.embeddings.create(
- input=texts,
- **kwargs,
- )
- return [data.embedding for data in response.data]
- except AuthenticationError as e:
- raise ValueError(
- "Invalid OpenAI API key provided. Please check your OPENAI_API_KEY environment variable."
- ) from e
- except Exception as e:
- error_msg = f"Error getting embeddings: {str(e)}"
- logger.error(error_msg)
- raise ValueError(error_msg) from e
- async def async_get_embedding(
- self,
- text: str,
- stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
- purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
- **kwargs,
- ) -> list[float]:
- if stage != EmbeddingProvider.PipeStage.BASE:
- raise ValueError(
- "OpenAIEmbeddingProvider only supports search stage."
- )
- task = {
- "texts": [text],
- "stage": stage,
- "purpose": purpose,
- "kwargs": kwargs,
- }
- result = await self._execute_with_backoff_async(task)
- return result[0]
- def get_embedding(
- self,
- text: str,
- stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
- purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
- **kwargs,
- ) -> list[float]:
- if stage != EmbeddingProvider.PipeStage.BASE:
- raise ValueError(
- "OpenAIEmbeddingProvider only supports search stage."
- )
- task = {
- "texts": [text],
- "stage": stage,
- "purpose": purpose,
- "kwargs": kwargs,
- }
- result = self._execute_with_backoff_sync(task)
- return result[0]
- async def async_get_embeddings(
- self,
- texts: list[str],
- stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
- purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
- **kwargs,
- ) -> list[list[float]]:
- if stage != EmbeddingProvider.PipeStage.BASE:
- raise ValueError(
- "OpenAIEmbeddingProvider only supports search stage."
- )
- task = {
- "texts": texts,
- "stage": stage,
- "purpose": purpose,
- "kwargs": kwargs,
- }
- return await self._execute_with_backoff_async(task)
- def get_embeddings(
- self,
- texts: list[str],
- stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.BASE,
- purpose: EmbeddingPurpose = EmbeddingPurpose.INDEX,
- **kwargs,
- ) -> list[list[float]]:
- if stage != EmbeddingProvider.PipeStage.BASE:
- raise ValueError(
- "OpenAIEmbeddingProvider only supports search stage."
- )
- task = {
- "texts": texts,
- "stage": stage,
- "purpose": purpose,
- "kwargs": kwargs,
- }
- return self._execute_with_backoff_sync(task)
- def rerank(
- self,
- query: str,
- results: list[ChunkSearchResult],
- stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
- limit: int = 10,
- ):
- return results[:limit]
- async def arerank(
- self,
- query: str,
- results: list[ChunkSearchResult],
- stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK,
- limit: int = 10,
- ):
- return results[:limit]
- def tokenize_string(self, text: str, model: str) -> list[int]:
- try:
- import tiktoken
- except ImportError as e:
- raise ValueError(
- "Must download tiktoken library to run `tokenize_string`."
- ) from e
- if model not in OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER:
- raise ValueError(f"OpenAI embedding model {model} not supported.")
- encoding = tiktoken.get_encoding(
- OpenAIEmbeddingProvider.MODEL_TO_TOKENIZER[model]
- )
- return encoding.encode(text)
|