agent.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import asyncio
  2. import json
  3. import logging
  4. from abc import ABC, abstractmethod
  5. from typing import Any, AsyncGenerator, Optional, Type
  6. from pydantic import BaseModel
  7. from core.base.abstractions import (
  8. GenerationConfig,
  9. LLMChatCompletion,
  10. Message,
  11. MessageType,
  12. )
  13. from core.base.providers import CompletionProvider, DatabaseProvider
  14. from .base import Tool, ToolResult
  15. logger = logging.getLogger()
  16. class Conversation:
  17. def __init__(self):
  18. self.messages: list[Message] = []
  19. self._lock = asyncio.Lock()
  20. def create_and_add_message(
  21. self,
  22. role: MessageType | str,
  23. content: Optional[str] = None,
  24. name: Optional[str] = None,
  25. function_call: Optional[dict[str, Any]] = None,
  26. tool_calls: Optional[list[dict[str, Any]]] = None,
  27. ):
  28. message = Message(
  29. role=role,
  30. content=content,
  31. name=name,
  32. function_call=function_call,
  33. tool_calls=tool_calls,
  34. )
  35. self.add_message(message)
  36. async def add_message(self, message):
  37. async with self._lock:
  38. self.messages.append(message)
  39. async def get_messages(self) -> list[dict[str, Any]]:
  40. async with self._lock:
  41. return [
  42. {**msg.model_dump(exclude_none=True), "role": str(msg.role)}
  43. for msg in self.messages
  44. ]
  45. # TODO - Move agents to provider pattern
  46. class AgentConfig(BaseModel):
  47. system_instruction_name: str = "rag_agent"
  48. tool_names: list[str] = ["search"]
  49. generation_config: GenerationConfig = GenerationConfig()
  50. stream: bool = False
  51. @classmethod
  52. def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
  53. base_args = cls.model_fields.keys()
  54. filtered_kwargs = {
  55. k: v if v != "None" else None
  56. for k, v in kwargs.items()
  57. if k in base_args
  58. }
  59. return cls(**filtered_kwargs) # type: ignore
  60. class Agent(ABC):
  61. def __init__(
  62. self,
  63. llm_provider: CompletionProvider,
  64. database_provider: DatabaseProvider,
  65. config: AgentConfig,
  66. ):
  67. self.llm_provider = llm_provider
  68. self.database_provider: DatabaseProvider = database_provider
  69. self.config = config
  70. self.conversation = Conversation()
  71. self._completed = False
  72. self._tools: list[Tool] = []
  73. self._register_tools()
  74. @abstractmethod
  75. def _register_tools(self):
  76. pass
  77. async def _setup(self, system_instruction: Optional[str] = None):
  78. content = system_instruction or (
  79. await self.database_provider.prompts_handler.get_cached_prompt(
  80. self.config.system_instruction_name
  81. )
  82. )
  83. await self.conversation.add_message(
  84. Message(
  85. role="system",
  86. content=system_instruction
  87. or (
  88. await self.database_provider.prompts_handler.get_cached_prompt(
  89. self.config.system_instruction_name
  90. )
  91. ),
  92. )
  93. )
  94. @property
  95. def tools(self) -> list[Tool]:
  96. return self._tools
  97. @tools.setter
  98. def tools(self, tools: list[Tool]):
  99. self._tools = tools
  100. @abstractmethod
  101. async def arun(
  102. self,
  103. system_instruction: Optional[str] = None,
  104. messages: Optional[list[Message]] = None,
  105. *args,
  106. **kwargs,
  107. ) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]:
  108. pass
  109. @abstractmethod
  110. async def process_llm_response(
  111. self,
  112. response: Any,
  113. *args,
  114. **kwargs,
  115. ) -> None | AsyncGenerator[str, None]:
  116. pass
  117. async def execute_tool(self, tool_name: str, *args, **kwargs) -> str:
  118. if tool := next((t for t in self.tools if t.name == tool_name), None):
  119. return await tool.results_function(*args, **kwargs)
  120. else:
  121. return f"Error: Tool {tool_name} not found."
  122. def get_generation_config(
  123. self, last_message: dict, stream: bool = False
  124. ) -> GenerationConfig:
  125. if (
  126. last_message["role"] in ["tool", "function"]
  127. and last_message["content"] != ""
  128. ):
  129. return GenerationConfig(
  130. **self.config.generation_config.model_dump(
  131. exclude={"functions", "tools", "stream"}
  132. ),
  133. stream=stream,
  134. )
  135. return GenerationConfig(
  136. **self.config.generation_config.model_dump(
  137. exclude={"functions", "tools", "stream"}
  138. ),
  139. # FIXME: Use tools instead of functions
  140. # TODO - Investigate why `tools` fails with OpenAI+LiteLLM
  141. # tools=[
  142. # {
  143. # "function":{
  144. # "name": tool.name,
  145. # "description": tool.description,
  146. # "parameters": tool.parameters,
  147. # },
  148. # "type": "function"
  149. # }
  150. # for tool in self.tools
  151. # ],
  152. functions=[
  153. {
  154. "name": tool.name,
  155. "description": tool.description,
  156. "parameters": tool.parameters,
  157. }
  158. for tool in self.tools
  159. ],
  160. stream=stream,
  161. )
  162. async def handle_function_or_tool_call(
  163. self,
  164. function_name: str,
  165. function_arguments: str,
  166. tool_id: Optional[str] = None,
  167. *args,
  168. **kwargs,
  169. ) -> ToolResult:
  170. await self.conversation.add_message(
  171. Message(
  172. role="assistant",
  173. tool_calls=(
  174. [
  175. {
  176. "id": tool_id,
  177. "function": {
  178. "name": function_name,
  179. "arguments": function_arguments,
  180. },
  181. }
  182. ]
  183. if tool_id
  184. else None
  185. ),
  186. function_call=(
  187. {
  188. "name": function_name,
  189. "arguments": function_arguments,
  190. }
  191. if not tool_id
  192. else None
  193. ),
  194. )
  195. )
  196. if tool := next(
  197. (t for t in self.tools if t.name == function_name), None
  198. ):
  199. merged_kwargs = {**kwargs, **json.loads(function_arguments)}
  200. raw_result = await tool.results_function(*args, **merged_kwargs)
  201. llm_formatted_result = tool.llm_format_function(raw_result)
  202. tool_result = ToolResult(
  203. raw_result=raw_result,
  204. llm_formatted_result=llm_formatted_result,
  205. )
  206. if tool.stream_function:
  207. tool_result.stream_result = tool.stream_function(raw_result)
  208. else:
  209. error_message = f"The requested tool '{function_name}' is not available. Available tools: {', '.join(t.name for t in self.tools)}"
  210. tool_result = ToolResult(
  211. raw_result=error_message,
  212. llm_formatted_result=error_message,
  213. )
  214. await self.conversation.add_message(
  215. Message(
  216. role="tool" if tool_id else "function",
  217. content=str(tool_result.llm_formatted_result),
  218. name=function_name,
  219. )
  220. )
  221. return tool_result