agent.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # type: ignore
  2. import asyncio
  3. import json
  4. import logging
  5. from abc import ABC, abstractmethod
  6. from datetime import datetime
  7. from json import JSONDecodeError
  8. from typing import Any, AsyncGenerator, Optional, Type
  9. from pydantic import BaseModel
  10. from core.base.abstractions import (
  11. GenerationConfig,
  12. LLMChatCompletion,
  13. Message,
  14. )
  15. from core.base.providers import CompletionProvider, DatabaseProvider
  16. from shared.abstractions.tool import Tool, ToolResult
  17. logger = logging.getLogger()
  18. class Conversation:
  19. def __init__(self):
  20. self.messages: list[Message] = []
  21. self._lock = asyncio.Lock()
  22. async def add_message(self, message):
  23. async with self._lock:
  24. self.messages.append(message)
  25. async def get_messages(self) -> list[dict[str, Any]]:
  26. async with self._lock:
  27. return [
  28. {**msg.model_dump(exclude_none=True), "role": str(msg.role)}
  29. for msg in self.messages
  30. ]
  31. # TODO - Move agents to provider pattern
  32. class AgentConfig(BaseModel):
  33. rag_rag_agent_static_prompt: str = "static_rag_agent"
  34. rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted"
  35. stream: bool = False
  36. include_tools: bool = True
  37. max_iterations: int = 10
  38. @classmethod
  39. def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
  40. base_args = cls.model_fields.keys()
  41. filtered_kwargs = {
  42. k: v if v != "None" else None
  43. for k, v in kwargs.items()
  44. if k in base_args
  45. }
  46. return cls(**filtered_kwargs) # type: ignore
  47. class Agent(ABC):
  48. def __init__(
  49. self,
  50. llm_provider: CompletionProvider,
  51. database_provider: DatabaseProvider,
  52. config: AgentConfig,
  53. rag_generation_config: GenerationConfig,
  54. ):
  55. self.llm_provider = llm_provider
  56. self.database_provider: DatabaseProvider = database_provider
  57. self.config = config
  58. self.conversation = Conversation()
  59. self._completed = False
  60. self._tools: list[Tool] = []
  61. self.tool_calls: list[dict] = []
  62. self.rag_generation_config = rag_generation_config
  63. # self._register_tools()
  64. @abstractmethod
  65. def _register_tools(self):
  66. pass
  67. async def _setup(
  68. self, system_instruction: Optional[str] = None, *args, **kwargs
  69. ):
  70. await self.conversation.add_message(
  71. Message(
  72. role="system",
  73. content=system_instruction
  74. or (
  75. await self.database_provider.prompts_handler.get_cached_prompt(
  76. self.config.rag_rag_agent_static_prompt,
  77. inputs={
  78. "date": str(datetime.now().strftime("%m/%d/%Y"))
  79. },
  80. )
  81. + f"\n Note,you only have {self.config.max_iterations} iterations or tool calls to reach a conclusion before your operation terminates."
  82. ),
  83. )
  84. )
  85. @property
  86. def tools(self) -> list[Tool]:
  87. return self._tools
  88. @tools.setter
  89. def tools(self, tools: list[Tool]):
  90. self._tools = tools
  91. @abstractmethod
  92. async def arun(
  93. self,
  94. system_instruction: Optional[str] = None,
  95. messages: Optional[list[Message]] = None,
  96. *args,
  97. **kwargs,
  98. ) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]:
  99. pass
  100. @abstractmethod
  101. async def process_llm_response(
  102. self,
  103. response: Any,
  104. *args,
  105. **kwargs,
  106. ) -> None | AsyncGenerator[str, None]:
  107. pass
  108. async def execute_tool(self, tool_name: str, *args, **kwargs) -> str:
  109. if tool := next((t for t in self.tools if t.name == tool_name), None):
  110. return await tool.results_function(*args, **kwargs)
  111. else:
  112. return f"Error: Tool {tool_name} not found."
  113. def get_generation_config(
  114. self, last_message: dict, stream: bool = False
  115. ) -> GenerationConfig:
  116. if (
  117. last_message["role"] in ["tool", "function"]
  118. and last_message["content"] != ""
  119. and "ollama" in self.rag_generation_config.model
  120. or not self.config.include_tools
  121. ):
  122. return GenerationConfig(
  123. **self.rag_generation_config.model_dump(
  124. exclude={"functions", "tools", "stream"}
  125. ),
  126. stream=stream,
  127. )
  128. return GenerationConfig(
  129. **self.rag_generation_config.model_dump(
  130. exclude={"functions", "tools", "stream"}
  131. ),
  132. # FIXME: Use tools instead of functions
  133. # TODO - Investigate why `tools` fails with OpenAI+LiteLLM
  134. tools=(
  135. [
  136. {
  137. "function": {
  138. "name": tool.name,
  139. "description": tool.description,
  140. "parameters": tool.parameters,
  141. },
  142. "type": "function",
  143. "name": tool.name,
  144. }
  145. for tool in self.tools
  146. ]
  147. if self.tools
  148. else None
  149. ),
  150. stream=stream,
  151. )
  152. async def handle_function_or_tool_call(
  153. self,
  154. function_name: str,
  155. function_arguments: str,
  156. tool_id: Optional[str] = None,
  157. save_messages: bool = True,
  158. *args,
  159. **kwargs,
  160. ) -> ToolResult:
  161. logger.debug(
  162. f"Calling function: {function_name}, args: {function_arguments}, tool_id: {tool_id}"
  163. )
  164. if tool := next(
  165. (t for t in self.tools if t.name == function_name), None
  166. ):
  167. try:
  168. function_args = json.loads(function_arguments)
  169. except JSONDecodeError as e:
  170. error_message = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with `JSONDecodeError`."
  171. if save_messages:
  172. await self.conversation.add_message(
  173. Message(
  174. role="tool" if tool_id else "function",
  175. content=error_message,
  176. name=function_name,
  177. tool_call_id=tool_id,
  178. )
  179. )
  180. merged_kwargs = {**kwargs, **function_args}
  181. try:
  182. raw_result = await tool.execute(*args, **merged_kwargs)
  183. llm_formatted_result = tool.llm_format_function(raw_result)
  184. except Exception as e:
  185. raw_result = f"Calling the requested tool '{function_name}' with arguments {function_arguments} failed with an exception: {e}."
  186. logger.error(raw_result)
  187. llm_formatted_result = raw_result
  188. tool_result = ToolResult(
  189. raw_result=raw_result,
  190. llm_formatted_result=llm_formatted_result,
  191. )
  192. if tool.stream_function:
  193. tool_result.stream_result = tool.stream_function(raw_result)
  194. if save_messages:
  195. await self.conversation.add_message(
  196. Message(
  197. role="tool" if tool_id else "function",
  198. content=str(tool_result.llm_formatted_result),
  199. name=function_name,
  200. tool_call_id=tool_id,
  201. )
  202. )
  203. # HACK - to fix issues with claude thinking + tool use [https://github.com/anthropics/anthropic-cookbook/blob/main/extended_thinking/extended_thinking_with_tool_use.ipynb]
  204. logger.debug(
  205. f"Extended thinking - Claude needs a particular message continuation which however breaks other models. Model in use : {self.rag_generation_config.model}"
  206. )
  207. is_anthropic = (
  208. self.rag_generation_config.model
  209. and "anthropic/" in self.rag_generation_config.model
  210. )
  211. if (
  212. self.rag_generation_config.extended_thinking
  213. and is_anthropic
  214. ):
  215. await self.conversation.add_message(
  216. Message(
  217. role="user",
  218. content="Continue...",
  219. )
  220. )
  221. self.tool_calls.append(
  222. {
  223. "name": function_name,
  224. "args": function_arguments,
  225. }
  226. )
  227. return tool_result
  228. # TODO - Move agents to provider pattern
  229. class RAGAgentConfig(AgentConfig):
  230. rag_rag_agent_static_prompt: str = "static_rag_agent"
  231. rag_agent_dynamic_prompt: str = "dynamic_reasoning_rag_agent_prompted"
  232. stream: bool = False
  233. include_tools: bool = True
  234. max_iterations: int = 10
  235. # tools: list[str] = [] # HACK - unused variable.
  236. # Default RAG tools
  237. rag_tools: list[str] = [
  238. "search_file_descriptions",
  239. "search_file_knowledge",
  240. "get_file_content",
  241. # Web search tools - disabled by default
  242. # "web_search",
  243. # "web_scrape",
  244. # "tavily_search",
  245. # "tavily_extract",
  246. ]
  247. # Default Research tools
  248. research_tools: list[str] = [
  249. "rag",
  250. "reasoning",
  251. # DISABLED by default
  252. "critique",
  253. "python_executor",
  254. ]
  255. @classmethod
  256. def create(cls: Type["AgentConfig"], **kwargs: Any) -> "AgentConfig":
  257. base_args = cls.model_fields.keys()
  258. filtered_kwargs = {
  259. k: v if v != "None" else None
  260. for k, v in kwargs.items()
  261. if k in base_args
  262. }
  263. filtered_kwargs["tools"] = kwargs.get("tools", None) or kwargs.get(
  264. "tool_names", None
  265. )
  266. return cls(**filtered_kwargs) # type: ignore