base_utils.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802
  1. import json
  2. import logging
  3. import math
  4. import uuid
  5. from abc import ABCMeta
  6. from copy import deepcopy
  7. from datetime import datetime
  8. from typing import Any, Optional, Tuple, TypeVar
  9. from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5
  10. import tiktoken
  11. from ..abstractions import (
  12. AggregateSearchResult,
  13. AsyncSyncMeta,
  14. GraphCommunityResult,
  15. GraphEntityResult,
  16. GraphRelationshipResult,
  17. )
  18. from ..abstractions.vector import VectorQuantizationType
  19. logger = logging.getLogger()
  20. def id_to_shorthand(id: str | UUID):
  21. return str(id)[:7]
  22. def format_search_results_for_llm(
  23. results: AggregateSearchResult,
  24. ) -> str:
  25. """
  26. Instead of resetting 'source_counter' to 1, we:
  27. - For each chunk / graph / web / doc in `results`,
  28. - Find the aggregator index from the collector,
  29. - Print 'Source [X]:' with that aggregator index.
  30. """
  31. lines = []
  32. # We'll build a quick helper to locate aggregator indices for each object:
  33. # Or you can rely on the fact that we've added them to the collector
  34. # in the same order. But let's do a "lookup aggregator index" approach:
  35. # 1) Chunk search
  36. if results.chunk_search_results:
  37. lines.append("Vector Search Results:")
  38. for c in results.chunk_search_results:
  39. lines.extend(
  40. (f"Source ID [{id_to_shorthand(c.id)}]:", (c.text or ""))
  41. )
  42. # 2) Graph search
  43. if results.graph_search_results:
  44. lines.append("Graph Search Results:")
  45. for g in results.graph_search_results:
  46. lines.append(f"Source ID [{id_to_shorthand(g.id)}]:")
  47. if isinstance(g.content, GraphCommunityResult):
  48. lines.extend(
  49. (
  50. f"Community Name: {g.content.name}",
  51. f"ID: {g.content.id}",
  52. f"Summary: {g.content.summary}",
  53. )
  54. )
  55. elif isinstance(g.content, GraphEntityResult):
  56. lines.extend(
  57. (
  58. f"Entity Name: {g.content.name}",
  59. f"Description: {g.content.description}",
  60. )
  61. )
  62. elif isinstance(g.content, GraphRelationshipResult):
  63. lines.append(
  64. f"Relationship: {g.content.subject}-{g.content.predicate}-{g.content.object}"
  65. )
  66. # Web page search results
  67. if results.web_page_search_results:
  68. lines.append("Web Page Search Results:")
  69. for w in results.web_page_search_results:
  70. lines.extend(
  71. (
  72. f"Source ID [{id_to_shorthand(w.id)}]:",
  73. f"Title: {w.title}",
  74. f"Link: {w.link}",
  75. f"Snippet: {w.snippet}",
  76. )
  77. )
  78. # Web search results
  79. if results.web_search_results:
  80. for web_search_result in results.web_search_results:
  81. lines.append("Web Search Results:")
  82. for search_result in web_search_result.organic_results:
  83. lines.extend(
  84. (
  85. f"Source ID [{id_to_shorthand(search_result.id)}]:",
  86. f"Title: {search_result.title}",
  87. f"Link: {search_result.link}",
  88. f"Snippet: {search_result.snippet}",
  89. )
  90. )
  91. # 4) Local context docs
  92. if results.document_search_results:
  93. lines.append("Local Context Documents:")
  94. for doc_result in results.document_search_results:
  95. doc_title = doc_result.title or "Untitled Document"
  96. doc_id = doc_result.id
  97. lines.extend(
  98. (
  99. f"Full Document ID: {doc_id}",
  100. f"Shortened Document ID: {id_to_shorthand(doc_id)}",
  101. f"Document Title: {doc_title}",
  102. )
  103. )
  104. if summary := doc_result.summary:
  105. lines.append(f"Summary: {summary}")
  106. if doc_result.chunks:
  107. # Then each chunk inside:
  108. lines.extend(
  109. f"\nChunk ID {id_to_shorthand(chunk['id'])}:\n{chunk['text']}"
  110. for chunk in doc_result.chunks
  111. )
  112. if results.generic_tool_result:
  113. lines.extend(
  114. (f"Generic Tool Results: {tool_result}" or "")
  115. for tool_result in results.generic_tool_result
  116. )
  117. return "\n".join(lines)
  118. def _generate_id_from_label(label) -> UUID:
  119. return uuid5(NAMESPACE_DNS, label)
  120. def generate_id(label: Optional[str] = None) -> UUID:
  121. """Generates a unique run id."""
  122. return _generate_id_from_label(
  123. label if label is not None else str(uuid4())
  124. )
  125. def generate_document_id(filename: str, user_id: UUID) -> UUID:
  126. """Generates a unique document id from a given filename and user id."""
  127. safe_filename = filename.replace("/", "_")
  128. return _generate_id_from_label(f"{safe_filename}-{str(user_id)}")
  129. def generate_extraction_id(
  130. document_id: UUID, iteration: int = 0, version: str = "0"
  131. ) -> UUID:
  132. """Generates a unique extraction id from a given document id and
  133. iteration."""
  134. return _generate_id_from_label(f"{str(document_id)}-{iteration}-{version}")
  135. def generate_default_user_collection_id(user_id: UUID) -> UUID:
  136. """Generates a unique collection id from a given user id."""
  137. return _generate_id_from_label(str(user_id))
  138. def generate_user_id(email: str) -> UUID:
  139. """Generates a unique user id from a given email."""
  140. return _generate_id_from_label(email)
  141. def generate_default_prompt_id(prompt_name: str) -> UUID:
  142. """Generates a unique prompt id."""
  143. return _generate_id_from_label(prompt_name)
  144. def generate_entity_document_id() -> UUID:
  145. """Generates a unique document id inserting entities into a graph."""
  146. generation_time = datetime.now().isoformat()
  147. return _generate_id_from_label(f"entity-{generation_time}")
  148. def validate_uuid(uuid_str: str) -> UUID:
  149. return UUID(uuid_str)
  150. def update_settings_from_dict(server_settings, settings_dict: dict):
  151. """Updates a settings object with values from a dictionary."""
  152. settings = deepcopy(server_settings)
  153. for key, value in settings_dict.items():
  154. if value is not None:
  155. if isinstance(value, dict):
  156. for k, v in value.items():
  157. if isinstance(getattr(settings, key), dict):
  158. getattr(settings, key)[k] = v
  159. else:
  160. setattr(getattr(settings, key), k, v)
  161. else:
  162. setattr(settings, key, value)
  163. return settings
  164. def _decorate_vector_type(
  165. input_str: str,
  166. quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
  167. ) -> str:
  168. return f"{quantization_type.db_type}{input_str}"
  169. def _get_vector_column_str(
  170. dimension: int | float, quantization_type: VectorQuantizationType
  171. ) -> str:
  172. """Returns a string representation of a vector column type.
  173. Explicitly handles the case where the dimension is not a valid number meant
  174. to support embedding models that do not allow for specifying the dimension.
  175. """
  176. if math.isnan(dimension) or dimension <= 0:
  177. vector_dim = "" # Allows for Postgres to handle any dimension
  178. else:
  179. vector_dim = f"({dimension})"
  180. return _decorate_vector_type(vector_dim, quantization_type)
  181. KeyType = TypeVar("KeyType")
  182. def deep_update(
  183. mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]
  184. ) -> dict[KeyType, Any]:
  185. """
  186. Taken from Pydantic v1:
  187. https://github.com/pydantic/pydantic/blob/fd2991fe6a73819b48c906e3c3274e8e47d0f761/pydantic/utils.py#L200
  188. """
  189. updated_mapping = mapping.copy()
  190. for updating_mapping in updating_mappings:
  191. for k, v in updating_mapping.items():
  192. if (
  193. k in updated_mapping
  194. and isinstance(updated_mapping[k], dict)
  195. and isinstance(v, dict)
  196. ):
  197. updated_mapping[k] = deep_update(updated_mapping[k], v)
  198. else:
  199. updated_mapping[k] = v
  200. return updated_mapping
  201. def tokens_count_for_message(message, encoding):
  202. """Return the number of tokens used by a single message."""
  203. tokens_per_message = 3
  204. num_tokens = 0 + tokens_per_message
  205. if message.get("function_call"):
  206. num_tokens += len(encoding.encode(message["function_call"]["name"]))
  207. num_tokens += len(
  208. encoding.encode(message["function_call"]["arguments"])
  209. )
  210. elif message.get("tool_calls"):
  211. for tool_call in message["tool_calls"]:
  212. num_tokens += len(encoding.encode(tool_call["function"]["name"]))
  213. num_tokens += len(
  214. encoding.encode(tool_call["function"]["arguments"])
  215. )
  216. elif "content" in message:
  217. num_tokens += len(encoding.encode(message["content"]))
  218. return num_tokens
  219. def num_tokens_from_messages(messages, model="gpt-4.1"):
  220. """Return the number of tokens used by a list of messages for both user and assistant."""
  221. try:
  222. encoding = tiktoken.encoding_for_model(model)
  223. except KeyError:
  224. logger.warning("Warning: model not found. Using cl100k_base encoding.")
  225. encoding = tiktoken.get_encoding("cl100k_base")
  226. tokens = 0
  227. for message_ in messages:
  228. tokens += tokens_count_for_message(message_, encoding)
  229. tokens += 3 # every reply is primed with assistant
  230. return tokens
  231. class SearchResultsCollector:
  232. """
  233. Collects search results in the form (source_type, result_obj).
  234. Handles both object-oriented and dictionary-based search results.
  235. """
  236. def __init__(self):
  237. # We'll store a list of (source_type, result_obj)
  238. self._results_in_order = []
  239. @property
  240. def results(self):
  241. """Get the results list"""
  242. return self._results_in_order
  243. @results.setter
  244. def results(self, value):
  245. """
  246. Set the results directly, with automatic type detection for 'unknown' items
  247. Handles the format: [('unknown', {...}), ('unknown', {...})]
  248. """
  249. self._results_in_order = []
  250. if not isinstance(value, list):
  251. raise ValueError("Results must be a list")
  252. for item in value:
  253. if isinstance(item, tuple) and len(item) == 2:
  254. source_type, result_obj = item
  255. # Only auto-detect if the source type is "unknown"
  256. if source_type == "unknown":
  257. detected_type = self._detect_result_type(result_obj)
  258. self._results_in_order.append((detected_type, result_obj))
  259. else:
  260. self._results_in_order.append((source_type, result_obj))
  261. else:
  262. # If not a tuple, detect and add
  263. detected_type = self._detect_result_type(item)
  264. self._results_in_order.append((detected_type, item))
  265. def add_aggregate_result(self, agg):
  266. """
  267. Flatten the chunk_search_results, graph_search_results, web_search_results,
  268. and document_search_results into the collector, including nested chunks.
  269. """
  270. if hasattr(agg, "chunk_search_results") and agg.chunk_search_results:
  271. for c in agg.chunk_search_results:
  272. self._results_in_order.append(("chunk", c))
  273. if hasattr(agg, "graph_search_results") and agg.graph_search_results:
  274. for g in agg.graph_search_results:
  275. self._results_in_order.append(("graph", g))
  276. if (
  277. hasattr(agg, "web_page_search_results")
  278. and agg.web_page_search_results
  279. ):
  280. for w in agg.web_page_search_results:
  281. self._results_in_order.append(("web", w))
  282. if hasattr(agg, "web_search_results") and agg.web_search_results:
  283. for w in agg.web_search_results:
  284. self._results_in_order.append(("web", w))
  285. # Add documents and extract their chunks
  286. if (
  287. hasattr(agg, "document_search_results")
  288. and agg.document_search_results
  289. ):
  290. for doc in agg.document_search_results:
  291. # Add the document itself
  292. self._results_in_order.append(("doc", doc))
  293. # Extract and add chunks from the document
  294. chunks = None
  295. if isinstance(doc, dict):
  296. chunks = doc.get("chunks", [])
  297. elif hasattr(doc, "chunks") and doc.chunks is not None:
  298. chunks = doc.chunks
  299. if chunks:
  300. for chunk in chunks:
  301. # Ensure each chunk has the minimum required attributes
  302. if isinstance(chunk, dict) and "id" in chunk:
  303. # Add the chunk directly to results for citation lookup
  304. self._results_in_order.append(("chunk", chunk))
  305. elif hasattr(chunk, "id"):
  306. self._results_in_order.append(("chunk", chunk))
  307. def add_result(self, result_obj, source_type=None):
  308. """
  309. Add a single result object to the collector.
  310. If source_type is not provided, automatically detect the type.
  311. """
  312. if source_type:
  313. self._results_in_order.append((source_type, result_obj))
  314. return source_type
  315. detected_type = self._detect_result_type(result_obj)
  316. self._results_in_order.append((detected_type, result_obj))
  317. return detected_type
  318. def _detect_result_type(self, obj):
  319. """
  320. Detect the type of a result object based on its properties.
  321. Works with both object attributes and dictionary keys.
  322. """
  323. # Handle dictionary types first (common for web search results)
  324. if isinstance(obj, dict):
  325. # Web search pattern
  326. if all(k in obj for k in ["title", "link"]) and any(
  327. k in obj for k in ["snippet", "description"]
  328. ):
  329. return "web"
  330. # Check for graph dictionary patterns
  331. if "content" in obj and isinstance(obj["content"], dict):
  332. content = obj["content"]
  333. if all(k in content for k in ["name", "description"]):
  334. return "graph" # Entity
  335. if all(
  336. k in content for k in ["subject", "predicate", "object"]
  337. ):
  338. return "graph" # Relationship
  339. if all(k in content for k in ["name", "summary"]):
  340. return "graph" # Community
  341. # Chunk pattern
  342. if all(k in obj for k in ["text", "id"]) and any(
  343. k in obj for k in ["score", "metadata"]
  344. ):
  345. return "chunk"
  346. # Context document pattern
  347. if "document" in obj and "chunks" in obj:
  348. return "doc"
  349. # Check for explicit type indicator
  350. if "type" in obj:
  351. type_val = str(obj["type"]).lower()
  352. if any(t in type_val for t in ["web", "organic"]):
  353. return "web"
  354. if "graph" in type_val:
  355. return "graph"
  356. if "chunk" in type_val:
  357. return "chunk"
  358. if "document" in type_val:
  359. return "doc"
  360. # Handle object attributes for OOP-style results
  361. if hasattr(obj, "result_type"):
  362. result_type = str(obj.result_type).lower()
  363. if result_type in {"entity", "relationship", "community"}:
  364. return "graph"
  365. # Check class name hints
  366. class_name = obj.__class__.__name__
  367. if "Graph" in class_name:
  368. return "graph"
  369. if "Chunk" in class_name:
  370. return "chunk"
  371. if "Web" in class_name:
  372. return "web"
  373. if "Document" in class_name:
  374. return "doc"
  375. # Check for object attribute patterns
  376. if hasattr(obj, "content"):
  377. content = obj.content
  378. if hasattr(content, "name") and hasattr(content, "description"):
  379. return "graph" # Entity
  380. if hasattr(content, "subject") and hasattr(content, "predicate"):
  381. return "graph" # Relationship
  382. if hasattr(content, "name") and hasattr(content, "summary"):
  383. return "graph" # Community
  384. if (
  385. hasattr(obj, "text")
  386. and hasattr(obj, "id")
  387. and (hasattr(obj, "score") or hasattr(obj, "metadata"))
  388. ):
  389. return "chunk"
  390. if (
  391. hasattr(obj, "title")
  392. and hasattr(obj, "link")
  393. and hasattr(obj, "snippet")
  394. ):
  395. return "web"
  396. if hasattr(obj, "document") and hasattr(obj, "chunks"):
  397. return "doc"
  398. # Default when type can't be determined
  399. return "unknown"
  400. def find_by_short_id(self, short_id):
  401. """Find a result by its short ID prefix with better chunk handling"""
  402. if not short_id:
  403. return None
  404. # First try direct lookup using regular iteration
  405. for _, result_obj in self._results_in_order:
  406. # Check dictionary objects
  407. if isinstance(result_obj, dict) and "id" in result_obj:
  408. result_id = str(result_obj["id"])
  409. if result_id.startswith(short_id):
  410. return result_obj
  411. # Check object with id attribute
  412. elif hasattr(result_obj, "id"):
  413. obj_id = getattr(result_obj, "id", None)
  414. if obj_id and str(obj_id).startswith(short_id):
  415. # Convert to dict if possible
  416. if hasattr(result_obj, "as_dict"):
  417. return result_obj.as_dict()
  418. elif hasattr(result_obj, "model_dump"):
  419. return result_obj.model_dump()
  420. elif hasattr(result_obj, "dict"):
  421. return result_obj.dict()
  422. else:
  423. return result_obj
  424. # If not found, look for chunks inside documents that weren't extracted properly
  425. for source_type, result_obj in self._results_in_order:
  426. if source_type == "doc":
  427. # Try various ways to access chunks
  428. chunks = None
  429. if isinstance(result_obj, dict) and "chunks" in result_obj:
  430. chunks = result_obj["chunks"]
  431. elif (
  432. hasattr(result_obj, "chunks")
  433. and result_obj.chunks is not None
  434. ):
  435. chunks = result_obj.chunks
  436. if chunks:
  437. for chunk in chunks:
  438. # Try each chunk
  439. chunk_id = None
  440. if isinstance(chunk, dict) and "id" in chunk:
  441. chunk_id = chunk["id"]
  442. elif hasattr(chunk, "id"):
  443. chunk_id = chunk.id
  444. if chunk_id and str(chunk_id).startswith(short_id):
  445. return chunk
  446. return None
  447. def get_results_by_type(self, type_name):
  448. """Get all results of a specific type"""
  449. return [
  450. result_obj
  451. for source_type, result_obj in self._results_in_order
  452. if source_type == type_name
  453. ]
  454. def __repr__(self):
  455. """String representation showing counts by type"""
  456. type_counts = {}
  457. for source_type, _ in self._results_in_order:
  458. type_counts[source_type] = type_counts.get(source_type, 0) + 1
  459. return f"SearchResultsCollector with {len(self._results_in_order)} results: {type_counts}"
  460. def get_all_results(self) -> list[Tuple[str, Any]]:
  461. """
  462. Return list of (source_type, result_obj, aggregator_index),
  463. in the order appended.
  464. """
  465. return self._results_in_order
  466. def convert_nonserializable_objects(obj):
  467. if hasattr(obj, "model_dump"):
  468. obj = obj.model_dump()
  469. if hasattr(obj, "as_dict"):
  470. obj = obj.as_dict()
  471. if hasattr(obj, "to_dict"):
  472. obj = obj.to_dict()
  473. if isinstance(obj, dict):
  474. new_obj = {}
  475. for key, value in obj.items():
  476. # Convert key to string if it is a UUID or not already a string.
  477. new_key = key if isinstance(key, str) else str(key)
  478. new_obj[new_key] = convert_nonserializable_objects(value)
  479. return new_obj
  480. elif isinstance(obj, list):
  481. return [convert_nonserializable_objects(item) for item in obj]
  482. elif isinstance(obj, tuple):
  483. return tuple(convert_nonserializable_objects(item) for item in obj)
  484. elif isinstance(obj, set):
  485. return {convert_nonserializable_objects(item) for item in obj}
  486. elif isinstance(obj, uuid.UUID):
  487. return str(obj)
  488. elif isinstance(obj, datetime):
  489. return obj.isoformat() # Convert datetime to ISO formatted string
  490. else:
  491. return obj
  492. def dump_obj(obj) -> list[dict[str, Any]]:
  493. if hasattr(obj, "model_dump"):
  494. obj = obj.model_dump()
  495. elif hasattr(obj, "dict"):
  496. obj = obj.dict()
  497. elif hasattr(obj, "as_dict"):
  498. obj = obj.as_dict()
  499. elif hasattr(obj, "to_dict"):
  500. obj = obj.to_dict()
  501. obj = convert_nonserializable_objects(obj)
  502. return obj
  503. def dump_collector(collector: SearchResultsCollector) -> list[dict[str, Any]]:
  504. dumped = []
  505. for source_type, result_obj in collector.get_all_results():
  506. # Get the dictionary from the result object
  507. if hasattr(result_obj, "model_dump"):
  508. result_dict = result_obj.model_dump()
  509. elif hasattr(result_obj, "dict"):
  510. result_dict = result_obj.dict()
  511. elif hasattr(result_obj, "as_dict"):
  512. result_dict = result_obj.as_dict()
  513. elif hasattr(result_obj, "to_dict"):
  514. result_dict = result_obj.to_dict()
  515. else:
  516. result_dict = (
  517. result_obj # Fallback if no conversion method is available
  518. )
  519. # Use the recursive conversion on the entire dictionary
  520. result_dict = convert_nonserializable_objects(result_dict)
  521. dumped.append(
  522. {
  523. "source_type": source_type,
  524. "result": result_dict,
  525. }
  526. )
  527. return dumped
  528. # FIXME: Tiktoken does not support gpt-4.1, so continue using gpt-4o
  529. # https://github.com/openai/tiktoken/issues/395
  530. def num_tokens(text, model="gpt-4o"):
  531. try:
  532. encoding = tiktoken.encoding_for_model(model)
  533. except KeyError:
  534. # Fallback to a known encoding if model not recognized
  535. encoding = tiktoken.get_encoding("cl100k_base")
  536. return len(encoding.encode(text, disallowed_special=()))
  537. class CombinedMeta(AsyncSyncMeta, ABCMeta):
  538. pass
  539. async def yield_sse_event(event_name: str, payload: dict, chunk_size=1024):
  540. """
  541. Helper that yields a single SSE event in properly chunked lines.
  542. e.g. event: event_name
  543. data: (partial JSON 1)
  544. data: (partial JSON 2)
  545. ...
  546. [blank line to end event]
  547. """
  548. # SSE: first the "event: ..."
  549. yield f"event: {event_name}\n"
  550. # Convert payload to JSON
  551. content_str = json.dumps(payload, default=str)
  552. # data
  553. yield f"data: {content_str}\n"
  554. # blank line signals end of SSE event
  555. yield "\n"
  556. class SSEFormatter:
  557. """
  558. Enhanced formatter for Server-Sent Events (SSE) with citation tracking.
  559. Extends the existing SSEFormatter with improved citation handling.
  560. """
  561. @staticmethod
  562. async def yield_citation_event(
  563. citation_data: dict,
  564. ):
  565. """
  566. Emits a citation event with optimized payload.
  567. Args:
  568. citation_id: The short ID of the citation (e.g., 'abc1234')
  569. span: (start, end) position tuple for this occurrence
  570. payload: Source object (included only for first occurrence)
  571. is_new: Whether this is the first time we've seen this citation
  572. citation_id_counter: Optional counter for citation occurrences
  573. Yields:
  574. Formatted SSE event lines
  575. """
  576. # Include the full payload only for new citations
  577. if not citation_data.get("is_new") or "payload" not in citation_data:
  578. citation_data["payload"] = None
  579. # Yield the event
  580. async for line in yield_sse_event("citation", citation_data):
  581. yield line
  582. @staticmethod
  583. async def yield_final_answer_event(
  584. final_data: dict,
  585. ):
  586. # Yield the event
  587. async for line in yield_sse_event("final_answer", final_data):
  588. yield line
  589. # Include other existing SSEFormatter methods for compatibility
  590. @staticmethod
  591. async def yield_message_event(text_segment, msg_id=None):
  592. msg_id = msg_id or f"msg_{uuid.uuid4().hex[:8]}"
  593. msg_payload = {
  594. "id": msg_id,
  595. "object": "agent.message.delta",
  596. "delta": {
  597. "content": [
  598. {
  599. "type": "text",
  600. "payload": {
  601. "value": text_segment,
  602. "annotations": [],
  603. },
  604. }
  605. ]
  606. },
  607. }
  608. async for line in yield_sse_event("message", msg_payload):
  609. yield line
  610. @staticmethod
  611. async def yield_thinking_event(text_segment, thinking_id=None):
  612. thinking_id = thinking_id or f"think_{uuid.uuid4().hex[:8]}"
  613. thinking_data = {
  614. "id": thinking_id,
  615. "object": "agent.thinking.delta",
  616. "delta": {
  617. "content": [
  618. {
  619. "type": "text",
  620. "payload": {
  621. "value": text_segment,
  622. "annotations": [],
  623. },
  624. }
  625. ]
  626. },
  627. }
  628. async for line in yield_sse_event("thinking", thinking_data):
  629. yield line
  630. @staticmethod
  631. def yield_done_event():
  632. return "event: done\ndata: [DONE]\n\n"
  633. @staticmethod
  634. async def yield_error_event(error_message, error_id=None):
  635. error_id = error_id or f"err_{uuid.uuid4().hex[:8]}"
  636. error_payload = {
  637. "id": error_id,
  638. "object": "agent.error",
  639. "error": {"message": error_message, "type": "agent_error"},
  640. }
  641. async for line in yield_sse_event("error", error_payload):
  642. yield line
  643. @staticmethod
  644. async def yield_tool_call_event(tool_call_data):
  645. from ..api.models.retrieval.responses import ToolCallEvent
  646. tc_event = ToolCallEvent(event="tool_call", data=tool_call_data)
  647. async for line in yield_sse_event(
  648. "tool_call", tc_event.dict()["data"]
  649. ):
  650. yield line
  651. # New helper for emitting search results:
  652. @staticmethod
  653. async def yield_search_results_event(aggregated_results):
  654. payload = {
  655. "id": "search_1",
  656. "object": "rag.search_results",
  657. "data": aggregated_results.as_dict(),
  658. }
  659. async for line in yield_sse_event("search_results", payload):
  660. yield line
  661. @staticmethod
  662. async def yield_tool_result_event(tool_result_data):
  663. from ..api.models.retrieval.responses import ToolResultEvent
  664. tr_event = ToolResultEvent(event="tool_result", data=tool_result_data)
  665. async for line in yield_sse_event(
  666. "tool_result", tr_event.dict()["data"]
  667. ):
  668. yield line