llm.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. """Abstractions for the LLM model."""
  2. import json
  3. from enum import Enum
  4. from typing import TYPE_CHECKING, Any, ClassVar, Optional
  5. from openai.types.chat import ChatCompletion, ChatCompletionChunk
  6. from pydantic import BaseModel, Field
  7. from .base import R2RSerializable
  8. if TYPE_CHECKING:
  9. from .search import AggregateSearchResult
  10. LLMChatCompletion = ChatCompletion
  11. LLMChatCompletionChunk = ChatCompletionChunk
  12. class RAGCompletion:
  13. completion: LLMChatCompletion
  14. search_results: "AggregateSearchResult"
  15. def __init__(
  16. self,
  17. completion: LLMChatCompletion,
  18. search_results: "AggregateSearchResult",
  19. ):
  20. self.completion = completion
  21. self.search_results = search_results
  22. class GenerationConfig(R2RSerializable):
  23. _defaults: ClassVar[dict] = {
  24. "model": "openai/gpt-4o",
  25. "temperature": 0.1,
  26. "top_p": 1.0,
  27. "max_tokens_to_sample": 1024,
  28. "stream": False,
  29. "functions": None,
  30. "tools": None,
  31. "add_generation_kwargs": None,
  32. "api_base": None,
  33. "response_format": None,
  34. }
  35. model: str = Field(
  36. default_factory=lambda: GenerationConfig._defaults["model"]
  37. )
  38. temperature: float = Field(
  39. default_factory=lambda: GenerationConfig._defaults["temperature"]
  40. )
  41. top_p: float = Field(
  42. default_factory=lambda: GenerationConfig._defaults["top_p"],
  43. )
  44. max_tokens_to_sample: int = Field(
  45. default_factory=lambda: GenerationConfig._defaults[
  46. "max_tokens_to_sample"
  47. ],
  48. )
  49. stream: bool = Field(
  50. default_factory=lambda: GenerationConfig._defaults["stream"]
  51. )
  52. functions: Optional[list[dict]] = Field(
  53. default_factory=lambda: GenerationConfig._defaults["functions"]
  54. )
  55. tools: Optional[list[dict]] = Field(
  56. default_factory=lambda: GenerationConfig._defaults["tools"]
  57. )
  58. add_generation_kwargs: Optional[dict] = Field(
  59. default_factory=lambda: GenerationConfig._defaults[
  60. "add_generation_kwargs"
  61. ],
  62. )
  63. api_base: Optional[str] = Field(
  64. default_factory=lambda: GenerationConfig._defaults["api_base"],
  65. )
  66. response_format: Optional[dict | BaseModel] = None
  67. @classmethod
  68. def set_default(cls, **kwargs):
  69. for key, value in kwargs.items():
  70. if key in cls._defaults:
  71. cls._defaults[key] = value
  72. else:
  73. raise AttributeError(
  74. f"No default attribute '{key}' in GenerationConfig"
  75. )
  76. def __init__(self, **data):
  77. if (
  78. "response_format" in data
  79. and isinstance(data["response_format"], type)
  80. and issubclass(data["response_format"], BaseModel)
  81. ):
  82. model_class = data["response_format"]
  83. data["response_format"] = {
  84. "type": "json_schema",
  85. "json_schema": {
  86. "name": model_class.__name__,
  87. "schema": model_class.model_json_schema(),
  88. },
  89. }
  90. model = data.pop("model", None)
  91. if model is not None:
  92. super().__init__(model=model, **data)
  93. else:
  94. super().__init__(**data)
  95. def __str__(self):
  96. return json.dumps(self.to_dict())
  97. class Config:
  98. populate_by_name = True
  99. json_schema_extra = {
  100. "model": "openai/gpt-4o",
  101. "temperature": 0.1,
  102. "top_p": 1.0,
  103. "max_tokens_to_sample": 1024,
  104. "stream": False,
  105. "functions": None,
  106. "tools": None,
  107. "add_generation_kwargs": None,
  108. "api_base": None,
  109. }
  110. class MessageType(Enum):
  111. SYSTEM = "system"
  112. USER = "user"
  113. ASSISTANT = "assistant"
  114. FUNCTION = "function"
  115. TOOL = "tool"
  116. def __str__(self):
  117. return self.value
  118. class Message(R2RSerializable):
  119. role: MessageType | str
  120. content: Optional[str] = None
  121. name: Optional[str] = None
  122. function_call: Optional[dict[str, Any]] = None
  123. tool_calls: Optional[list[dict[str, Any]]] = None
  124. class Config:
  125. populate_by_name = True
  126. json_schema_extra = {
  127. "role": "user",
  128. "content": "This is a test message.",
  129. "name": None,
  130. "function_call": None,
  131. "tool_calls": None,
  132. }