12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484 |
- import asyncio
- import json
- import logging
- import re
- from abc import ABCMeta
- from typing import AsyncGenerator, Optional, Tuple
- from core.base import AsyncSyncMeta, LLMChatCompletion, Message, syncable
- from core.base.agent import Agent, Conversation
- from core.utils import (
- CitationTracker,
- SearchResultsCollector,
- SSEFormatter,
- convert_nonserializable_objects,
- dump_obj,
- find_new_citation_spans,
- )
- logger = logging.getLogger()
- class CombinedMeta(AsyncSyncMeta, ABCMeta):
- pass
- def sync_wrapper(async_gen):
- loop = asyncio.get_event_loop()
- def wrapper():
- try:
- while True:
- try:
- yield loop.run_until_complete(async_gen.__anext__())
- except StopAsyncIteration:
- break
- finally:
- loop.run_until_complete(async_gen.aclose())
- return wrapper()
- class R2RAgent(Agent, metaclass=CombinedMeta):
- def __init__(self, *args, **kwargs):
- self.search_results_collector = SearchResultsCollector()
- super().__init__(*args, **kwargs)
- self._reset()
- async def _generate_llm_summary(self, iterations_count: int) -> str:
- """
- Generate a summary of the conversation using the LLM when max iterations are exceeded.
- Args:
- iterations_count: The number of iterations that were completed
- Returns:
- A string containing the LLM-generated summary
- """
- try:
- # Get all messages in the conversation
- all_messages = await self.conversation.get_messages()
- # Create a prompt for the LLM to summarize
- summary_prompt = {
- "role": "user",
- "content": (
- f"The conversation has reached the maximum limit of {iterations_count} iterations "
- f"without completing the task. Please provide a concise summary of: "
- f"1) The key information you've gathered that's relevant to the original query, "
- f"2) What you've attempted so far and why it's incomplete, and "
- f"3) A specific recommendation for how to proceed. "
- f"Keep your summary brief (3-4 sentences total) and focused on the most valuable insights. If it is possible to answer the original user query, then do so now instead."
- f"Start with '⚠️ **Maximum iterations exceeded**'"
- ),
- }
- # Create a new message list with just the conversation history and summary request
- summary_messages = all_messages + [summary_prompt]
- # Get a completion for the summary
- generation_config = self.get_generation_config(summary_prompt)
- response = await self.llm_provider.aget_completion(
- summary_messages,
- generation_config,
- )
- return response.choices[0].message.content
- except Exception as e:
- logger.error(f"Error generating LLM summary: {str(e)}")
- # Fall back to basic summary if LLM generation fails
- return (
- "⚠️ **Maximum iterations exceeded**\n\n"
- "The agent reached the maximum iteration limit without completing the task. "
- "Consider breaking your request into smaller steps or refining your query."
- )
- def _reset(self):
- self._completed = False
- self.conversation = Conversation()
- @syncable
- async def arun(
- self,
- messages: list[Message],
- system_instruction: Optional[str] = None,
- *args,
- **kwargs,
- ) -> list[dict]:
- self._reset()
- await self._setup(system_instruction)
- if messages:
- for message in messages:
- await self.conversation.add_message(message)
- iterations_count = 0
- while (
- not self._completed
- and iterations_count < self.config.max_iterations
- ):
- iterations_count += 1
- messages_list = await self.conversation.get_messages()
- generation_config = self.get_generation_config(messages_list[-1])
- response = await self.llm_provider.aget_completion(
- messages_list,
- generation_config,
- )
- logger.debug(f"R2RAgent response: {response}")
- await self.process_llm_response(response, *args, **kwargs)
- if not self._completed:
- # Generate a summary of the conversation using the LLM
- summary = await self._generate_llm_summary(iterations_count)
- await self.conversation.add_message(
- Message(role="assistant", content=summary)
- )
- # Return final content
- all_messages: list[dict] = await self.conversation.get_messages()
- all_messages.reverse()
- output_messages = []
- for message_2 in all_messages:
- if (
- # message_2.get("content")
- message_2.get("content") != messages[-1].content
- ):
- output_messages.append(message_2)
- else:
- break
- output_messages.reverse()
- return output_messages
- async def process_llm_response(
- self, response: LLMChatCompletion, *args, **kwargs
- ) -> None:
- if not self._completed:
- message = response.choices[0].message
- finish_reason = response.choices[0].finish_reason
- if finish_reason == "stop":
- self._completed = True
- # Determine which provider we're using
- using_anthropic = (
- "anthropic" in self.rag_generation_config.model.lower()
- )
- # OPENAI HANDLING
- if not using_anthropic:
- if message.tool_calls:
- assistant_msg = Message(
- role="assistant",
- content="",
- tool_calls=[msg.dict() for msg in message.tool_calls],
- )
- await self.conversation.add_message(assistant_msg)
- # If there are multiple tool_calls, call them sequentially here
- for tool_call in message.tool_calls:
- await self.handle_function_or_tool_call(
- tool_call.function.name,
- tool_call.function.arguments,
- tool_id=tool_call.id,
- *args,
- **kwargs,
- )
- else:
- await self.conversation.add_message(
- Message(role="assistant", content=message.content)
- )
- self._completed = True
- else:
- # First handle thinking blocks if present
- if (
- hasattr(message, "structured_content")
- and message.structured_content
- ):
- # Check if structured_content contains any tool_use blocks
- has_tool_use = any(
- block.get("type") == "tool_use"
- for block in message.structured_content
- )
- if not has_tool_use and message.tool_calls:
- # If it has thinking but no tool_use, add a separate message with structured_content
- assistant_msg = Message(
- role="assistant",
- structured_content=message.structured_content, # Use structured_content field
- )
- await self.conversation.add_message(assistant_msg)
- # Add explicit tool_use blocks in a separate message
- tool_uses = []
- for tool_call in message.tool_calls:
- # Safely parse arguments if they're a string
- try:
- if isinstance(
- tool_call.function.arguments, str
- ):
- input_args = json.loads(
- tool_call.function.arguments
- )
- else:
- input_args = tool_call.function.arguments
- except json.JSONDecodeError:
- logger.error(
- f"Failed to parse tool arguments: {tool_call.function.arguments}"
- )
- input_args = {
- "_raw": tool_call.function.arguments
- }
- tool_uses.append(
- {
- "type": "tool_use",
- "id": tool_call.id,
- "name": tool_call.function.name,
- "input": input_args,
- }
- )
- # Add tool_use blocks as a separate assistant message with structured content
- if tool_uses:
- await self.conversation.add_message(
- Message(
- role="assistant",
- structured_content=tool_uses,
- content="",
- )
- )
- else:
- # If it already has tool_use or no tool_calls, preserve original structure
- assistant_msg = Message(
- role="assistant",
- structured_content=message.structured_content,
- )
- await self.conversation.add_message(assistant_msg)
- elif message.content:
- # For regular text content
- await self.conversation.add_message(
- Message(role="assistant", content=message.content)
- )
- # If there are tool calls, add them as structured content
- if message.tool_calls:
- tool_uses = []
- for tool_call in message.tool_calls:
- # Same safe parsing as above
- try:
- if isinstance(
- tool_call.function.arguments, str
- ):
- input_args = json.loads(
- tool_call.function.arguments
- )
- else:
- input_args = tool_call.function.arguments
- except json.JSONDecodeError:
- logger.error(
- f"Failed to parse tool arguments: {tool_call.function.arguments}"
- )
- input_args = {
- "_raw": tool_call.function.arguments
- }
- tool_uses.append(
- {
- "type": "tool_use",
- "id": tool_call.id,
- "name": tool_call.function.name,
- "input": input_args,
- }
- )
- await self.conversation.add_message(
- Message(
- role="assistant", structured_content=tool_uses
- )
- )
- # NEW CASE: Handle tool_calls with no content or structured_content
- elif message.tool_calls:
- # Create tool_uses for the message with only tool_calls
- tool_uses = []
- for tool_call in message.tool_calls:
- try:
- if isinstance(tool_call.function.arguments, str):
- input_args = json.loads(
- tool_call.function.arguments
- )
- else:
- input_args = tool_call.function.arguments
- except json.JSONDecodeError:
- logger.error(
- f"Failed to parse tool arguments: {tool_call.function.arguments}"
- )
- input_args = {"_raw": tool_call.function.arguments}
- tool_uses.append(
- {
- "type": "tool_use",
- "id": tool_call.id,
- "name": tool_call.function.name,
- "input": input_args,
- }
- )
- # Add tool_use blocks as a message before processing tools
- if tool_uses:
- await self.conversation.add_message(
- Message(
- role="assistant",
- structured_content=tool_uses,
- )
- )
- # Process the tool calls
- if message.tool_calls:
- for tool_call in message.tool_calls:
- await self.handle_function_or_tool_call(
- tool_call.function.name,
- tool_call.function.arguments,
- tool_id=tool_call.id,
- *args,
- **kwargs,
- )
- class R2RStreamingAgent(R2RAgent):
- """
- Base class for all streaming agents with core streaming functionality.
- Supports emitting messages, tool calls, and results as SSE events.
- """
- # These two regexes will detect bracket references and then find short IDs.
- BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
- SHORT_ID_PATTERN = re.compile(
- r"[A-Za-z0-9]{7,8}"
- ) # 7-8 chars, for example
- def __init__(self, *args, **kwargs):
- # Force streaming on
- if hasattr(kwargs.get("config", {}), "stream"):
- kwargs["config"].stream = True
- super().__init__(*args, **kwargs)
- async def arun(
- self,
- system_instruction: str | None = None,
- messages: list[Message] | None = None,
- *args,
- **kwargs,
- ) -> AsyncGenerator[str, None]:
- """
- Main streaming entrypoint: returns an async generator of SSE lines.
- """
- self._reset()
- await self._setup(system_instruction)
- if messages:
- for m in messages:
- await self.conversation.add_message(m)
- # Initialize citation tracker for this run
- citation_tracker = CitationTracker()
- # Dictionary to store citation payloads by ID
- citation_payloads = {}
- # Track all citations emitted during streaming for final persistence
- self.streaming_citations: list[dict] = []
- async def sse_generator() -> AsyncGenerator[str, None]:
- pending_tool_calls = {}
- partial_text_buffer = ""
- iterations_count = 0
- try:
- # Keep streaming until we complete
- while (
- not self._completed
- and iterations_count < self.config.max_iterations
- ):
- iterations_count += 1
- # 1) Get current messages
- msg_list = await self.conversation.get_messages()
- gen_cfg = self.get_generation_config(
- msg_list[-1], stream=True
- )
- accumulated_thinking = ""
- thinking_signatures = {} # Map thinking content to signatures
- # 2) Start streaming from LLM
- llm_stream = self.llm_provider.aget_completion_stream(
- msg_list, gen_cfg
- )
- async for chunk in llm_stream:
- delta = chunk.choices[0].delta
- finish_reason = chunk.choices[0].finish_reason
- if hasattr(delta, "thinking") and delta.thinking:
- # Accumulate thinking for later use in messages
- accumulated_thinking += delta.thinking
- # Emit SSE "thinking" event
- async for (
- line
- ) in SSEFormatter.yield_thinking_event(
- delta.thinking
- ):
- yield line
- # Add this new handler for thinking signatures
- if hasattr(delta, "thinking_signature"):
- thinking_signatures[accumulated_thinking] = (
- delta.thinking_signature
- )
- accumulated_thinking = ""
- # 3) If new text, accumulate it
- if delta.content:
- partial_text_buffer += delta.content
- # (a) Now emit the newly streamed text as a "message" event
- async for line in SSEFormatter.yield_message_event(
- delta.content
- ):
- yield line
- # (b) Find new citation spans in the accumulated text
- new_citation_spans = find_new_citation_spans(
- partial_text_buffer, citation_tracker
- )
- # Process each new citation span
- for cid, spans in new_citation_spans.items():
- for span in spans:
- # Check if this is the first time we've seen this citation ID
- is_new_citation = (
- citation_tracker.is_new_citation(cid)
- )
- # Get payload if it's a new citation
- payload = None
- if is_new_citation:
- source_obj = self.search_results_collector.find_by_short_id(
- cid
- )
- if source_obj:
- # Store payload for reuse
- payload = dump_obj(source_obj)
- citation_payloads[cid] = payload
- # Create citation event payload
- citation_data = {
- "id": cid,
- "object": "citation",
- "is_new": is_new_citation,
- "span": {
- "start": span[0],
- "end": span[1],
- },
- }
- # Only include full payload for new citations
- if is_new_citation and payload:
- citation_data["payload"] = payload
- # Add to streaming citations for final answer
- self.streaming_citations.append(
- citation_data
- )
- # Emit the citation event
- async for (
- line
- ) in SSEFormatter.yield_citation_event(
- citation_data
- ):
- yield line
- if delta.tool_calls:
- for tc in delta.tool_calls:
- idx = tc.index
- if idx not in pending_tool_calls:
- pending_tool_calls[idx] = {
- "id": tc.id,
- "name": tc.function.name or "",
- "arguments": tc.function.arguments
- or "",
- }
- else:
- # Accumulate partial name/arguments
- if tc.function.name:
- pending_tool_calls[idx]["name"] = (
- tc.function.name
- )
- if tc.function.arguments:
- pending_tool_calls[idx][
- "arguments"
- ] += tc.function.arguments
- # 5) If the stream signals we should handle "tool_calls"
- if finish_reason == "tool_calls":
- # Handle thinking if present
- await self._handle_thinking(
- thinking_signatures, accumulated_thinking
- )
- calls_list = []
- for idx in sorted(pending_tool_calls.keys()):
- cinfo = pending_tool_calls[idx]
- calls_list.append(
- {
- "tool_call_id": cinfo["id"]
- or f"call_{idx}",
- "name": cinfo["name"],
- "arguments": cinfo["arguments"],
- }
- )
- # (a) Emit SSE "tool_call" events
- for c in calls_list:
- tc_data = self._create_tool_call_data(c)
- async for (
- line
- ) in SSEFormatter.yield_tool_call_event(
- tc_data
- ):
- yield line
- # (b) Add an assistant message capturing these calls
- await self._add_tool_calls_message(
- calls_list, partial_text_buffer
- )
- # (c) Execute each tool call in parallel
- await asyncio.gather(
- *[
- self.handle_function_or_tool_call(
- c["name"],
- c["arguments"],
- tool_id=c["tool_call_id"],
- )
- for c in calls_list
- ]
- )
- # Reset buffer & calls
- pending_tool_calls.clear()
- partial_text_buffer = ""
- elif finish_reason == "stop":
- # Handle thinking if present
- await self._handle_thinking(
- thinking_signatures, accumulated_thinking
- )
- # 6) The LLM is done. If we have any leftover partial text,
- # finalize it in the conversation
- if partial_text_buffer:
- # Create the final message with metadata including citations
- final_message = Message(
- role="assistant",
- content=partial_text_buffer,
- metadata={
- "citations": self.streaming_citations
- },
- )
- # Add it to the conversation
- await self.conversation.add_message(
- final_message
- )
- # (a) Prepare final answer with optimized citations
- consolidated_citations = []
- # Group citations by ID with all their spans
- for (
- cid,
- spans,
- ) in citation_tracker.get_all_spans().items():
- if cid in citation_payloads:
- consolidated_citations.append(
- {
- "id": cid,
- "object": "citation",
- "spans": [
- {"start": s[0], "end": s[1]}
- for s in spans
- ],
- "payload": citation_payloads[cid],
- }
- )
- # Create final answer payload
- final_evt_payload = {
- "id": "msg_final",
- "object": "agent.final_answer",
- "generated_answer": partial_text_buffer,
- "citations": consolidated_citations,
- }
- # Emit final answer event
- async for (
- line
- ) in SSEFormatter.yield_final_answer_event(
- final_evt_payload
- ):
- yield line
- # (b) Signal the end of the SSE stream
- yield SSEFormatter.yield_done_event()
- self._completed = True
- break
- # If we exit the while loop due to hitting max iterations
- if not self._completed:
- # Generate a summary using the LLM
- summary = await self._generate_llm_summary(
- iterations_count
- )
- # Send the summary as a message event
- async for line in SSEFormatter.yield_message_event(
- summary
- ):
- yield line
- # Add summary to conversation with citations metadata
- await self.conversation.add_message(
- Message(
- role="assistant",
- content=summary,
- metadata={"citations": self.streaming_citations},
- )
- )
- # Create and emit a final answer payload with the summary
- final_evt_payload = {
- "id": "msg_final",
- "object": "agent.final_answer",
- "generated_answer": summary,
- "citations": consolidated_citations,
- }
- async for line in SSEFormatter.yield_final_answer_event(
- final_evt_payload
- ):
- yield line
- # Signal the end of the SSE stream
- yield SSEFormatter.yield_done_event()
- self._completed = True
- except Exception as e:
- logger.error(f"Error in streaming agent: {str(e)}")
- # Emit error event for client
- async for line in SSEFormatter.yield_error_event(
- f"Agent error: {str(e)}"
- ):
- yield line
- # Send done event to close the stream
- yield SSEFormatter.yield_done_event()
- # Finally, we return the async generator
- async for line in sse_generator():
- yield line
- async def _handle_thinking(
- self, thinking_signatures, accumulated_thinking
- ):
- """Process any accumulated thinking content"""
- if accumulated_thinking:
- structured_content = [
- {
- "type": "thinking",
- "thinking": accumulated_thinking,
- # Anthropic will validate this in their API
- "signature": "placeholder_signature",
- }
- ]
- assistant_msg = Message(
- role="assistant",
- structured_content=structured_content,
- )
- await self.conversation.add_message(assistant_msg)
- elif thinking_signatures:
- for (
- accumulated_thinking,
- thinking_signature,
- ) in thinking_signatures.items():
- structured_content = [
- {
- "type": "thinking",
- "thinking": accumulated_thinking,
- # Anthropic will validate this in their API
- "signature": thinking_signature,
- }
- ]
- assistant_msg = Message(
- role="assistant",
- structured_content=structured_content,
- )
- await self.conversation.add_message(assistant_msg)
- async def _add_tool_calls_message(self, calls_list, partial_text_buffer):
- """Add a message with tool calls to the conversation"""
- assistant_msg = Message(
- role="assistant",
- content=partial_text_buffer or "",
- tool_calls=[
- {
- "id": c["tool_call_id"],
- "type": "function",
- "function": {
- "name": c["name"],
- "arguments": c["arguments"],
- },
- }
- for c in calls_list
- ],
- )
- await self.conversation.add_message(assistant_msg)
- def _create_tool_call_data(self, call_info):
- """Create tool call data structure from call info"""
- return {
- "tool_call_id": call_info["tool_call_id"],
- "name": call_info["name"],
- "arguments": call_info["arguments"],
- }
- def _create_citation_payload(self, short_id, payload):
- """Create citation payload for a short ID"""
- # This will be overridden in RAG subclasses
- # check if as_dict is on payload
- if hasattr(payload, "as_dict"):
- payload = payload.as_dict()
- if hasattr(payload, "dict"):
- payload = payload.dict
- if hasattr(payload, "to_dict"):
- payload = payload.to_dict()
- return {
- "id": f"{short_id}",
- "object": "citation",
- "payload": dump_obj(payload), # Will be populated in RAG agents
- }
- def _create_final_answer_payload(self, answer_text, citations):
- """Create the final answer payload"""
- # This will be extended in RAG subclasses
- return {
- "id": "msg_final",
- "object": "agent.final_answer",
- "generated_answer": answer_text,
- "citations": citations,
- }
- class R2RXMLStreamingAgent(R2RStreamingAgent):
- """
- A streaming agent that parses XML-formatted responses with special handling for:
- - <think> or <Thought> blocks for chain-of-thought reasoning
- - <Action>, <ToolCalls>, <ToolCall> blocks for tool execution
- """
- # We treat <think> or <Thought> as the same token boundaries
- THOUGHT_OPEN = re.compile(r"<(Thought|think)>", re.IGNORECASE)
- THOUGHT_CLOSE = re.compile(r"</(Thought|think)>", re.IGNORECASE)
- # Regexes to parse out <Action>, <ToolCalls>, <ToolCall>, <Name>, <Parameters>, <Response>
- ACTION_PATTERN = re.compile(
- r"<Action>(.*?)</Action>", re.IGNORECASE | re.DOTALL
- )
- TOOLCALLS_PATTERN = re.compile(
- r"<ToolCalls>(.*?)</ToolCalls>", re.IGNORECASE | re.DOTALL
- )
- TOOLCALL_PATTERN = re.compile(
- r"<ToolCall>(.*?)</ToolCall>", re.IGNORECASE | re.DOTALL
- )
- NAME_PATTERN = re.compile(r"<Name>(.*?)</Name>", re.IGNORECASE | re.DOTALL)
- PARAMS_PATTERN = re.compile(
- r"<Parameters>(.*?)</Parameters>", re.IGNORECASE | re.DOTALL
- )
- RESPONSE_PATTERN = re.compile(
- r"<Response>(.*?)</Response>", re.IGNORECASE | re.DOTALL
- )
- async def arun(
- self,
- system_instruction: str | None = None,
- messages: list[Message] | None = None,
- *args,
- **kwargs,
- ) -> AsyncGenerator[str, None]:
- """
- Main streaming entrypoint: returns an async generator of SSE lines.
- """
- self._reset()
- await self._setup(system_instruction)
- if messages:
- for m in messages:
- await self.conversation.add_message(m)
- # Initialize citation tracker for this run
- citation_tracker = CitationTracker()
- # Dictionary to store citation payloads by ID
- citation_payloads = {}
- # Track all citations emitted during streaming for final persistence
- self.streaming_citations: list[dict] = []
- async def sse_generator() -> AsyncGenerator[str, None]:
- iterations_count = 0
- try:
- # Keep streaming until we complete
- while (
- not self._completed
- and iterations_count < self.config.max_iterations
- ):
- iterations_count += 1
- # 1) Get current messages
- msg_list = await self.conversation.get_messages()
- gen_cfg = self.get_generation_config(
- msg_list[-1], stream=True
- )
- # 2) Start streaming from LLM
- llm_stream = self.llm_provider.aget_completion_stream(
- msg_list, gen_cfg
- )
- # Create state variables for each iteration
- iteration_buffer = ""
- yielded_first_event = False
- in_action_block = False
- is_thinking = False
- accumulated_thinking = ""
- thinking_signatures = {}
- async for chunk in llm_stream:
- delta = chunk.choices[0].delta
- finish_reason = chunk.choices[0].finish_reason
- # Handle thinking if present
- if hasattr(delta, "thinking") and delta.thinking:
- # Accumulate thinking for later use in messages
- accumulated_thinking += delta.thinking
- # Emit SSE "thinking" event
- async for (
- line
- ) in SSEFormatter.yield_thinking_event(
- delta.thinking
- ):
- yield line
- # Add this new handler for thinking signatures
- if hasattr(delta, "thinking_signature"):
- thinking_signatures[accumulated_thinking] = (
- delta.thinking_signature
- )
- accumulated_thinking = ""
- # 3) If new text, accumulate it
- if delta.content:
- iteration_buffer += delta.content
- # Check if we have accumulated enough text for a `<Thought>` block
- if len(iteration_buffer) < len("<Thought>"):
- continue
- # Check if we have yielded the first event
- if not yielded_first_event:
- # Emit the first chunk
- if self.THOUGHT_OPEN.findall(iteration_buffer):
- is_thinking = True
- async for (
- line
- ) in SSEFormatter.yield_thinking_event(
- iteration_buffer
- ):
- yield line
- else:
- async for (
- line
- ) in SSEFormatter.yield_message_event(
- iteration_buffer
- ):
- yield line
- # Mark as yielded
- yielded_first_event = True
- continue
- # Check if we are in a thinking block
- if is_thinking:
- # Still thinking, so keep yielding thinking events
- if not self.THOUGHT_CLOSE.findall(
- iteration_buffer
- ):
- # Emit SSE "thinking" event
- async for (
- line
- ) in SSEFormatter.yield_thinking_event(
- delta.content
- ):
- yield line
- continue
- # Done thinking, so emit the last thinking event
- else:
- is_thinking = False
- thought_text = delta.content.split(
- "</Thought>"
- )[0].split("</think>")[0]
- async for (
- line
- ) in SSEFormatter.yield_thinking_event(
- thought_text
- ):
- yield line
- post_thought_text = delta.content.split(
- "</Thought>"
- )[-1].split("</think>")[-1]
- delta.content = post_thought_text
- # (b) Find new citation spans in the accumulated text
- new_citation_spans = find_new_citation_spans(
- iteration_buffer, citation_tracker
- )
- # Process each new citation span
- for cid, spans in new_citation_spans.items():
- for span in spans:
- # Check if this is the first time we've seen this citation ID
- is_new_citation = (
- citation_tracker.is_new_citation(cid)
- )
- # Get payload if it's a new citation
- payload = None
- if is_new_citation:
- source_obj = self.search_results_collector.find_by_short_id(
- cid
- )
- if source_obj:
- # Store payload for reuse
- payload = dump_obj(source_obj)
- citation_payloads[cid] = payload
- # Create citation event payload
- citation_data = {
- "id": cid,
- "object": "citation",
- "is_new": is_new_citation,
- "span": {
- "start": span[0],
- "end": span[1],
- },
- }
- # Only include full payload for new citations
- if is_new_citation and payload:
- citation_data["payload"] = payload
- # Add to streaming citations for final answer
- self.streaming_citations.append(
- citation_data
- )
- # Emit the citation event
- async for (
- line
- ) in SSEFormatter.yield_citation_event(
- citation_data
- ):
- yield line
- # Now prepare to emit the newly streamed text as a "message" event
- if (
- iteration_buffer.count("<")
- and not in_action_block
- ):
- in_action_block = True
- if (
- in_action_block
- and len(
- self.ACTION_PATTERN.findall(
- iteration_buffer
- )
- )
- < 2
- ):
- continue
- elif in_action_block:
- in_action_block = False
- # Emit the post action block text, if it is there
- post_action_text = iteration_buffer.split(
- "</Action>"
- )[-1]
- if post_action_text:
- async for (
- line
- ) in SSEFormatter.yield_message_event(
- post_action_text
- ):
- yield line
- else:
- async for (
- line
- ) in SSEFormatter.yield_message_event(
- delta.content
- ):
- yield line
- elif finish_reason == "stop":
- break
- # Process any accumulated thinking
- await self._handle_thinking(
- thinking_signatures, accumulated_thinking
- )
- # 6) The LLM is done. If we have any leftover partial text,
- # finalize it in the conversation
- if iteration_buffer:
- # Create the final message with metadata including citations
- final_message = Message(
- role="assistant",
- content=iteration_buffer,
- metadata={"citations": self.streaming_citations},
- )
- # Add it to the conversation
- await self.conversation.add_message(final_message)
- # --- 4) Process any <Action>/<ToolCalls> blocks, or mark completed
- action_matches = self.ACTION_PATTERN.findall(
- iteration_buffer
- )
- if len(action_matches) > 0:
- # Process each ToolCall
- xml_toolcalls = "<ToolCalls>"
- for action_block in action_matches:
- tool_calls_text = []
- # Look for ToolCalls wrapper, or use the raw action block
- calls_wrapper = self.TOOLCALLS_PATTERN.findall(
- action_block
- )
- if calls_wrapper:
- for tw in calls_wrapper:
- tool_calls_text.append(tw)
- else:
- tool_calls_text.append(action_block)
- for calls_region in tool_calls_text:
- calls_found = self.TOOLCALL_PATTERN.findall(
- calls_region
- )
- for tc_block in calls_found:
- tool_name, tool_params = (
- self._parse_single_tool_call(tc_block)
- )
- if tool_name:
- # Emit SSE event for tool call
- tool_call_id = (
- f"call_{abs(hash(tc_block))}"
- )
- call_evt_data = {
- "tool_call_id": tool_call_id,
- "name": tool_name,
- "arguments": json.dumps(
- tool_params
- ),
- }
- async for line in (
- SSEFormatter.yield_tool_call_event(
- call_evt_data
- )
- ):
- yield line
- try:
- tool_result = await self.handle_function_or_tool_call(
- tool_name,
- json.dumps(tool_params),
- tool_id=tool_call_id,
- save_messages=False,
- )
- result_content = tool_result.llm_formatted_result
- except Exception as e:
- result_content = f"Error in tool '{tool_name}': {str(e)}"
- xml_toolcalls += (
- f"<ToolCall>"
- f"<Name>{tool_name}</Name>"
- f"<Parameters>{json.dumps(tool_params)}</Parameters>"
- f"<Result>{result_content}</Result>"
- f"</ToolCall>"
- )
- # Emit SSE tool result for non-result tools
- result_data = {
- "tool_call_id": tool_call_id,
- "role": "tool",
- "content": json.dumps(
- convert_nonserializable_objects(
- result_content
- )
- ),
- }
- async for line in SSEFormatter.yield_tool_result_event(
- result_data
- ):
- yield line
- xml_toolcalls += "</ToolCalls>"
- pre_action_text = iteration_buffer[
- : iteration_buffer.find(action_block)
- ]
- post_action_text = iteration_buffer[
- iteration_buffer.find(action_block)
- + len(action_block) :
- ]
- iteration_text = (
- pre_action_text + xml_toolcalls + post_action_text
- )
- # Update the conversation with tool results
- await self.conversation.add_message(
- Message(
- role="assistant",
- content=iteration_text,
- metadata={
- "citations": self.streaming_citations
- },
- )
- )
- else:
- # (a) Prepare final answer with optimized citations
- consolidated_citations = []
- # Group citations by ID with all their spans
- for (
- cid,
- spans,
- ) in citation_tracker.get_all_spans().items():
- if cid in citation_payloads:
- consolidated_citations.append(
- {
- "id": cid,
- "object": "citation",
- "spans": [
- {"start": s[0], "end": s[1]}
- for s in spans
- ],
- "payload": citation_payloads[cid],
- }
- )
- # Create final answer payload
- final_evt_payload = {
- "id": "msg_final",
- "object": "agent.final_answer",
- "generated_answer": iteration_buffer,
- "citations": consolidated_citations,
- }
- # Emit final answer event
- async for (
- line
- ) in SSEFormatter.yield_final_answer_event(
- final_evt_payload
- ):
- yield line
- # (b) Signal the end of the SSE stream
- yield SSEFormatter.yield_done_event()
- self._completed = True
- # If we exit the while loop due to hitting max iterations
- if not self._completed:
- # Generate a summary using the LLM
- summary = await self._generate_llm_summary(
- iterations_count
- )
- # Send the summary as a message event
- async for line in SSEFormatter.yield_message_event(
- summary
- ):
- yield line
- # Add summary to conversation with citations metadata
- await self.conversation.add_message(
- Message(
- role="assistant",
- content=summary,
- metadata={"citations": self.streaming_citations},
- )
- )
- # Create and emit a final answer payload with the summary
- final_evt_payload = {
- "id": "msg_final",
- "object": "agent.final_answer",
- "generated_answer": summary,
- "citations": consolidated_citations,
- }
- async for line in SSEFormatter.yield_final_answer_event(
- final_evt_payload
- ):
- yield line
- # Signal the end of the SSE stream
- yield SSEFormatter.yield_done_event()
- self._completed = True
- except Exception as e:
- logger.error(f"Error in streaming agent: {str(e)}")
- # Emit error event for client
- async for line in SSEFormatter.yield_error_event(
- f"Agent error: {str(e)}"
- ):
- yield line
- # Send done event to close the stream
- yield SSEFormatter.yield_done_event()
- # Finally, we return the async generator
- async for line in sse_generator():
- yield line
- def _parse_single_tool_call(
- self, toolcall_text: str
- ) -> Tuple[Optional[str], dict]:
- """
- Parse a ToolCall block to extract the name and parameters.
- Args:
- toolcall_text: The text content of a ToolCall block
- Returns:
- Tuple of (tool_name, tool_parameters)
- """
- name_match = self.NAME_PATTERN.search(toolcall_text)
- if not name_match:
- return None, {}
- tool_name = name_match.group(1).strip()
- params_match = self.PARAMS_PATTERN.search(toolcall_text)
- if not params_match:
- return tool_name, {}
- raw_params = params_match.group(1).strip()
- try:
- # Handle potential JSON parsing issues
- # First try direct parsing
- tool_params = json.loads(raw_params)
- except json.JSONDecodeError:
- # If that fails, try to clean up the JSON string
- try:
- # Replace escaped quotes that might cause issues
- cleaned_params = raw_params.replace('\\"', '"')
- # Try again with the cleaned string
- tool_params = json.loads(cleaned_params)
- except json.JSONDecodeError:
- # If all else fails, treat as a plain string value
- tool_params = {"value": raw_params}
- return tool_name, tool_params
- class R2RXMLToolsAgent(R2RAgent):
- """
- A non-streaming agent that:
- - parses <think> or <Thought> blocks as chain-of-thought
- - filters out XML tags related to tool calls and actions
- - processes <Action><ToolCalls><ToolCall> blocks
- - properly extracts citations when they appear in the text
- """
- # We treat <think> or <Thought> as the same token boundaries
- THOUGHT_OPEN = re.compile(r"<(Thought|think)>", re.IGNORECASE)
- THOUGHT_CLOSE = re.compile(r"</(Thought|think)>", re.IGNORECASE)
- # Regexes to parse out <Action>, <ToolCalls>, <ToolCall>, <Name>, <Parameters>, <Response>
- ACTION_PATTERN = re.compile(
- r"<Action>(.*?)</Action>", re.IGNORECASE | re.DOTALL
- )
- TOOLCALLS_PATTERN = re.compile(
- r"<ToolCalls>(.*?)</ToolCalls>", re.IGNORECASE | re.DOTALL
- )
- TOOLCALL_PATTERN = re.compile(
- r"<ToolCall>(.*?)</ToolCall>", re.IGNORECASE | re.DOTALL
- )
- NAME_PATTERN = re.compile(r"<Name>(.*?)</Name>", re.IGNORECASE | re.DOTALL)
- PARAMS_PATTERN = re.compile(
- r"<Parameters>(.*?)</Parameters>", re.IGNORECASE | re.DOTALL
- )
- RESPONSE_PATTERN = re.compile(
- r"<Response>(.*?)</Response>", re.IGNORECASE | re.DOTALL
- )
- async def process_llm_response(self, response, *args, **kwargs):
- """
- Override the base process_llm_response to handle XML structured responses
- including thoughts and tool calls.
- """
- if self._completed:
- return
- message = response.choices[0].message
- finish_reason = response.choices[0].finish_reason
- if not message.content:
- # If there's no content, let the parent class handle the normal tool_calls flow
- return await super().process_llm_response(
- response, *args, **kwargs
- )
- # Get the response content
- content = message.content
- # HACK for gemini
- content = content.replace("```action", "")
- content = content.replace("```tool_code", "")
- content = content.replace("```", "")
- if (
- not content.startswith("<")
- and "deepseek" in self.rag_generation_config.model
- ): # HACK - fix issues with adding `<think>` to the beginning
- content = "<think>" + content
- # Process any tool calls in the content
- action_matches = self.ACTION_PATTERN.findall(content)
- if action_matches:
- xml_toolcalls = "<ToolCalls>"
- for action_block in action_matches:
- tool_calls_text = []
- # Look for ToolCalls wrapper, or use the raw action block
- calls_wrapper = self.TOOLCALLS_PATTERN.findall(action_block)
- if calls_wrapper:
- for tw in calls_wrapper:
- tool_calls_text.append(tw)
- else:
- tool_calls_text.append(action_block)
- # Process each ToolCall
- for calls_region in tool_calls_text:
- calls_found = self.TOOLCALL_PATTERN.findall(calls_region)
- for tc_block in calls_found:
- tool_name, tool_params = self._parse_single_tool_call(
- tc_block
- )
- if tool_name:
- tool_call_id = f"call_{abs(hash(tc_block))}"
- try:
- tool_result = (
- await self.handle_function_or_tool_call(
- tool_name,
- json.dumps(tool_params),
- tool_id=tool_call_id,
- save_messages=False,
- )
- )
- # Add tool result to XML
- xml_toolcalls += (
- f"<ToolCall>"
- f"<Name>{tool_name}</Name>"
- f"<Parameters>{json.dumps(tool_params)}</Parameters>"
- f"<Result>{tool_result.llm_formatted_result}</Result>"
- f"</ToolCall>"
- )
- except Exception as e:
- logger.error(f"Error in tool call: {str(e)}")
- # Add error to XML
- xml_toolcalls += (
- f"<ToolCall>"
- f"<Name>{tool_name}</Name>"
- f"<Parameters>{json.dumps(tool_params)}</Parameters>"
- f"<Result>Error: {str(e)}</Result>"
- f"</ToolCall>"
- )
- xml_toolcalls += "</ToolCalls>"
- pre_action_text = content[: content.find(action_block)]
- post_action_text = content[
- content.find(action_block) + len(action_block) :
- ]
- iteration_text = pre_action_text + xml_toolcalls + post_action_text
- # Create the assistant message
- await self.conversation.add_message(
- Message(role="assistant", content=iteration_text)
- )
- else:
- # Create an assistant message with the content as-is
- await self.conversation.add_message(
- Message(role="assistant", content=content)
- )
- # Only mark as completed if the finish_reason is "stop" or there are no action calls
- # This allows the agent to continue the conversation when tool calls are processed
- if finish_reason == "stop":
- self._completed = True
- def _parse_single_tool_call(
- self, toolcall_text: str
- ) -> Tuple[Optional[str], dict]:
- """
- Parse a ToolCall block to extract the name and parameters.
- Args:
- toolcall_text: The text content of a ToolCall block
- Returns:
- Tuple of (tool_name, tool_parameters)
- """
- name_match = self.NAME_PATTERN.search(toolcall_text)
- if not name_match:
- return None, {}
- tool_name = name_match.group(1).strip()
- params_match = self.PARAMS_PATTERN.search(toolcall_text)
- if not params_match:
- return tool_name, {}
- raw_params = params_match.group(1).strip()
- try:
- # Handle potential JSON parsing issues
- # First try direct parsing
- tool_params = json.loads(raw_params)
- except json.JSONDecodeError:
- # If that fails, try to clean up the JSON string
- try:
- # Replace escaped quotes that might cause issues
- cleaned_params = raw_params.replace('\\"', '"')
- # Try again with the cleaned string
- tool_params = json.loads(cleaned_params)
- except json.JSONDecodeError:
- # If all else fails, treat as a plain string value
- tool_params = {"value": raw_params}
- return tool_name, tool_params
|