base.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import asyncio
  2. import logging
  3. from abc import ABCMeta
  4. from typing import AsyncGenerator, Generator, Optional
  5. from core.base.abstractions import (
  6. AsyncSyncMeta,
  7. LLMChatCompletion,
  8. LLMChatCompletionChunk,
  9. Message,
  10. syncable,
  11. )
  12. from core.base.agent import Agent, Conversation
  13. logger = logging.getLogger()
  14. class CombinedMeta(AsyncSyncMeta, ABCMeta):
  15. pass
  16. def sync_wrapper(async_gen):
  17. loop = asyncio.get_event_loop()
  18. def wrapper():
  19. try:
  20. while True:
  21. try:
  22. yield loop.run_until_complete(async_gen.__anext__())
  23. except StopAsyncIteration:
  24. break
  25. finally:
  26. loop.run_until_complete(async_gen.aclose())
  27. return wrapper()
  28. class R2RAgent(Agent, metaclass=CombinedMeta):
  29. def __init__(self, *args, **kwargs):
  30. super().__init__(*args, **kwargs)
  31. self._register_tools()
  32. self._reset()
  33. def _reset(self):
  34. self._completed = False
  35. self.conversation = Conversation()
  36. @syncable
  37. async def arun(
  38. self,
  39. messages: list[Message],
  40. system_instruction: Optional[str] = None,
  41. *args,
  42. **kwargs,
  43. ) -> list[dict]:
  44. # TODO - Make this method return a list of messages.
  45. self._reset()
  46. await self._setup(system_instruction)
  47. if messages:
  48. for message in messages:
  49. await self.conversation.add_message(message)
  50. while not self._completed:
  51. messages_list = await self.conversation.get_messages()
  52. generation_config = self.get_generation_config(messages_list[-1])
  53. response = await self.llm_provider.aget_completion(
  54. messages_list,
  55. generation_config,
  56. )
  57. await self.process_llm_response(response, *args, **kwargs)
  58. # Get the output messages
  59. all_messages: list[dict] = await self.conversation.get_messages()
  60. all_messages.reverse()
  61. output_messages = []
  62. for message_2 in all_messages:
  63. if (
  64. message_2.get("content")
  65. and message_2.get("content") != messages[-1].content
  66. ):
  67. output_messages.append(message_2)
  68. else:
  69. break
  70. output_messages.reverse()
  71. return output_messages
  72. async def process_llm_response(
  73. self, response: LLMChatCompletion, *args, **kwargs
  74. ) -> None:
  75. if not self._completed:
  76. message = response.choices[0].message
  77. if message.function_call:
  78. await self.handle_function_or_tool_call(
  79. message.function_call.name,
  80. message.function_call.arguments,
  81. *args,
  82. **kwargs,
  83. )
  84. elif message.tool_calls:
  85. for tool_call in message.tool_calls:
  86. await self.handle_function_or_tool_call(
  87. tool_call.function.name,
  88. tool_call.function.arguments,
  89. *args,
  90. **kwargs,
  91. )
  92. else:
  93. await self.conversation.add_message(
  94. Message(role="assistant", content=message.content)
  95. )
  96. self._completed = True
  97. class R2RStreamingAgent(R2RAgent):
  98. async def arun( # type: ignore
  99. self,
  100. system_instruction: Optional[str] = None,
  101. messages: Optional[list[Message]] = None,
  102. *args,
  103. **kwargs,
  104. ) -> AsyncGenerator[str, None]:
  105. self._reset()
  106. await self._setup(system_instruction)
  107. if messages:
  108. for message in messages:
  109. await self.conversation.add_message(message)
  110. while not self._completed:
  111. messages_list = await self.conversation.get_messages()
  112. generation_config = self.get_generation_config(
  113. messages_list[-1], stream=True
  114. )
  115. stream = self.llm_provider.aget_completion_stream(
  116. messages_list,
  117. generation_config,
  118. )
  119. async for proc_chunk in self.process_llm_response(
  120. stream, *args, **kwargs
  121. ):
  122. yield proc_chunk
  123. def run(
  124. self, system_instruction, messages, *args, **kwargs
  125. ) -> Generator[str, None, None]:
  126. return sync_wrapper(
  127. self.arun(system_instruction, messages, *args, **kwargs)
  128. )
  129. async def process_llm_response( # type: ignore
  130. self,
  131. stream: AsyncGenerator[LLMChatCompletionChunk, None],
  132. *args,
  133. **kwargs,
  134. ) -> AsyncGenerator[str, None]:
  135. function_name = None
  136. function_arguments = ""
  137. content_buffer = ""
  138. async for chunk in stream:
  139. delta = chunk.choices[0].delta
  140. if delta.tool_calls:
  141. for tool_call in delta.tool_calls:
  142. if not tool_call.function:
  143. logger.info("Tool function not found in tool call.")
  144. continue
  145. name = tool_call.function.name
  146. if not name:
  147. logger.info("Tool name not found in tool call.")
  148. continue
  149. arguments = tool_call.function.arguments
  150. if not arguments:
  151. logger.info("Tool arguments not found in tool call.")
  152. continue
  153. results = await self.handle_function_or_tool_call(
  154. name,
  155. arguments,
  156. # FIXME: tool_call.id,
  157. *args,
  158. **kwargs,
  159. )
  160. yield "<tool_call>"
  161. yield f"<name>{name}</name>"
  162. yield f"<arguments>{arguments}</arguments>"
  163. yield f"<results>{results.llm_formatted_result}</results>"
  164. yield "</tool_call>"
  165. if delta.function_call:
  166. if delta.function_call.name:
  167. function_name = delta.function_call.name
  168. if delta.function_call.arguments:
  169. function_arguments += delta.function_call.arguments
  170. elif delta.content:
  171. if content_buffer == "":
  172. yield "<completion>"
  173. content_buffer += delta.content
  174. yield delta.content
  175. if chunk.choices[0].finish_reason == "function_call":
  176. if not function_name:
  177. logger.info("Function name not found in function call.")
  178. continue
  179. yield "<function_call>"
  180. yield f"<name>{function_name}</name>"
  181. yield f"<arguments>{function_arguments}</arguments>"
  182. tool_result = await self.handle_function_or_tool_call(
  183. function_name, function_arguments, *args, **kwargs
  184. )
  185. if tool_result.stream_result:
  186. yield f"<results>{tool_result.stream_result}</results>"
  187. else:
  188. yield f"<results>{tool_result.llm_formatted_result}</results>"
  189. yield "</function_call>"
  190. function_name = None
  191. function_arguments = ""
  192. elif chunk.choices[0].finish_reason == "stop":
  193. if content_buffer:
  194. await self.conversation.add_message(
  195. Message(role="assistant", content=content_buffer)
  196. )
  197. self._completed = True
  198. yield "</completion>"
  199. # Handle any remaining content after the stream ends
  200. if content_buffer and not self._completed:
  201. await self.conversation.add_message(
  202. Message(role="assistant", content=content_buffer)
  203. )
  204. self._completed = True
  205. yield "</completion>"