import asyncio import logging from abc import ABCMeta from typing import AsyncGenerator, Generator, Optional from core.base.abstractions import ( AsyncSyncMeta, LLMChatCompletion, LLMChatCompletionChunk, Message, syncable, ) from core.base.agent import Agent, Conversation logger = logging.getLogger() class CombinedMeta(AsyncSyncMeta, ABCMeta): pass def sync_wrapper(async_gen): loop = asyncio.get_event_loop() def wrapper(): try: while True: try: yield loop.run_until_complete(async_gen.__anext__()) except StopAsyncIteration: break finally: loop.run_until_complete(async_gen.aclose()) return wrapper() class R2RAgent(Agent, metaclass=CombinedMeta): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._register_tools() self._reset() def _reset(self): self._completed = False self.conversation = Conversation() @syncable async def arun( self, messages: list[Message], system_instruction: Optional[str] = None, *args, **kwargs, ) -> list[dict]: # TODO - Make this method return a list of messages. self._reset() await self._setup(system_instruction) if messages: for message in messages: await self.conversation.add_message(message) while not self._completed: messages_list = await self.conversation.get_messages() generation_config = self.get_generation_config(messages_list[-1]) response = await self.llm_provider.aget_completion( messages_list, generation_config, ) await self.process_llm_response(response, *args, **kwargs) # Get the output messages all_messages: list[dict] = await self.conversation.get_messages() all_messages.reverse() output_messages = [] for message_2 in all_messages: if ( message_2.get("content") and message_2.get("content") != messages[-1].content ): output_messages.append(message_2) else: break output_messages.reverse() return output_messages async def process_llm_response( self, response: LLMChatCompletion, *args, **kwargs ) -> None: if not self._completed: message = response.choices[0].message if message.function_call: await self.handle_function_or_tool_call( message.function_call.name, message.function_call.arguments, *args, **kwargs, ) elif message.tool_calls: for tool_call in message.tool_calls: await self.handle_function_or_tool_call( tool_call.function.name, tool_call.function.arguments, *args, **kwargs, ) else: await self.conversation.add_message( Message(role="assistant", content=message.content) ) self._completed = True class R2RStreamingAgent(R2RAgent): async def arun( # type: ignore self, system_instruction: Optional[str] = None, messages: Optional[list[Message]] = None, *args, **kwargs, ) -> AsyncGenerator[str, None]: self._reset() await self._setup(system_instruction) if messages: for message in messages: await self.conversation.add_message(message) while not self._completed: messages_list = await self.conversation.get_messages() generation_config = self.get_generation_config( messages_list[-1], stream=True ) stream = self.llm_provider.aget_completion_stream( messages_list, generation_config, ) async for proc_chunk in self.process_llm_response( stream, *args, **kwargs ): yield proc_chunk def run( self, system_instruction, messages, *args, **kwargs ) -> Generator[str, None, None]: return sync_wrapper( self.arun(system_instruction, messages, *args, **kwargs) ) async def process_llm_response( # type: ignore self, stream: AsyncGenerator[LLMChatCompletionChunk, None], *args, **kwargs, ) -> AsyncGenerator[str, None]: function_name = None function_arguments = "" content_buffer = "" async for chunk in stream: delta = chunk.choices[0].delta if delta.tool_calls: for tool_call in delta.tool_calls: if not tool_call.function: logger.info("Tool function not found in tool call.") continue name = tool_call.function.name if not name: logger.info("Tool name not found in tool call.") continue arguments = tool_call.function.arguments if not arguments: logger.info("Tool arguments not found in tool call.") continue results = await self.handle_function_or_tool_call( name, arguments, # FIXME: tool_call.id, *args, **kwargs, ) yield "" yield f"{name}" yield f"{arguments}" yield f"{results.llm_formatted_result}" yield "" if delta.function_call: if delta.function_call.name: function_name = delta.function_call.name if delta.function_call.arguments: function_arguments += delta.function_call.arguments elif delta.content: if content_buffer == "": yield "" content_buffer += delta.content yield delta.content if chunk.choices[0].finish_reason == "function_call": if not function_name: logger.info("Function name not found in function call.") continue yield "" yield f"{function_name}" yield f"{function_arguments}" tool_result = await self.handle_function_or_tool_call( function_name, function_arguments, *args, **kwargs ) if tool_result.stream_result: yield f"{tool_result.stream_result}" else: yield f"{tool_result.llm_formatted_result}" yield "" function_name = None function_arguments = "" elif chunk.choices[0].finish_reason == "stop": if content_buffer: await self.conversation.add_message( Message(role="assistant", content=content_buffer) ) self._completed = True yield "" # Handle any remaining content after the stream ends if content_buffer and not self._completed: await self.conversation.add_message( Message(role="assistant", content=content_buffer) ) self._completed = True yield ""