base.py 63 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484
  1. import asyncio
  2. import json
  3. import logging
  4. import re
  5. from abc import ABCMeta
  6. from typing import AsyncGenerator, Optional, Tuple
  7. from core.base import AsyncSyncMeta, LLMChatCompletion, Message, syncable
  8. from core.base.agent import Agent, Conversation
  9. from core.utils import (
  10. CitationTracker,
  11. SearchResultsCollector,
  12. SSEFormatter,
  13. convert_nonserializable_objects,
  14. dump_obj,
  15. find_new_citation_spans,
  16. )
  17. logger = logging.getLogger()
  18. class CombinedMeta(AsyncSyncMeta, ABCMeta):
  19. pass
  20. def sync_wrapper(async_gen):
  21. loop = asyncio.get_event_loop()
  22. def wrapper():
  23. try:
  24. while True:
  25. try:
  26. yield loop.run_until_complete(async_gen.__anext__())
  27. except StopAsyncIteration:
  28. break
  29. finally:
  30. loop.run_until_complete(async_gen.aclose())
  31. return wrapper()
  32. class R2RAgent(Agent, metaclass=CombinedMeta):
  33. def __init__(self, *args, **kwargs):
  34. self.search_results_collector = SearchResultsCollector()
  35. super().__init__(*args, **kwargs)
  36. self._reset()
  37. async def _generate_llm_summary(self, iterations_count: int) -> str:
  38. """
  39. Generate a summary of the conversation using the LLM when max iterations are exceeded.
  40. Args:
  41. iterations_count: The number of iterations that were completed
  42. Returns:
  43. A string containing the LLM-generated summary
  44. """
  45. try:
  46. # Get all messages in the conversation
  47. all_messages = await self.conversation.get_messages()
  48. # Create a prompt for the LLM to summarize
  49. summary_prompt = {
  50. "role": "user",
  51. "content": (
  52. f"The conversation has reached the maximum limit of {iterations_count} iterations "
  53. f"without completing the task. Please provide a concise summary of: "
  54. f"1) The key information you've gathered that's relevant to the original query, "
  55. f"2) What you've attempted so far and why it's incomplete, and "
  56. f"3) A specific recommendation for how to proceed. "
  57. f"Keep your summary brief (3-4 sentences total) and focused on the most valuable insights. If it is possible to answer the original user query, then do so now instead."
  58. f"Start with '⚠️ **Maximum iterations exceeded**'"
  59. ),
  60. }
  61. # Create a new message list with just the conversation history and summary request
  62. summary_messages = all_messages + [summary_prompt]
  63. # Get a completion for the summary
  64. generation_config = self.get_generation_config(summary_prompt)
  65. response = await self.llm_provider.aget_completion(
  66. summary_messages,
  67. generation_config,
  68. )
  69. return response.choices[0].message.content
  70. except Exception as e:
  71. logger.error(f"Error generating LLM summary: {str(e)}")
  72. # Fall back to basic summary if LLM generation fails
  73. return (
  74. "⚠️ **Maximum iterations exceeded**\n\n"
  75. "The agent reached the maximum iteration limit without completing the task. "
  76. "Consider breaking your request into smaller steps or refining your query."
  77. )
  78. def _reset(self):
  79. self._completed = False
  80. self.conversation = Conversation()
  81. @syncable
  82. async def arun(
  83. self,
  84. messages: list[Message],
  85. system_instruction: Optional[str] = None,
  86. *args,
  87. **kwargs,
  88. ) -> list[dict]:
  89. self._reset()
  90. await self._setup(system_instruction)
  91. if messages:
  92. for message in messages:
  93. await self.conversation.add_message(message)
  94. iterations_count = 0
  95. while (
  96. not self._completed
  97. and iterations_count < self.config.max_iterations
  98. ):
  99. iterations_count += 1
  100. messages_list = await self.conversation.get_messages()
  101. generation_config = self.get_generation_config(messages_list[-1])
  102. response = await self.llm_provider.aget_completion(
  103. messages_list,
  104. generation_config,
  105. )
  106. logger.debug(f"R2RAgent response: {response}")
  107. await self.process_llm_response(response, *args, **kwargs)
  108. if not self._completed:
  109. # Generate a summary of the conversation using the LLM
  110. summary = await self._generate_llm_summary(iterations_count)
  111. await self.conversation.add_message(
  112. Message(role="assistant", content=summary)
  113. )
  114. # Return final content
  115. all_messages: list[dict] = await self.conversation.get_messages()
  116. all_messages.reverse()
  117. output_messages = []
  118. for message_2 in all_messages:
  119. if (
  120. # message_2.get("content")
  121. message_2.get("content") != messages[-1].content
  122. ):
  123. output_messages.append(message_2)
  124. else:
  125. break
  126. output_messages.reverse()
  127. return output_messages
  128. async def process_llm_response(
  129. self, response: LLMChatCompletion, *args, **kwargs
  130. ) -> None:
  131. if not self._completed:
  132. message = response.choices[0].message
  133. finish_reason = response.choices[0].finish_reason
  134. if finish_reason == "stop":
  135. self._completed = True
  136. # Determine which provider we're using
  137. using_anthropic = (
  138. "anthropic" in self.rag_generation_config.model.lower()
  139. )
  140. # OPENAI HANDLING
  141. if not using_anthropic:
  142. if message.tool_calls:
  143. assistant_msg = Message(
  144. role="assistant",
  145. content="",
  146. tool_calls=[msg.dict() for msg in message.tool_calls],
  147. )
  148. await self.conversation.add_message(assistant_msg)
  149. # If there are multiple tool_calls, call them sequentially here
  150. for tool_call in message.tool_calls:
  151. await self.handle_function_or_tool_call(
  152. tool_call.function.name,
  153. tool_call.function.arguments,
  154. tool_id=tool_call.id,
  155. *args,
  156. **kwargs,
  157. )
  158. else:
  159. await self.conversation.add_message(
  160. Message(role="assistant", content=message.content)
  161. )
  162. self._completed = True
  163. else:
  164. # First handle thinking blocks if present
  165. if (
  166. hasattr(message, "structured_content")
  167. and message.structured_content
  168. ):
  169. # Check if structured_content contains any tool_use blocks
  170. has_tool_use = any(
  171. block.get("type") == "tool_use"
  172. for block in message.structured_content
  173. )
  174. if not has_tool_use and message.tool_calls:
  175. # If it has thinking but no tool_use, add a separate message with structured_content
  176. assistant_msg = Message(
  177. role="assistant",
  178. structured_content=message.structured_content, # Use structured_content field
  179. )
  180. await self.conversation.add_message(assistant_msg)
  181. # Add explicit tool_use blocks in a separate message
  182. tool_uses = []
  183. for tool_call in message.tool_calls:
  184. # Safely parse arguments if they're a string
  185. try:
  186. if isinstance(
  187. tool_call.function.arguments, str
  188. ):
  189. input_args = json.loads(
  190. tool_call.function.arguments
  191. )
  192. else:
  193. input_args = tool_call.function.arguments
  194. except json.JSONDecodeError:
  195. logger.error(
  196. f"Failed to parse tool arguments: {tool_call.function.arguments}"
  197. )
  198. input_args = {
  199. "_raw": tool_call.function.arguments
  200. }
  201. tool_uses.append(
  202. {
  203. "type": "tool_use",
  204. "id": tool_call.id,
  205. "name": tool_call.function.name,
  206. "input": input_args,
  207. }
  208. )
  209. # Add tool_use blocks as a separate assistant message with structured content
  210. if tool_uses:
  211. await self.conversation.add_message(
  212. Message(
  213. role="assistant",
  214. structured_content=tool_uses,
  215. content="",
  216. )
  217. )
  218. else:
  219. # If it already has tool_use or no tool_calls, preserve original structure
  220. assistant_msg = Message(
  221. role="assistant",
  222. structured_content=message.structured_content,
  223. )
  224. await self.conversation.add_message(assistant_msg)
  225. elif message.content:
  226. # For regular text content
  227. await self.conversation.add_message(
  228. Message(role="assistant", content=message.content)
  229. )
  230. # If there are tool calls, add them as structured content
  231. if message.tool_calls:
  232. tool_uses = []
  233. for tool_call in message.tool_calls:
  234. # Same safe parsing as above
  235. try:
  236. if isinstance(
  237. tool_call.function.arguments, str
  238. ):
  239. input_args = json.loads(
  240. tool_call.function.arguments
  241. )
  242. else:
  243. input_args = tool_call.function.arguments
  244. except json.JSONDecodeError:
  245. logger.error(
  246. f"Failed to parse tool arguments: {tool_call.function.arguments}"
  247. )
  248. input_args = {
  249. "_raw": tool_call.function.arguments
  250. }
  251. tool_uses.append(
  252. {
  253. "type": "tool_use",
  254. "id": tool_call.id,
  255. "name": tool_call.function.name,
  256. "input": input_args,
  257. }
  258. )
  259. await self.conversation.add_message(
  260. Message(
  261. role="assistant", structured_content=tool_uses
  262. )
  263. )
  264. # NEW CASE: Handle tool_calls with no content or structured_content
  265. elif message.tool_calls:
  266. # Create tool_uses for the message with only tool_calls
  267. tool_uses = []
  268. for tool_call in message.tool_calls:
  269. try:
  270. if isinstance(tool_call.function.arguments, str):
  271. input_args = json.loads(
  272. tool_call.function.arguments
  273. )
  274. else:
  275. input_args = tool_call.function.arguments
  276. except json.JSONDecodeError:
  277. logger.error(
  278. f"Failed to parse tool arguments: {tool_call.function.arguments}"
  279. )
  280. input_args = {"_raw": tool_call.function.arguments}
  281. tool_uses.append(
  282. {
  283. "type": "tool_use",
  284. "id": tool_call.id,
  285. "name": tool_call.function.name,
  286. "input": input_args,
  287. }
  288. )
  289. # Add tool_use blocks as a message before processing tools
  290. if tool_uses:
  291. await self.conversation.add_message(
  292. Message(
  293. role="assistant",
  294. structured_content=tool_uses,
  295. )
  296. )
  297. # Process the tool calls
  298. if message.tool_calls:
  299. for tool_call in message.tool_calls:
  300. await self.handle_function_or_tool_call(
  301. tool_call.function.name,
  302. tool_call.function.arguments,
  303. tool_id=tool_call.id,
  304. *args,
  305. **kwargs,
  306. )
  307. class R2RStreamingAgent(R2RAgent):
  308. """
  309. Base class for all streaming agents with core streaming functionality.
  310. Supports emitting messages, tool calls, and results as SSE events.
  311. """
  312. # These two regexes will detect bracket references and then find short IDs.
  313. BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
  314. SHORT_ID_PATTERN = re.compile(
  315. r"[A-Za-z0-9]{7,8}"
  316. ) # 7-8 chars, for example
  317. def __init__(self, *args, **kwargs):
  318. # Force streaming on
  319. if hasattr(kwargs.get("config", {}), "stream"):
  320. kwargs["config"].stream = True
  321. super().__init__(*args, **kwargs)
  322. async def arun(
  323. self,
  324. system_instruction: str | None = None,
  325. messages: list[Message] | None = None,
  326. *args,
  327. **kwargs,
  328. ) -> AsyncGenerator[str, None]:
  329. """
  330. Main streaming entrypoint: returns an async generator of SSE lines.
  331. """
  332. self._reset()
  333. await self._setup(system_instruction)
  334. if messages:
  335. for m in messages:
  336. await self.conversation.add_message(m)
  337. # Initialize citation tracker for this run
  338. citation_tracker = CitationTracker()
  339. # Dictionary to store citation payloads by ID
  340. citation_payloads = {}
  341. # Track all citations emitted during streaming for final persistence
  342. self.streaming_citations: list[dict] = []
  343. async def sse_generator() -> AsyncGenerator[str, None]:
  344. pending_tool_calls = {}
  345. partial_text_buffer = ""
  346. iterations_count = 0
  347. try:
  348. # Keep streaming until we complete
  349. while (
  350. not self._completed
  351. and iterations_count < self.config.max_iterations
  352. ):
  353. iterations_count += 1
  354. # 1) Get current messages
  355. msg_list = await self.conversation.get_messages()
  356. gen_cfg = self.get_generation_config(
  357. msg_list[-1], stream=True
  358. )
  359. accumulated_thinking = ""
  360. thinking_signatures = {} # Map thinking content to signatures
  361. # 2) Start streaming from LLM
  362. llm_stream = self.llm_provider.aget_completion_stream(
  363. msg_list, gen_cfg
  364. )
  365. async for chunk in llm_stream:
  366. delta = chunk.choices[0].delta
  367. finish_reason = chunk.choices[0].finish_reason
  368. if hasattr(delta, "thinking") and delta.thinking:
  369. # Accumulate thinking for later use in messages
  370. accumulated_thinking += delta.thinking
  371. # Emit SSE "thinking" event
  372. async for (
  373. line
  374. ) in SSEFormatter.yield_thinking_event(
  375. delta.thinking
  376. ):
  377. yield line
  378. # Add this new handler for thinking signatures
  379. if hasattr(delta, "thinking_signature"):
  380. thinking_signatures[accumulated_thinking] = (
  381. delta.thinking_signature
  382. )
  383. accumulated_thinking = ""
  384. # 3) If new text, accumulate it
  385. if delta.content:
  386. partial_text_buffer += delta.content
  387. # (a) Now emit the newly streamed text as a "message" event
  388. async for line in SSEFormatter.yield_message_event(
  389. delta.content
  390. ):
  391. yield line
  392. # (b) Find new citation spans in the accumulated text
  393. new_citation_spans = find_new_citation_spans(
  394. partial_text_buffer, citation_tracker
  395. )
  396. # Process each new citation span
  397. for cid, spans in new_citation_spans.items():
  398. for span in spans:
  399. # Check if this is the first time we've seen this citation ID
  400. is_new_citation = (
  401. citation_tracker.is_new_citation(cid)
  402. )
  403. # Get payload if it's a new citation
  404. payload = None
  405. if is_new_citation:
  406. source_obj = self.search_results_collector.find_by_short_id(
  407. cid
  408. )
  409. if source_obj:
  410. # Store payload for reuse
  411. payload = dump_obj(source_obj)
  412. citation_payloads[cid] = payload
  413. # Create citation event payload
  414. citation_data = {
  415. "id": cid,
  416. "object": "citation",
  417. "is_new": is_new_citation,
  418. "span": {
  419. "start": span[0],
  420. "end": span[1],
  421. },
  422. }
  423. # Only include full payload for new citations
  424. if is_new_citation and payload:
  425. citation_data["payload"] = payload
  426. # Add to streaming citations for final answer
  427. self.streaming_citations.append(
  428. citation_data
  429. )
  430. # Emit the citation event
  431. async for (
  432. line
  433. ) in SSEFormatter.yield_citation_event(
  434. citation_data
  435. ):
  436. yield line
  437. if delta.tool_calls:
  438. for tc in delta.tool_calls:
  439. idx = tc.index
  440. if idx not in pending_tool_calls:
  441. pending_tool_calls[idx] = {
  442. "id": tc.id,
  443. "name": tc.function.name or "",
  444. "arguments": tc.function.arguments
  445. or "",
  446. }
  447. else:
  448. # Accumulate partial name/arguments
  449. if tc.function.name:
  450. pending_tool_calls[idx]["name"] = (
  451. tc.function.name
  452. )
  453. if tc.function.arguments:
  454. pending_tool_calls[idx][
  455. "arguments"
  456. ] += tc.function.arguments
  457. # 5) If the stream signals we should handle "tool_calls"
  458. if finish_reason == "tool_calls":
  459. # Handle thinking if present
  460. await self._handle_thinking(
  461. thinking_signatures, accumulated_thinking
  462. )
  463. calls_list = []
  464. for idx in sorted(pending_tool_calls.keys()):
  465. cinfo = pending_tool_calls[idx]
  466. calls_list.append(
  467. {
  468. "tool_call_id": cinfo["id"]
  469. or f"call_{idx}",
  470. "name": cinfo["name"],
  471. "arguments": cinfo["arguments"],
  472. }
  473. )
  474. # (a) Emit SSE "tool_call" events
  475. for c in calls_list:
  476. tc_data = self._create_tool_call_data(c)
  477. async for (
  478. line
  479. ) in SSEFormatter.yield_tool_call_event(
  480. tc_data
  481. ):
  482. yield line
  483. # (b) Add an assistant message capturing these calls
  484. await self._add_tool_calls_message(
  485. calls_list, partial_text_buffer
  486. )
  487. # (c) Execute each tool call in parallel
  488. await asyncio.gather(
  489. *[
  490. self.handle_function_or_tool_call(
  491. c["name"],
  492. c["arguments"],
  493. tool_id=c["tool_call_id"],
  494. )
  495. for c in calls_list
  496. ]
  497. )
  498. # Reset buffer & calls
  499. pending_tool_calls.clear()
  500. partial_text_buffer = ""
  501. elif finish_reason == "stop":
  502. # Handle thinking if present
  503. await self._handle_thinking(
  504. thinking_signatures, accumulated_thinking
  505. )
  506. # 6) The LLM is done. If we have any leftover partial text,
  507. # finalize it in the conversation
  508. if partial_text_buffer:
  509. # Create the final message with metadata including citations
  510. final_message = Message(
  511. role="assistant",
  512. content=partial_text_buffer,
  513. metadata={
  514. "citations": self.streaming_citations
  515. },
  516. )
  517. # Add it to the conversation
  518. await self.conversation.add_message(
  519. final_message
  520. )
  521. # (a) Prepare final answer with optimized citations
  522. consolidated_citations = []
  523. # Group citations by ID with all their spans
  524. for (
  525. cid,
  526. spans,
  527. ) in citation_tracker.get_all_spans().items():
  528. if cid in citation_payloads:
  529. consolidated_citations.append(
  530. {
  531. "id": cid,
  532. "object": "citation",
  533. "spans": [
  534. {"start": s[0], "end": s[1]}
  535. for s in spans
  536. ],
  537. "payload": citation_payloads[cid],
  538. }
  539. )
  540. # Create final answer payload
  541. final_evt_payload = {
  542. "id": "msg_final",
  543. "object": "agent.final_answer",
  544. "generated_answer": partial_text_buffer,
  545. "citations": consolidated_citations,
  546. }
  547. # Emit final answer event
  548. async for (
  549. line
  550. ) in SSEFormatter.yield_final_answer_event(
  551. final_evt_payload
  552. ):
  553. yield line
  554. # (b) Signal the end of the SSE stream
  555. yield SSEFormatter.yield_done_event()
  556. self._completed = True
  557. break
  558. # If we exit the while loop due to hitting max iterations
  559. if not self._completed:
  560. # Generate a summary using the LLM
  561. summary = await self._generate_llm_summary(
  562. iterations_count
  563. )
  564. # Send the summary as a message event
  565. async for line in SSEFormatter.yield_message_event(
  566. summary
  567. ):
  568. yield line
  569. # Add summary to conversation with citations metadata
  570. await self.conversation.add_message(
  571. Message(
  572. role="assistant",
  573. content=summary,
  574. metadata={"citations": self.streaming_citations},
  575. )
  576. )
  577. # Create and emit a final answer payload with the summary
  578. final_evt_payload = {
  579. "id": "msg_final",
  580. "object": "agent.final_answer",
  581. "generated_answer": summary,
  582. "citations": consolidated_citations,
  583. }
  584. async for line in SSEFormatter.yield_final_answer_event(
  585. final_evt_payload
  586. ):
  587. yield line
  588. # Signal the end of the SSE stream
  589. yield SSEFormatter.yield_done_event()
  590. self._completed = True
  591. except Exception as e:
  592. logger.error(f"Error in streaming agent: {str(e)}")
  593. # Emit error event for client
  594. async for line in SSEFormatter.yield_error_event(
  595. f"Agent error: {str(e)}"
  596. ):
  597. yield line
  598. # Send done event to close the stream
  599. yield SSEFormatter.yield_done_event()
  600. # Finally, we return the async generator
  601. async for line in sse_generator():
  602. yield line
  603. async def _handle_thinking(
  604. self, thinking_signatures, accumulated_thinking
  605. ):
  606. """Process any accumulated thinking content"""
  607. if accumulated_thinking:
  608. structured_content = [
  609. {
  610. "type": "thinking",
  611. "thinking": accumulated_thinking,
  612. # Anthropic will validate this in their API
  613. "signature": "placeholder_signature",
  614. }
  615. ]
  616. assistant_msg = Message(
  617. role="assistant",
  618. structured_content=structured_content,
  619. )
  620. await self.conversation.add_message(assistant_msg)
  621. elif thinking_signatures:
  622. for (
  623. accumulated_thinking,
  624. thinking_signature,
  625. ) in thinking_signatures.items():
  626. structured_content = [
  627. {
  628. "type": "thinking",
  629. "thinking": accumulated_thinking,
  630. # Anthropic will validate this in their API
  631. "signature": thinking_signature,
  632. }
  633. ]
  634. assistant_msg = Message(
  635. role="assistant",
  636. structured_content=structured_content,
  637. )
  638. await self.conversation.add_message(assistant_msg)
  639. async def _add_tool_calls_message(self, calls_list, partial_text_buffer):
  640. """Add a message with tool calls to the conversation"""
  641. assistant_msg = Message(
  642. role="assistant",
  643. content=partial_text_buffer or "",
  644. tool_calls=[
  645. {
  646. "id": c["tool_call_id"],
  647. "type": "function",
  648. "function": {
  649. "name": c["name"],
  650. "arguments": c["arguments"],
  651. },
  652. }
  653. for c in calls_list
  654. ],
  655. )
  656. await self.conversation.add_message(assistant_msg)
  657. def _create_tool_call_data(self, call_info):
  658. """Create tool call data structure from call info"""
  659. return {
  660. "tool_call_id": call_info["tool_call_id"],
  661. "name": call_info["name"],
  662. "arguments": call_info["arguments"],
  663. }
  664. def _create_citation_payload(self, short_id, payload):
  665. """Create citation payload for a short ID"""
  666. # This will be overridden in RAG subclasses
  667. # check if as_dict is on payload
  668. if hasattr(payload, "as_dict"):
  669. payload = payload.as_dict()
  670. if hasattr(payload, "dict"):
  671. payload = payload.dict
  672. if hasattr(payload, "to_dict"):
  673. payload = payload.to_dict()
  674. return {
  675. "id": f"{short_id}",
  676. "object": "citation",
  677. "payload": dump_obj(payload), # Will be populated in RAG agents
  678. }
  679. def _create_final_answer_payload(self, answer_text, citations):
  680. """Create the final answer payload"""
  681. # This will be extended in RAG subclasses
  682. return {
  683. "id": "msg_final",
  684. "object": "agent.final_answer",
  685. "generated_answer": answer_text,
  686. "citations": citations,
  687. }
  688. class R2RXMLStreamingAgent(R2RStreamingAgent):
  689. """
  690. A streaming agent that parses XML-formatted responses with special handling for:
  691. - <think> or <Thought> blocks for chain-of-thought reasoning
  692. - <Action>, <ToolCalls>, <ToolCall> blocks for tool execution
  693. """
  694. # We treat <think> or <Thought> as the same token boundaries
  695. THOUGHT_OPEN = re.compile(r"<(Thought|think)>", re.IGNORECASE)
  696. THOUGHT_CLOSE = re.compile(r"</(Thought|think)>", re.IGNORECASE)
  697. # Regexes to parse out <Action>, <ToolCalls>, <ToolCall>, <Name>, <Parameters>, <Response>
  698. ACTION_PATTERN = re.compile(
  699. r"<Action>(.*?)</Action>", re.IGNORECASE | re.DOTALL
  700. )
  701. TOOLCALLS_PATTERN = re.compile(
  702. r"<ToolCalls>(.*?)</ToolCalls>", re.IGNORECASE | re.DOTALL
  703. )
  704. TOOLCALL_PATTERN = re.compile(
  705. r"<ToolCall>(.*?)</ToolCall>", re.IGNORECASE | re.DOTALL
  706. )
  707. NAME_PATTERN = re.compile(r"<Name>(.*?)</Name>", re.IGNORECASE | re.DOTALL)
  708. PARAMS_PATTERN = re.compile(
  709. r"<Parameters>(.*?)</Parameters>", re.IGNORECASE | re.DOTALL
  710. )
  711. RESPONSE_PATTERN = re.compile(
  712. r"<Response>(.*?)</Response>", re.IGNORECASE | re.DOTALL
  713. )
  714. async def arun(
  715. self,
  716. system_instruction: str | None = None,
  717. messages: list[Message] | None = None,
  718. *args,
  719. **kwargs,
  720. ) -> AsyncGenerator[str, None]:
  721. """
  722. Main streaming entrypoint: returns an async generator of SSE lines.
  723. """
  724. self._reset()
  725. await self._setup(system_instruction)
  726. if messages:
  727. for m in messages:
  728. await self.conversation.add_message(m)
  729. # Initialize citation tracker for this run
  730. citation_tracker = CitationTracker()
  731. # Dictionary to store citation payloads by ID
  732. citation_payloads = {}
  733. # Track all citations emitted during streaming for final persistence
  734. self.streaming_citations: list[dict] = []
  735. async def sse_generator() -> AsyncGenerator[str, None]:
  736. iterations_count = 0
  737. try:
  738. # Keep streaming until we complete
  739. while (
  740. not self._completed
  741. and iterations_count < self.config.max_iterations
  742. ):
  743. iterations_count += 1
  744. # 1) Get current messages
  745. msg_list = await self.conversation.get_messages()
  746. gen_cfg = self.get_generation_config(
  747. msg_list[-1], stream=True
  748. )
  749. # 2) Start streaming from LLM
  750. llm_stream = self.llm_provider.aget_completion_stream(
  751. msg_list, gen_cfg
  752. )
  753. # Create state variables for each iteration
  754. iteration_buffer = ""
  755. yielded_first_event = False
  756. in_action_block = False
  757. is_thinking = False
  758. accumulated_thinking = ""
  759. thinking_signatures = {}
  760. async for chunk in llm_stream:
  761. delta = chunk.choices[0].delta
  762. finish_reason = chunk.choices[0].finish_reason
  763. # Handle thinking if present
  764. if hasattr(delta, "thinking") and delta.thinking:
  765. # Accumulate thinking for later use in messages
  766. accumulated_thinking += delta.thinking
  767. # Emit SSE "thinking" event
  768. async for (
  769. line
  770. ) in SSEFormatter.yield_thinking_event(
  771. delta.thinking
  772. ):
  773. yield line
  774. # Add this new handler for thinking signatures
  775. if hasattr(delta, "thinking_signature"):
  776. thinking_signatures[accumulated_thinking] = (
  777. delta.thinking_signature
  778. )
  779. accumulated_thinking = ""
  780. # 3) If new text, accumulate it
  781. if delta.content:
  782. iteration_buffer += delta.content
  783. # Check if we have accumulated enough text for a `<Thought>` block
  784. if len(iteration_buffer) < len("<Thought>"):
  785. continue
  786. # Check if we have yielded the first event
  787. if not yielded_first_event:
  788. # Emit the first chunk
  789. if self.THOUGHT_OPEN.findall(iteration_buffer):
  790. is_thinking = True
  791. async for (
  792. line
  793. ) in SSEFormatter.yield_thinking_event(
  794. iteration_buffer
  795. ):
  796. yield line
  797. else:
  798. async for (
  799. line
  800. ) in SSEFormatter.yield_message_event(
  801. iteration_buffer
  802. ):
  803. yield line
  804. # Mark as yielded
  805. yielded_first_event = True
  806. continue
  807. # Check if we are in a thinking block
  808. if is_thinking:
  809. # Still thinking, so keep yielding thinking events
  810. if not self.THOUGHT_CLOSE.findall(
  811. iteration_buffer
  812. ):
  813. # Emit SSE "thinking" event
  814. async for (
  815. line
  816. ) in SSEFormatter.yield_thinking_event(
  817. delta.content
  818. ):
  819. yield line
  820. continue
  821. # Done thinking, so emit the last thinking event
  822. else:
  823. is_thinking = False
  824. thought_text = delta.content.split(
  825. "</Thought>"
  826. )[0].split("</think>")[0]
  827. async for (
  828. line
  829. ) in SSEFormatter.yield_thinking_event(
  830. thought_text
  831. ):
  832. yield line
  833. post_thought_text = delta.content.split(
  834. "</Thought>"
  835. )[-1].split("</think>")[-1]
  836. delta.content = post_thought_text
  837. # (b) Find new citation spans in the accumulated text
  838. new_citation_spans = find_new_citation_spans(
  839. iteration_buffer, citation_tracker
  840. )
  841. # Process each new citation span
  842. for cid, spans in new_citation_spans.items():
  843. for span in spans:
  844. # Check if this is the first time we've seen this citation ID
  845. is_new_citation = (
  846. citation_tracker.is_new_citation(cid)
  847. )
  848. # Get payload if it's a new citation
  849. payload = None
  850. if is_new_citation:
  851. source_obj = self.search_results_collector.find_by_short_id(
  852. cid
  853. )
  854. if source_obj:
  855. # Store payload for reuse
  856. payload = dump_obj(source_obj)
  857. citation_payloads[cid] = payload
  858. # Create citation event payload
  859. citation_data = {
  860. "id": cid,
  861. "object": "citation",
  862. "is_new": is_new_citation,
  863. "span": {
  864. "start": span[0],
  865. "end": span[1],
  866. },
  867. }
  868. # Only include full payload for new citations
  869. if is_new_citation and payload:
  870. citation_data["payload"] = payload
  871. # Add to streaming citations for final answer
  872. self.streaming_citations.append(
  873. citation_data
  874. )
  875. # Emit the citation event
  876. async for (
  877. line
  878. ) in SSEFormatter.yield_citation_event(
  879. citation_data
  880. ):
  881. yield line
  882. # Now prepare to emit the newly streamed text as a "message" event
  883. if (
  884. iteration_buffer.count("<")
  885. and not in_action_block
  886. ):
  887. in_action_block = True
  888. if (
  889. in_action_block
  890. and len(
  891. self.ACTION_PATTERN.findall(
  892. iteration_buffer
  893. )
  894. )
  895. < 2
  896. ):
  897. continue
  898. elif in_action_block:
  899. in_action_block = False
  900. # Emit the post action block text, if it is there
  901. post_action_text = iteration_buffer.split(
  902. "</Action>"
  903. )[-1]
  904. if post_action_text:
  905. async for (
  906. line
  907. ) in SSEFormatter.yield_message_event(
  908. post_action_text
  909. ):
  910. yield line
  911. else:
  912. async for (
  913. line
  914. ) in SSEFormatter.yield_message_event(
  915. delta.content
  916. ):
  917. yield line
  918. elif finish_reason == "stop":
  919. break
  920. # Process any accumulated thinking
  921. await self._handle_thinking(
  922. thinking_signatures, accumulated_thinking
  923. )
  924. # 6) The LLM is done. If we have any leftover partial text,
  925. # finalize it in the conversation
  926. if iteration_buffer:
  927. # Create the final message with metadata including citations
  928. final_message = Message(
  929. role="assistant",
  930. content=iteration_buffer,
  931. metadata={"citations": self.streaming_citations},
  932. )
  933. # Add it to the conversation
  934. await self.conversation.add_message(final_message)
  935. # --- 4) Process any <Action>/<ToolCalls> blocks, or mark completed
  936. action_matches = self.ACTION_PATTERN.findall(
  937. iteration_buffer
  938. )
  939. if len(action_matches) > 0:
  940. # Process each ToolCall
  941. xml_toolcalls = "<ToolCalls>"
  942. for action_block in action_matches:
  943. tool_calls_text = []
  944. # Look for ToolCalls wrapper, or use the raw action block
  945. calls_wrapper = self.TOOLCALLS_PATTERN.findall(
  946. action_block
  947. )
  948. if calls_wrapper:
  949. for tw in calls_wrapper:
  950. tool_calls_text.append(tw)
  951. else:
  952. tool_calls_text.append(action_block)
  953. for calls_region in tool_calls_text:
  954. calls_found = self.TOOLCALL_PATTERN.findall(
  955. calls_region
  956. )
  957. for tc_block in calls_found:
  958. tool_name, tool_params = (
  959. self._parse_single_tool_call(tc_block)
  960. )
  961. if tool_name:
  962. # Emit SSE event for tool call
  963. tool_call_id = (
  964. f"call_{abs(hash(tc_block))}"
  965. )
  966. call_evt_data = {
  967. "tool_call_id": tool_call_id,
  968. "name": tool_name,
  969. "arguments": json.dumps(
  970. tool_params
  971. ),
  972. }
  973. async for line in (
  974. SSEFormatter.yield_tool_call_event(
  975. call_evt_data
  976. )
  977. ):
  978. yield line
  979. try:
  980. tool_result = await self.handle_function_or_tool_call(
  981. tool_name,
  982. json.dumps(tool_params),
  983. tool_id=tool_call_id,
  984. save_messages=False,
  985. )
  986. result_content = tool_result.llm_formatted_result
  987. except Exception as e:
  988. result_content = f"Error in tool '{tool_name}': {str(e)}"
  989. xml_toolcalls += (
  990. f"<ToolCall>"
  991. f"<Name>{tool_name}</Name>"
  992. f"<Parameters>{json.dumps(tool_params)}</Parameters>"
  993. f"<Result>{result_content}</Result>"
  994. f"</ToolCall>"
  995. )
  996. # Emit SSE tool result for non-result tools
  997. result_data = {
  998. "tool_call_id": tool_call_id,
  999. "role": "tool",
  1000. "content": json.dumps(
  1001. convert_nonserializable_objects(
  1002. result_content
  1003. )
  1004. ),
  1005. }
  1006. async for line in SSEFormatter.yield_tool_result_event(
  1007. result_data
  1008. ):
  1009. yield line
  1010. xml_toolcalls += "</ToolCalls>"
  1011. pre_action_text = iteration_buffer[
  1012. : iteration_buffer.find(action_block)
  1013. ]
  1014. post_action_text = iteration_buffer[
  1015. iteration_buffer.find(action_block)
  1016. + len(action_block) :
  1017. ]
  1018. iteration_text = (
  1019. pre_action_text + xml_toolcalls + post_action_text
  1020. )
  1021. # Update the conversation with tool results
  1022. await self.conversation.add_message(
  1023. Message(
  1024. role="assistant",
  1025. content=iteration_text,
  1026. metadata={
  1027. "citations": self.streaming_citations
  1028. },
  1029. )
  1030. )
  1031. else:
  1032. # (a) Prepare final answer with optimized citations
  1033. consolidated_citations = []
  1034. # Group citations by ID with all their spans
  1035. for (
  1036. cid,
  1037. spans,
  1038. ) in citation_tracker.get_all_spans().items():
  1039. if cid in citation_payloads:
  1040. consolidated_citations.append(
  1041. {
  1042. "id": cid,
  1043. "object": "citation",
  1044. "spans": [
  1045. {"start": s[0], "end": s[1]}
  1046. for s in spans
  1047. ],
  1048. "payload": citation_payloads[cid],
  1049. }
  1050. )
  1051. # Create final answer payload
  1052. final_evt_payload = {
  1053. "id": "msg_final",
  1054. "object": "agent.final_answer",
  1055. "generated_answer": iteration_buffer,
  1056. "citations": consolidated_citations,
  1057. }
  1058. # Emit final answer event
  1059. async for (
  1060. line
  1061. ) in SSEFormatter.yield_final_answer_event(
  1062. final_evt_payload
  1063. ):
  1064. yield line
  1065. # (b) Signal the end of the SSE stream
  1066. yield SSEFormatter.yield_done_event()
  1067. self._completed = True
  1068. # If we exit the while loop due to hitting max iterations
  1069. if not self._completed:
  1070. # Generate a summary using the LLM
  1071. summary = await self._generate_llm_summary(
  1072. iterations_count
  1073. )
  1074. # Send the summary as a message event
  1075. async for line in SSEFormatter.yield_message_event(
  1076. summary
  1077. ):
  1078. yield line
  1079. # Add summary to conversation with citations metadata
  1080. await self.conversation.add_message(
  1081. Message(
  1082. role="assistant",
  1083. content=summary,
  1084. metadata={"citations": self.streaming_citations},
  1085. )
  1086. )
  1087. # Create and emit a final answer payload with the summary
  1088. final_evt_payload = {
  1089. "id": "msg_final",
  1090. "object": "agent.final_answer",
  1091. "generated_answer": summary,
  1092. "citations": consolidated_citations,
  1093. }
  1094. async for line in SSEFormatter.yield_final_answer_event(
  1095. final_evt_payload
  1096. ):
  1097. yield line
  1098. # Signal the end of the SSE stream
  1099. yield SSEFormatter.yield_done_event()
  1100. self._completed = True
  1101. except Exception as e:
  1102. logger.error(f"Error in streaming agent: {str(e)}")
  1103. # Emit error event for client
  1104. async for line in SSEFormatter.yield_error_event(
  1105. f"Agent error: {str(e)}"
  1106. ):
  1107. yield line
  1108. # Send done event to close the stream
  1109. yield SSEFormatter.yield_done_event()
  1110. # Finally, we return the async generator
  1111. async for line in sse_generator():
  1112. yield line
  1113. def _parse_single_tool_call(
  1114. self, toolcall_text: str
  1115. ) -> Tuple[Optional[str], dict]:
  1116. """
  1117. Parse a ToolCall block to extract the name and parameters.
  1118. Args:
  1119. toolcall_text: The text content of a ToolCall block
  1120. Returns:
  1121. Tuple of (tool_name, tool_parameters)
  1122. """
  1123. name_match = self.NAME_PATTERN.search(toolcall_text)
  1124. if not name_match:
  1125. return None, {}
  1126. tool_name = name_match.group(1).strip()
  1127. params_match = self.PARAMS_PATTERN.search(toolcall_text)
  1128. if not params_match:
  1129. return tool_name, {}
  1130. raw_params = params_match.group(1).strip()
  1131. try:
  1132. # Handle potential JSON parsing issues
  1133. # First try direct parsing
  1134. tool_params = json.loads(raw_params)
  1135. except json.JSONDecodeError:
  1136. # If that fails, try to clean up the JSON string
  1137. try:
  1138. # Replace escaped quotes that might cause issues
  1139. cleaned_params = raw_params.replace('\\"', '"')
  1140. # Try again with the cleaned string
  1141. tool_params = json.loads(cleaned_params)
  1142. except json.JSONDecodeError:
  1143. # If all else fails, treat as a plain string value
  1144. tool_params = {"value": raw_params}
  1145. return tool_name, tool_params
  1146. class R2RXMLToolsAgent(R2RAgent):
  1147. """
  1148. A non-streaming agent that:
  1149. - parses <think> or <Thought> blocks as chain-of-thought
  1150. - filters out XML tags related to tool calls and actions
  1151. - processes <Action><ToolCalls><ToolCall> blocks
  1152. - properly extracts citations when they appear in the text
  1153. """
  1154. # We treat <think> or <Thought> as the same token boundaries
  1155. THOUGHT_OPEN = re.compile(r"<(Thought|think)>", re.IGNORECASE)
  1156. THOUGHT_CLOSE = re.compile(r"</(Thought|think)>", re.IGNORECASE)
  1157. # Regexes to parse out <Action>, <ToolCalls>, <ToolCall>, <Name>, <Parameters>, <Response>
  1158. ACTION_PATTERN = re.compile(
  1159. r"<Action>(.*?)</Action>", re.IGNORECASE | re.DOTALL
  1160. )
  1161. TOOLCALLS_PATTERN = re.compile(
  1162. r"<ToolCalls>(.*?)</ToolCalls>", re.IGNORECASE | re.DOTALL
  1163. )
  1164. TOOLCALL_PATTERN = re.compile(
  1165. r"<ToolCall>(.*?)</ToolCall>", re.IGNORECASE | re.DOTALL
  1166. )
  1167. NAME_PATTERN = re.compile(r"<Name>(.*?)</Name>", re.IGNORECASE | re.DOTALL)
  1168. PARAMS_PATTERN = re.compile(
  1169. r"<Parameters>(.*?)</Parameters>", re.IGNORECASE | re.DOTALL
  1170. )
  1171. RESPONSE_PATTERN = re.compile(
  1172. r"<Response>(.*?)</Response>", re.IGNORECASE | re.DOTALL
  1173. )
  1174. async def process_llm_response(self, response, *args, **kwargs):
  1175. """
  1176. Override the base process_llm_response to handle XML structured responses
  1177. including thoughts and tool calls.
  1178. """
  1179. if self._completed:
  1180. return
  1181. message = response.choices[0].message
  1182. finish_reason = response.choices[0].finish_reason
  1183. if not message.content:
  1184. # If there's no content, let the parent class handle the normal tool_calls flow
  1185. return await super().process_llm_response(
  1186. response, *args, **kwargs
  1187. )
  1188. # Get the response content
  1189. content = message.content
  1190. # HACK for gemini
  1191. content = content.replace("```action", "")
  1192. content = content.replace("```tool_code", "")
  1193. content = content.replace("```", "")
  1194. if (
  1195. not content.startswith("<")
  1196. and "deepseek" in self.rag_generation_config.model
  1197. ): # HACK - fix issues with adding `<think>` to the beginning
  1198. content = "<think>" + content
  1199. # Process any tool calls in the content
  1200. action_matches = self.ACTION_PATTERN.findall(content)
  1201. if action_matches:
  1202. xml_toolcalls = "<ToolCalls>"
  1203. for action_block in action_matches:
  1204. tool_calls_text = []
  1205. # Look for ToolCalls wrapper, or use the raw action block
  1206. calls_wrapper = self.TOOLCALLS_PATTERN.findall(action_block)
  1207. if calls_wrapper:
  1208. for tw in calls_wrapper:
  1209. tool_calls_text.append(tw)
  1210. else:
  1211. tool_calls_text.append(action_block)
  1212. # Process each ToolCall
  1213. for calls_region in tool_calls_text:
  1214. calls_found = self.TOOLCALL_PATTERN.findall(calls_region)
  1215. for tc_block in calls_found:
  1216. tool_name, tool_params = self._parse_single_tool_call(
  1217. tc_block
  1218. )
  1219. if tool_name:
  1220. tool_call_id = f"call_{abs(hash(tc_block))}"
  1221. try:
  1222. tool_result = (
  1223. await self.handle_function_or_tool_call(
  1224. tool_name,
  1225. json.dumps(tool_params),
  1226. tool_id=tool_call_id,
  1227. save_messages=False,
  1228. )
  1229. )
  1230. # Add tool result to XML
  1231. xml_toolcalls += (
  1232. f"<ToolCall>"
  1233. f"<Name>{tool_name}</Name>"
  1234. f"<Parameters>{json.dumps(tool_params)}</Parameters>"
  1235. f"<Result>{tool_result.llm_formatted_result}</Result>"
  1236. f"</ToolCall>"
  1237. )
  1238. except Exception as e:
  1239. logger.error(f"Error in tool call: {str(e)}")
  1240. # Add error to XML
  1241. xml_toolcalls += (
  1242. f"<ToolCall>"
  1243. f"<Name>{tool_name}</Name>"
  1244. f"<Parameters>{json.dumps(tool_params)}</Parameters>"
  1245. f"<Result>Error: {str(e)}</Result>"
  1246. f"</ToolCall>"
  1247. )
  1248. xml_toolcalls += "</ToolCalls>"
  1249. pre_action_text = content[: content.find(action_block)]
  1250. post_action_text = content[
  1251. content.find(action_block) + len(action_block) :
  1252. ]
  1253. iteration_text = pre_action_text + xml_toolcalls + post_action_text
  1254. # Create the assistant message
  1255. await self.conversation.add_message(
  1256. Message(role="assistant", content=iteration_text)
  1257. )
  1258. else:
  1259. # Create an assistant message with the content as-is
  1260. await self.conversation.add_message(
  1261. Message(role="assistant", content=content)
  1262. )
  1263. # Only mark as completed if the finish_reason is "stop" or there are no action calls
  1264. # This allows the agent to continue the conversation when tool calls are processed
  1265. if finish_reason == "stop":
  1266. self._completed = True
  1267. def _parse_single_tool_call(
  1268. self, toolcall_text: str
  1269. ) -> Tuple[Optional[str], dict]:
  1270. """
  1271. Parse a ToolCall block to extract the name and parameters.
  1272. Args:
  1273. toolcall_text: The text content of a ToolCall block
  1274. Returns:
  1275. Tuple of (tool_name, tool_parameters)
  1276. """
  1277. name_match = self.NAME_PATTERN.search(toolcall_text)
  1278. if not name_match:
  1279. return None, {}
  1280. tool_name = name_match.group(1).strip()
  1281. params_match = self.PARAMS_PATTERN.search(toolcall_text)
  1282. if not params_match:
  1283. return tool_name, {}
  1284. raw_params = params_match.group(1).strip()
  1285. try:
  1286. # Handle potential JSON parsing issues
  1287. # First try direct parsing
  1288. tool_params = json.loads(raw_params)
  1289. except json.JSONDecodeError:
  1290. # If that fails, try to clean up the JSON string
  1291. try:
  1292. # Replace escaped quotes that might cause issues
  1293. cleaned_params = raw_params.replace('\\"', '"')
  1294. # Try again with the cleaned string
  1295. tool_params = json.loads(cleaned_params)
  1296. except json.JSONDecodeError:
  1297. # If all else fails, treat as a plain string value
  1298. tool_params = {"value": raw_params}
  1299. return tool_name, tool_params