llm.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. """Abstractions for the LLM model."""
  2. import json
  3. from enum import Enum
  4. from typing import TYPE_CHECKING, Any, ClassVar, Optional
  5. from openai.types.chat import ChatCompletionChunk
  6. from pydantic import BaseModel, Field
  7. from .base import R2RSerializable
  8. if TYPE_CHECKING:
  9. from .search import AggregateSearchResult
  10. from typing_extensions import Literal
  11. class Function(BaseModel):
  12. arguments: str
  13. """
  14. The arguments to call the function with, as generated by the model in JSON
  15. format. Note that the model does not always generate valid JSON, and may
  16. hallucinate parameters not defined by your function schema. Validate the
  17. arguments in your code before calling your function.
  18. """
  19. name: str
  20. """The name of the function to call."""
  21. class ChatCompletionMessageToolCall(BaseModel):
  22. id: str
  23. """The ID of the tool call."""
  24. function: Function
  25. """The function that the model called."""
  26. type: Literal["function"]
  27. """The type of the tool. Currently, only `function` is supported."""
  28. class FunctionCall(BaseModel):
  29. arguments: str
  30. """
  31. The arguments to call the function with, as generated by the model in JSON
  32. format. Note that the model does not always generate valid JSON, and may
  33. hallucinate parameters not defined by your function schema. Validate the
  34. arguments in your code before calling your function.
  35. """
  36. name: str
  37. """The name of the function to call."""
  38. class ChatCompletionMessage(BaseModel):
  39. content: Optional[str] = None
  40. """The contents of the message."""
  41. refusal: Optional[str] = None
  42. """The refusal message generated by the model."""
  43. role: Literal["assistant"]
  44. """The role of the author of this message."""
  45. # audio: Optional[ChatCompletionAudio] = None
  46. """
  47. If the audio output modality is requested, this object contains data about the
  48. audio response from the model.
  49. [Learn more](https://platform.openai.com/docs/guides/audio).
  50. """
  51. function_call: Optional[FunctionCall] = None
  52. """Deprecated and replaced by `tool_calls`.
  53. The name and arguments of a function that should be called, as generated by the
  54. model.
  55. """
  56. tool_calls: Optional[list[ChatCompletionMessageToolCall]] = None
  57. """The tool calls generated by the model, such as function calls."""
  58. structured_content: Optional[list[dict]] = None
  59. class Choice(BaseModel):
  60. finish_reason: Literal[
  61. "stop",
  62. "length",
  63. "tool_calls",
  64. "content_filter",
  65. "function_call",
  66. "max_tokens",
  67. ]
  68. """The reason the model stopped generating tokens.
  69. This will be `stop` if the model hit a natural stop point or a provided stop
  70. sequence, `length` if the maximum number of tokens specified in the request was
  71. reached, `content_filter` if content was omitted due to a flag from our content
  72. filters, `tool_calls` if the model called a tool, or `function_call`
  73. (deprecated) if the model called a function.
  74. """
  75. index: int
  76. """The index of the choice in the list of choices."""
  77. # logprobs: Optional[ChoiceLogprobs] = None
  78. """Log probability information for the choice."""
  79. message: ChatCompletionMessage
  80. """A chat completion message generated by the model."""
  81. class LLMChatCompletion(BaseModel):
  82. id: str
  83. """A unique identifier for the chat completion."""
  84. choices: list[Choice]
  85. """A list of chat completion choices.
  86. Can be more than one if `n` is greater than 1.
  87. """
  88. created: int
  89. """The Unix timestamp (in seconds) of when the chat completion was created."""
  90. model: str
  91. """The model used for the chat completion."""
  92. object: Literal["chat.completion"]
  93. """The object type, which is always `chat.completion`."""
  94. service_tier: Optional[Literal["scale", "default"]] = None
  95. """The service tier used for processing the request."""
  96. system_fingerprint: Optional[str] = None
  97. """This fingerprint represents the backend configuration that the model runs with.
  98. Can be used in conjunction with the `seed` request parameter to understand when
  99. backend changes have been made that might impact determinism.
  100. """
  101. usage: Optional[Any] = None
  102. """Usage statistics for the completion request."""
  103. LLMChatCompletionChunk = ChatCompletionChunk
  104. class RAGCompletion:
  105. completion: LLMChatCompletion
  106. search_results: "AggregateSearchResult"
  107. def __init__(
  108. self,
  109. completion: LLMChatCompletion,
  110. search_results: "AggregateSearchResult",
  111. ):
  112. self.completion = completion
  113. self.search_results = search_results
  114. class GenerationConfig(R2RSerializable):
  115. _defaults: ClassVar[dict] = {
  116. "model": "openai/gpt-4.1-mini",
  117. "temperature": 0.1,
  118. "top_p": 1.0,
  119. "max_tokens_to_sample": 102400,
  120. "stream": False,
  121. "functions": None,
  122. "tools": None,
  123. "add_generation_kwargs": None,
  124. "api_base": None,
  125. "response_format": None,
  126. "extended_thinking": False,
  127. "thinking_budget": None,
  128. "reasoning_effort": None,
  129. }
  130. model: Optional[str] = Field(
  131. default_factory=lambda: GenerationConfig._defaults["model"]
  132. )
  133. temperature: float = Field(
  134. default_factory=lambda: GenerationConfig._defaults["temperature"]
  135. )
  136. top_p: Optional[float] = Field(
  137. default_factory=lambda: GenerationConfig._defaults["top_p"],
  138. )
  139. max_tokens_to_sample: int = Field(
  140. default_factory=lambda: GenerationConfig._defaults[
  141. "max_tokens_to_sample"
  142. ],
  143. )
  144. stream: bool = Field(
  145. default_factory=lambda: GenerationConfig._defaults["stream"]
  146. )
  147. functions: Optional[list[dict]] = Field(
  148. default_factory=lambda: GenerationConfig._defaults["functions"]
  149. )
  150. tools: Optional[list[dict]] = Field(
  151. default_factory=lambda: GenerationConfig._defaults["tools"]
  152. )
  153. add_generation_kwargs: Optional[dict] = Field(
  154. default_factory=lambda: GenerationConfig._defaults[
  155. "add_generation_kwargs"
  156. ],
  157. )
  158. api_base: Optional[str] = Field(
  159. default_factory=lambda: GenerationConfig._defaults["api_base"],
  160. )
  161. response_format: Optional[dict | BaseModel] = None
  162. extended_thinking: bool = Field(
  163. default=False,
  164. description="Flag to enable extended thinking mode (for Anthropic providers)",
  165. )
  166. thinking_budget: Optional[int] = Field(
  167. default=None,
  168. description=(
  169. "Token budget for internal reasoning when extended thinking mode is enabled. "
  170. "Must be less than max_tokens_to_sample."
  171. ),
  172. )
  173. reasoning_effort: Optional[str] = Field(
  174. default=None,
  175. description=(
  176. "Effort level for internal reasoning when extended thinking mode is enabled, `low`, `medium`, or `high`."
  177. "Only applicable to OpenAI providers."
  178. ),
  179. )
  180. @classmethod
  181. def set_default(cls, **kwargs):
  182. for key, value in kwargs.items():
  183. if key in cls._defaults:
  184. cls._defaults[key] = value
  185. else:
  186. raise AttributeError(
  187. f"No default attribute '{key}' in GenerationConfig"
  188. )
  189. def __init__(self, **data):
  190. # Handle max_tokens mapping to max_tokens_to_sample
  191. if "max_tokens" in data:
  192. # Only set max_tokens_to_sample if it's not already provided
  193. if "max_tokens_to_sample" not in data:
  194. data["max_tokens_to_sample"] = data.pop("max_tokens")
  195. else:
  196. # If both are provided, max_tokens_to_sample takes precedence
  197. data.pop("max_tokens")
  198. if (
  199. "response_format" in data
  200. and isinstance(data["response_format"], type)
  201. and issubclass(data["response_format"], BaseModel)
  202. ):
  203. model_class = data["response_format"]
  204. data["response_format"] = {
  205. "type": "json_schema",
  206. "json_schema": {
  207. "name": model_class.__name__,
  208. "schema": model_class.model_json_schema(),
  209. },
  210. }
  211. model = data.pop("model", None)
  212. if model is not None:
  213. super().__init__(model=model, **data)
  214. else:
  215. super().__init__(**data)
  216. def __str__(self):
  217. return json.dumps(self.to_dict())
  218. class Config:
  219. populate_by_name = True
  220. json_schema_extra = {
  221. "example": {
  222. "model": "openai/gpt-4.1",
  223. "temperature": 0.1,
  224. "top_p": 1.0,
  225. "max_tokens_to_sample": 1024,
  226. "stream": False,
  227. "functions": None,
  228. "tools": None,
  229. "add_generation_kwargs": None,
  230. "api_base": None,
  231. }
  232. }
  233. class MessageType(Enum):
  234. SYSTEM = "system"
  235. USER = "user"
  236. ASSISTANT = "assistant"
  237. FUNCTION = "function"
  238. TOOL = "tool"
  239. def __str__(self):
  240. return self.value
  241. class Message(R2RSerializable):
  242. role: MessageType | str
  243. content: Optional[Any] = None
  244. name: Optional[str] = None
  245. function_call: Optional[dict[str, Any]] = None
  246. tool_calls: Optional[list[dict[str, Any]]] = None
  247. tool_call_id: Optional[str] = None
  248. metadata: Optional[dict[str, Any]] = None
  249. structured_content: Optional[list[dict]] = None
  250. image_url: Optional[str] = None # For URL-based images
  251. image_data: Optional[dict[str, str]] = (
  252. None # For base64 {media_type, data}
  253. )
  254. class Config:
  255. populate_by_name = True
  256. json_schema_extra = {
  257. "example": {
  258. "role": "user",
  259. "content": "This is a test message.",
  260. "name": None,
  261. "function_call": None,
  262. "tool_calls": None,
  263. }
  264. }