123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- """Abstractions for the LLM model."""
- import json
- from enum import Enum
- from typing import TYPE_CHECKING, Any, ClassVar, Optional
- from openai.types.chat import ChatCompletionChunk
- from pydantic import BaseModel, Field
- from .base import R2RSerializable
- if TYPE_CHECKING:
- from .search import AggregateSearchResult
- from typing_extensions import Literal
- class Function(BaseModel):
- arguments: str
- """
- The arguments to call the function with, as generated by the model in JSON
- format. Note that the model does not always generate valid JSON, and may
- hallucinate parameters not defined by your function schema. Validate the
- arguments in your code before calling your function.
- """
- name: str
- """The name of the function to call."""
- class ChatCompletionMessageToolCall(BaseModel):
- id: str
- """The ID of the tool call."""
- function: Function
- """The function that the model called."""
- type: Literal["function"]
- """The type of the tool. Currently, only `function` is supported."""
- class FunctionCall(BaseModel):
- arguments: str
- """
- The arguments to call the function with, as generated by the model in JSON
- format. Note that the model does not always generate valid JSON, and may
- hallucinate parameters not defined by your function schema. Validate the
- arguments in your code before calling your function.
- """
- name: str
- """The name of the function to call."""
- class ChatCompletionMessage(BaseModel):
- content: Optional[str] = None
- """The contents of the message."""
- refusal: Optional[str] = None
- """The refusal message generated by the model."""
- role: Literal["assistant"]
- """The role of the author of this message."""
- # audio: Optional[ChatCompletionAudio] = None
- """
- If the audio output modality is requested, this object contains data about the
- audio response from the model.
- [Learn more](https://platform.openai.com/docs/guides/audio).
- """
- function_call: Optional[FunctionCall] = None
- """Deprecated and replaced by `tool_calls`.
- The name and arguments of a function that should be called, as generated by the
- model.
- """
- tool_calls: Optional[list[ChatCompletionMessageToolCall]] = None
- """The tool calls generated by the model, such as function calls."""
- structured_content: Optional[list[dict]] = None
- class Choice(BaseModel):
- finish_reason: Literal[
- "stop",
- "length",
- "tool_calls",
- "content_filter",
- "function_call",
- "max_tokens",
- ]
- """The reason the model stopped generating tokens.
- This will be `stop` if the model hit a natural stop point or a provided stop
- sequence, `length` if the maximum number of tokens specified in the request was
- reached, `content_filter` if content was omitted due to a flag from our content
- filters, `tool_calls` if the model called a tool, or `function_call`
- (deprecated) if the model called a function.
- """
- index: int
- """The index of the choice in the list of choices."""
- # logprobs: Optional[ChoiceLogprobs] = None
- """Log probability information for the choice."""
- message: ChatCompletionMessage
- """A chat completion message generated by the model."""
- class LLMChatCompletion(BaseModel):
- id: str
- """A unique identifier for the chat completion."""
- choices: list[Choice]
- """A list of chat completion choices.
- Can be more than one if `n` is greater than 1.
- """
- created: int
- """The Unix timestamp (in seconds) of when the chat completion was created."""
- model: str
- """The model used for the chat completion."""
- object: Literal["chat.completion"]
- """The object type, which is always `chat.completion`."""
- service_tier: Optional[Literal["scale", "default"]] = None
- """The service tier used for processing the request."""
- system_fingerprint: Optional[str] = None
- """This fingerprint represents the backend configuration that the model runs with.
- Can be used in conjunction with the `seed` request parameter to understand when
- backend changes have been made that might impact determinism.
- """
- usage: Optional[Any] = None
- """Usage statistics for the completion request."""
- LLMChatCompletionChunk = ChatCompletionChunk
- class RAGCompletion:
- completion: LLMChatCompletion
- search_results: "AggregateSearchResult"
- def __init__(
- self,
- completion: LLMChatCompletion,
- search_results: "AggregateSearchResult",
- ):
- self.completion = completion
- self.search_results = search_results
- class GenerationConfig(R2RSerializable):
- _defaults: ClassVar[dict] = {
- "model": "openai/gpt-4.1-mini",
- "temperature": 0.1,
- "top_p": 1.0,
- "max_tokens_to_sample": 102400,
- "stream": False,
- "functions": None,
- "tools": None,
- "add_generation_kwargs": None,
- "api_base": None,
- "response_format": None,
- "extended_thinking": False,
- "thinking_budget": None,
- "reasoning_effort": None,
- }
- model: Optional[str] = Field(
- default_factory=lambda: GenerationConfig._defaults["model"]
- )
- temperature: float = Field(
- default_factory=lambda: GenerationConfig._defaults["temperature"]
- )
- top_p: Optional[float] = Field(
- default_factory=lambda: GenerationConfig._defaults["top_p"],
- )
- max_tokens_to_sample: int = Field(
- default_factory=lambda: GenerationConfig._defaults[
- "max_tokens_to_sample"
- ],
- )
- stream: bool = Field(
- default_factory=lambda: GenerationConfig._defaults["stream"]
- )
- functions: Optional[list[dict]] = Field(
- default_factory=lambda: GenerationConfig._defaults["functions"]
- )
- tools: Optional[list[dict]] = Field(
- default_factory=lambda: GenerationConfig._defaults["tools"]
- )
- add_generation_kwargs: Optional[dict] = Field(
- default_factory=lambda: GenerationConfig._defaults[
- "add_generation_kwargs"
- ],
- )
- api_base: Optional[str] = Field(
- default_factory=lambda: GenerationConfig._defaults["api_base"],
- )
- response_format: Optional[dict | BaseModel] = None
- extended_thinking: bool = Field(
- default=False,
- description="Flag to enable extended thinking mode (for Anthropic providers)",
- )
- thinking_budget: Optional[int] = Field(
- default=None,
- description=(
- "Token budget for internal reasoning when extended thinking mode is enabled. "
- "Must be less than max_tokens_to_sample."
- ),
- )
- reasoning_effort: Optional[str] = Field(
- default=None,
- description=(
- "Effort level for internal reasoning when extended thinking mode is enabled, `low`, `medium`, or `high`."
- "Only applicable to OpenAI providers."
- ),
- )
- @classmethod
- def set_default(cls, **kwargs):
- for key, value in kwargs.items():
- if key in cls._defaults:
- cls._defaults[key] = value
- else:
- raise AttributeError(
- f"No default attribute '{key}' in GenerationConfig"
- )
- def __init__(self, **data):
- # Handle max_tokens mapping to max_tokens_to_sample
- if "max_tokens" in data:
- # Only set max_tokens_to_sample if it's not already provided
- if "max_tokens_to_sample" not in data:
- data["max_tokens_to_sample"] = data.pop("max_tokens")
- else:
- # If both are provided, max_tokens_to_sample takes precedence
- data.pop("max_tokens")
- if (
- "response_format" in data
- and isinstance(data["response_format"], type)
- and issubclass(data["response_format"], BaseModel)
- ):
- model_class = data["response_format"]
- data["response_format"] = {
- "type": "json_schema",
- "json_schema": {
- "name": model_class.__name__,
- "schema": model_class.model_json_schema(),
- },
- }
- model = data.pop("model", None)
- if model is not None:
- super().__init__(model=model, **data)
- else:
- super().__init__(**data)
- def __str__(self):
- return json.dumps(self.to_dict())
- class Config:
- populate_by_name = True
- json_schema_extra = {
- "example": {
- "model": "openai/gpt-4.1",
- "temperature": 0.1,
- "top_p": 1.0,
- "max_tokens_to_sample": 1024,
- "stream": False,
- "functions": None,
- "tools": None,
- "add_generation_kwargs": None,
- "api_base": None,
- }
- }
- class MessageType(Enum):
- SYSTEM = "system"
- USER = "user"
- ASSISTANT = "assistant"
- FUNCTION = "function"
- TOOL = "tool"
- def __str__(self):
- return self.value
- class Message(R2RSerializable):
- role: MessageType | str
- content: Optional[Any] = None
- name: Optional[str] = None
- function_call: Optional[dict[str, Any]] = None
- tool_calls: Optional[list[dict[str, Any]]] = None
- tool_call_id: Optional[str] = None
- metadata: Optional[dict[str, Any]] = None
- structured_content: Optional[list[dict]] = None
- image_url: Optional[str] = None # For URL-based images
- image_data: Optional[dict[str, str]] = (
- None # For base64 {media_type, data}
- )
- class Config:
- populate_by_name = True
- json_schema_extra = {
- "example": {
- "role": "user",
- "content": "This is a test message.",
- "name": None,
- "function_call": None,
- "tool_calls": None,
- }
- }
|