123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802 |
- import json
- import logging
- import math
- import uuid
- from abc import ABCMeta
- from copy import deepcopy
- from datetime import datetime
- from typing import Any, Optional, Tuple, TypeVar
- from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5
- import tiktoken
- from ..abstractions import (
- AggregateSearchResult,
- AsyncSyncMeta,
- GraphCommunityResult,
- GraphEntityResult,
- GraphRelationshipResult,
- )
- from ..abstractions.vector import VectorQuantizationType
- logger = logging.getLogger()
- def id_to_shorthand(id: str | UUID):
- return str(id)[:7]
- def format_search_results_for_llm(
- results: AggregateSearchResult,
- ) -> str:
- """
- Instead of resetting 'source_counter' to 1, we:
- - For each chunk / graph / web / doc in `results`,
- - Find the aggregator index from the collector,
- - Print 'Source [X]:' with that aggregator index.
- """
- lines = []
- # We'll build a quick helper to locate aggregator indices for each object:
- # Or you can rely on the fact that we've added them to the collector
- # in the same order. But let's do a "lookup aggregator index" approach:
- # 1) Chunk search
- if results.chunk_search_results:
- lines.append("Vector Search Results:")
- for c in results.chunk_search_results:
- lines.extend(
- (f"Source ID [{id_to_shorthand(c.id)}]:", (c.text or ""))
- )
- # 2) Graph search
- if results.graph_search_results:
- lines.append("Graph Search Results:")
- for g in results.graph_search_results:
- lines.append(f"Source ID [{id_to_shorthand(g.id)}]:")
- if isinstance(g.content, GraphCommunityResult):
- lines.extend(
- (
- f"Community Name: {g.content.name}",
- f"ID: {g.content.id}",
- f"Summary: {g.content.summary}",
- )
- )
- elif isinstance(g.content, GraphEntityResult):
- lines.extend(
- (
- f"Entity Name: {g.content.name}",
- f"Description: {g.content.description}",
- )
- )
- elif isinstance(g.content, GraphRelationshipResult):
- lines.append(
- f"Relationship: {g.content.subject}-{g.content.predicate}-{g.content.object}"
- )
- # Web page search results
- if results.web_page_search_results:
- lines.append("Web Page Search Results:")
- for w in results.web_page_search_results:
- lines.extend(
- (
- f"Source ID [{id_to_shorthand(w.id)}]:",
- f"Title: {w.title}",
- f"Link: {w.link}",
- f"Snippet: {w.snippet}",
- )
- )
- # Web search results
- if results.web_search_results:
- for web_search_result in results.web_search_results:
- lines.append("Web Search Results:")
- for search_result in web_search_result.organic_results:
- lines.extend(
- (
- f"Source ID [{id_to_shorthand(search_result.id)}]:",
- f"Title: {search_result.title}",
- f"Link: {search_result.link}",
- f"Snippet: {search_result.snippet}",
- )
- )
- # 4) Local context docs
- if results.document_search_results:
- lines.append("Local Context Documents:")
- for doc_result in results.document_search_results:
- doc_title = doc_result.title or "Untitled Document"
- doc_id = doc_result.id
- lines.extend(
- (
- f"Full Document ID: {doc_id}",
- f"Shortened Document ID: {id_to_shorthand(doc_id)}",
- f"Document Title: {doc_title}",
- )
- )
- if summary := doc_result.summary:
- lines.append(f"Summary: {summary}")
- if doc_result.chunks:
- # Then each chunk inside:
- lines.extend(
- f"\nChunk ID {id_to_shorthand(chunk['id'])}:\n{chunk['text']}"
- for chunk in doc_result.chunks
- )
- if results.generic_tool_result:
- lines.extend(
- (f"Generic Tool Results: {tool_result}" or "")
- for tool_result in results.generic_tool_result
- )
- return "\n".join(lines)
- def _generate_id_from_label(label) -> UUID:
- return uuid5(NAMESPACE_DNS, label)
- def generate_id(label: Optional[str] = None) -> UUID:
- """Generates a unique run id."""
- return _generate_id_from_label(
- label if label is not None else str(uuid4())
- )
- def generate_document_id(filename: str, user_id: UUID) -> UUID:
- """Generates a unique document id from a given filename and user id."""
- safe_filename = filename.replace("/", "_")
- return _generate_id_from_label(f"{safe_filename}-{str(user_id)}")
- def generate_extraction_id(
- document_id: UUID, iteration: int = 0, version: str = "0"
- ) -> UUID:
- """Generates a unique extraction id from a given document id and
- iteration."""
- return _generate_id_from_label(f"{str(document_id)}-{iteration}-{version}")
- def generate_default_user_collection_id(user_id: UUID) -> UUID:
- """Generates a unique collection id from a given user id."""
- return _generate_id_from_label(str(user_id))
- def generate_user_id(email: str) -> UUID:
- """Generates a unique user id from a given email."""
- return _generate_id_from_label(email)
- def generate_default_prompt_id(prompt_name: str) -> UUID:
- """Generates a unique prompt id."""
- return _generate_id_from_label(prompt_name)
- def generate_entity_document_id() -> UUID:
- """Generates a unique document id inserting entities into a graph."""
- generation_time = datetime.now().isoformat()
- return _generate_id_from_label(f"entity-{generation_time}")
- def validate_uuid(uuid_str: str) -> UUID:
- return UUID(uuid_str)
- def update_settings_from_dict(server_settings, settings_dict: dict):
- """Updates a settings object with values from a dictionary."""
- settings = deepcopy(server_settings)
- for key, value in settings_dict.items():
- if value is not None:
- if isinstance(value, dict):
- for k, v in value.items():
- if isinstance(getattr(settings, key), dict):
- getattr(settings, key)[k] = v
- else:
- setattr(getattr(settings, key), k, v)
- else:
- setattr(settings, key, value)
- return settings
- def _decorate_vector_type(
- input_str: str,
- quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
- ) -> str:
- return f"{quantization_type.db_type}{input_str}"
- def _get_vector_column_str(
- dimension: int | float, quantization_type: VectorQuantizationType
- ) -> str:
- """Returns a string representation of a vector column type.
- Explicitly handles the case where the dimension is not a valid number meant
- to support embedding models that do not allow for specifying the dimension.
- """
- if math.isnan(dimension) or dimension <= 0:
- vector_dim = "" # Allows for Postgres to handle any dimension
- else:
- vector_dim = f"({dimension})"
- return _decorate_vector_type(vector_dim, quantization_type)
- KeyType = TypeVar("KeyType")
- def deep_update(
- mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]
- ) -> dict[KeyType, Any]:
- """
- Taken from Pydantic v1:
- https://github.com/pydantic/pydantic/blob/fd2991fe6a73819b48c906e3c3274e8e47d0f761/pydantic/utils.py#L200
- """
- updated_mapping = mapping.copy()
- for updating_mapping in updating_mappings:
- for k, v in updating_mapping.items():
- if (
- k in updated_mapping
- and isinstance(updated_mapping[k], dict)
- and isinstance(v, dict)
- ):
- updated_mapping[k] = deep_update(updated_mapping[k], v)
- else:
- updated_mapping[k] = v
- return updated_mapping
- def tokens_count_for_message(message, encoding):
- """Return the number of tokens used by a single message."""
- tokens_per_message = 3
- num_tokens = 0 + tokens_per_message
- if message.get("function_call"):
- num_tokens += len(encoding.encode(message["function_call"]["name"]))
- num_tokens += len(
- encoding.encode(message["function_call"]["arguments"])
- )
- elif message.get("tool_calls"):
- for tool_call in message["tool_calls"]:
- num_tokens += len(encoding.encode(tool_call["function"]["name"]))
- num_tokens += len(
- encoding.encode(tool_call["function"]["arguments"])
- )
- elif "content" in message:
- num_tokens += len(encoding.encode(message["content"]))
- return num_tokens
- def num_tokens_from_messages(messages, model="gpt-4.1"):
- """Return the number of tokens used by a list of messages for both user and assistant."""
- try:
- encoding = tiktoken.encoding_for_model(model)
- except KeyError:
- logger.warning("Warning: model not found. Using cl100k_base encoding.")
- encoding = tiktoken.get_encoding("cl100k_base")
- tokens = 0
- for message_ in messages:
- tokens += tokens_count_for_message(message_, encoding)
- tokens += 3 # every reply is primed with assistant
- return tokens
- class SearchResultsCollector:
- """
- Collects search results in the form (source_type, result_obj).
- Handles both object-oriented and dictionary-based search results.
- """
- def __init__(self):
- # We'll store a list of (source_type, result_obj)
- self._results_in_order = []
- @property
- def results(self):
- """Get the results list"""
- return self._results_in_order
- @results.setter
- def results(self, value):
- """
- Set the results directly, with automatic type detection for 'unknown' items
- Handles the format: [('unknown', {...}), ('unknown', {...})]
- """
- self._results_in_order = []
- if not isinstance(value, list):
- raise ValueError("Results must be a list")
- for item in value:
- if isinstance(item, tuple) and len(item) == 2:
- source_type, result_obj = item
- # Only auto-detect if the source type is "unknown"
- if source_type == "unknown":
- detected_type = self._detect_result_type(result_obj)
- self._results_in_order.append((detected_type, result_obj))
- else:
- self._results_in_order.append((source_type, result_obj))
- else:
- # If not a tuple, detect and add
- detected_type = self._detect_result_type(item)
- self._results_in_order.append((detected_type, item))
- def add_aggregate_result(self, agg):
- """
- Flatten the chunk_search_results, graph_search_results, web_search_results,
- and document_search_results into the collector, including nested chunks.
- """
- if hasattr(agg, "chunk_search_results") and agg.chunk_search_results:
- for c in agg.chunk_search_results:
- self._results_in_order.append(("chunk", c))
- if hasattr(agg, "graph_search_results") and agg.graph_search_results:
- for g in agg.graph_search_results:
- self._results_in_order.append(("graph", g))
- if (
- hasattr(agg, "web_page_search_results")
- and agg.web_page_search_results
- ):
- for w in agg.web_page_search_results:
- self._results_in_order.append(("web", w))
- if hasattr(agg, "web_search_results") and agg.web_search_results:
- for w in agg.web_search_results:
- self._results_in_order.append(("web", w))
- # Add documents and extract their chunks
- if (
- hasattr(agg, "document_search_results")
- and agg.document_search_results
- ):
- for doc in agg.document_search_results:
- # Add the document itself
- self._results_in_order.append(("doc", doc))
- # Extract and add chunks from the document
- chunks = None
- if isinstance(doc, dict):
- chunks = doc.get("chunks", [])
- elif hasattr(doc, "chunks") and doc.chunks is not None:
- chunks = doc.chunks
- if chunks:
- for chunk in chunks:
- # Ensure each chunk has the minimum required attributes
- if isinstance(chunk, dict) and "id" in chunk:
- # Add the chunk directly to results for citation lookup
- self._results_in_order.append(("chunk", chunk))
- elif hasattr(chunk, "id"):
- self._results_in_order.append(("chunk", chunk))
- def add_result(self, result_obj, source_type=None):
- """
- Add a single result object to the collector.
- If source_type is not provided, automatically detect the type.
- """
- if source_type:
- self._results_in_order.append((source_type, result_obj))
- return source_type
- detected_type = self._detect_result_type(result_obj)
- self._results_in_order.append((detected_type, result_obj))
- return detected_type
- def _detect_result_type(self, obj):
- """
- Detect the type of a result object based on its properties.
- Works with both object attributes and dictionary keys.
- """
- # Handle dictionary types first (common for web search results)
- if isinstance(obj, dict):
- # Web search pattern
- if all(k in obj for k in ["title", "link"]) and any(
- k in obj for k in ["snippet", "description"]
- ):
- return "web"
- # Check for graph dictionary patterns
- if "content" in obj and isinstance(obj["content"], dict):
- content = obj["content"]
- if all(k in content for k in ["name", "description"]):
- return "graph" # Entity
- if all(
- k in content for k in ["subject", "predicate", "object"]
- ):
- return "graph" # Relationship
- if all(k in content for k in ["name", "summary"]):
- return "graph" # Community
- # Chunk pattern
- if all(k in obj for k in ["text", "id"]) and any(
- k in obj for k in ["score", "metadata"]
- ):
- return "chunk"
- # Context document pattern
- if "document" in obj and "chunks" in obj:
- return "doc"
- # Check for explicit type indicator
- if "type" in obj:
- type_val = str(obj["type"]).lower()
- if any(t in type_val for t in ["web", "organic"]):
- return "web"
- if "graph" in type_val:
- return "graph"
- if "chunk" in type_val:
- return "chunk"
- if "document" in type_val:
- return "doc"
- # Handle object attributes for OOP-style results
- if hasattr(obj, "result_type"):
- result_type = str(obj.result_type).lower()
- if result_type in {"entity", "relationship", "community"}:
- return "graph"
- # Check class name hints
- class_name = obj.__class__.__name__
- if "Graph" in class_name:
- return "graph"
- if "Chunk" in class_name:
- return "chunk"
- if "Web" in class_name:
- return "web"
- if "Document" in class_name:
- return "doc"
- # Check for object attribute patterns
- if hasattr(obj, "content"):
- content = obj.content
- if hasattr(content, "name") and hasattr(content, "description"):
- return "graph" # Entity
- if hasattr(content, "subject") and hasattr(content, "predicate"):
- return "graph" # Relationship
- if hasattr(content, "name") and hasattr(content, "summary"):
- return "graph" # Community
- if (
- hasattr(obj, "text")
- and hasattr(obj, "id")
- and (hasattr(obj, "score") or hasattr(obj, "metadata"))
- ):
- return "chunk"
- if (
- hasattr(obj, "title")
- and hasattr(obj, "link")
- and hasattr(obj, "snippet")
- ):
- return "web"
- if hasattr(obj, "document") and hasattr(obj, "chunks"):
- return "doc"
- # Default when type can't be determined
- return "unknown"
- def find_by_short_id(self, short_id):
- """Find a result by its short ID prefix with better chunk handling"""
- if not short_id:
- return None
- # First try direct lookup using regular iteration
- for _, result_obj in self._results_in_order:
- # Check dictionary objects
- if isinstance(result_obj, dict) and "id" in result_obj:
- result_id = str(result_obj["id"])
- if result_id.startswith(short_id):
- return result_obj
- # Check object with id attribute
- elif hasattr(result_obj, "id"):
- obj_id = getattr(result_obj, "id", None)
- if obj_id and str(obj_id).startswith(short_id):
- # Convert to dict if possible
- if hasattr(result_obj, "as_dict"):
- return result_obj.as_dict()
- elif hasattr(result_obj, "model_dump"):
- return result_obj.model_dump()
- elif hasattr(result_obj, "dict"):
- return result_obj.dict()
- else:
- return result_obj
- # If not found, look for chunks inside documents that weren't extracted properly
- for source_type, result_obj in self._results_in_order:
- if source_type == "doc":
- # Try various ways to access chunks
- chunks = None
- if isinstance(result_obj, dict) and "chunks" in result_obj:
- chunks = result_obj["chunks"]
- elif (
- hasattr(result_obj, "chunks")
- and result_obj.chunks is not None
- ):
- chunks = result_obj.chunks
- if chunks:
- for chunk in chunks:
- # Try each chunk
- chunk_id = None
- if isinstance(chunk, dict) and "id" in chunk:
- chunk_id = chunk["id"]
- elif hasattr(chunk, "id"):
- chunk_id = chunk.id
- if chunk_id and str(chunk_id).startswith(short_id):
- return chunk
- return None
- def get_results_by_type(self, type_name):
- """Get all results of a specific type"""
- return [
- result_obj
- for source_type, result_obj in self._results_in_order
- if source_type == type_name
- ]
- def __repr__(self):
- """String representation showing counts by type"""
- type_counts = {}
- for source_type, _ in self._results_in_order:
- type_counts[source_type] = type_counts.get(source_type, 0) + 1
- return f"SearchResultsCollector with {len(self._results_in_order)} results: {type_counts}"
- def get_all_results(self) -> list[Tuple[str, Any]]:
- """
- Return list of (source_type, result_obj, aggregator_index),
- in the order appended.
- """
- return self._results_in_order
- def convert_nonserializable_objects(obj):
- if hasattr(obj, "model_dump"):
- obj = obj.model_dump()
- if hasattr(obj, "as_dict"):
- obj = obj.as_dict()
- if hasattr(obj, "to_dict"):
- obj = obj.to_dict()
- if isinstance(obj, dict):
- new_obj = {}
- for key, value in obj.items():
- # Convert key to string if it is a UUID or not already a string.
- new_key = key if isinstance(key, str) else str(key)
- new_obj[new_key] = convert_nonserializable_objects(value)
- return new_obj
- elif isinstance(obj, list):
- return [convert_nonserializable_objects(item) for item in obj]
- elif isinstance(obj, tuple):
- return tuple(convert_nonserializable_objects(item) for item in obj)
- elif isinstance(obj, set):
- return {convert_nonserializable_objects(item) for item in obj}
- elif isinstance(obj, uuid.UUID):
- return str(obj)
- elif isinstance(obj, datetime):
- return obj.isoformat() # Convert datetime to ISO formatted string
- else:
- return obj
- def dump_obj(obj) -> list[dict[str, Any]]:
- if hasattr(obj, "model_dump"):
- obj = obj.model_dump()
- elif hasattr(obj, "dict"):
- obj = obj.dict()
- elif hasattr(obj, "as_dict"):
- obj = obj.as_dict()
- elif hasattr(obj, "to_dict"):
- obj = obj.to_dict()
- obj = convert_nonserializable_objects(obj)
- return obj
- def dump_collector(collector: SearchResultsCollector) -> list[dict[str, Any]]:
- dumped = []
- for source_type, result_obj in collector.get_all_results():
- # Get the dictionary from the result object
- if hasattr(result_obj, "model_dump"):
- result_dict = result_obj.model_dump()
- elif hasattr(result_obj, "dict"):
- result_dict = result_obj.dict()
- elif hasattr(result_obj, "as_dict"):
- result_dict = result_obj.as_dict()
- elif hasattr(result_obj, "to_dict"):
- result_dict = result_obj.to_dict()
- else:
- result_dict = (
- result_obj # Fallback if no conversion method is available
- )
- # Use the recursive conversion on the entire dictionary
- result_dict = convert_nonserializable_objects(result_dict)
- dumped.append(
- {
- "source_type": source_type,
- "result": result_dict,
- }
- )
- return dumped
- # FIXME: Tiktoken does not support gpt-4.1, so continue using gpt-4o
- # https://github.com/openai/tiktoken/issues/395
- def num_tokens(text, model="gpt-4o"):
- try:
- encoding = tiktoken.encoding_for_model(model)
- except KeyError:
- # Fallback to a known encoding if model not recognized
- encoding = tiktoken.get_encoding("cl100k_base")
- return len(encoding.encode(text, disallowed_special=()))
- class CombinedMeta(AsyncSyncMeta, ABCMeta):
- pass
- async def yield_sse_event(event_name: str, payload: dict, chunk_size=1024):
- """
- Helper that yields a single SSE event in properly chunked lines.
- e.g. event: event_name
- data: (partial JSON 1)
- data: (partial JSON 2)
- ...
- [blank line to end event]
- """
- # SSE: first the "event: ..."
- yield f"event: {event_name}\n"
- # Convert payload to JSON
- content_str = json.dumps(payload, default=str)
- # data
- yield f"data: {content_str}\n"
- # blank line signals end of SSE event
- yield "\n"
- class SSEFormatter:
- """
- Enhanced formatter for Server-Sent Events (SSE) with citation tracking.
- Extends the existing SSEFormatter with improved citation handling.
- """
- @staticmethod
- async def yield_citation_event(
- citation_data: dict,
- ):
- """
- Emits a citation event with optimized payload.
- Args:
- citation_id: The short ID of the citation (e.g., 'abc1234')
- span: (start, end) position tuple for this occurrence
- payload: Source object (included only for first occurrence)
- is_new: Whether this is the first time we've seen this citation
- citation_id_counter: Optional counter for citation occurrences
- Yields:
- Formatted SSE event lines
- """
- # Include the full payload only for new citations
- if not citation_data.get("is_new") or "payload" not in citation_data:
- citation_data["payload"] = None
- # Yield the event
- async for line in yield_sse_event("citation", citation_data):
- yield line
- @staticmethod
- async def yield_final_answer_event(
- final_data: dict,
- ):
- # Yield the event
- async for line in yield_sse_event("final_answer", final_data):
- yield line
- # Include other existing SSEFormatter methods for compatibility
- @staticmethod
- async def yield_message_event(text_segment, msg_id=None):
- msg_id = msg_id or f"msg_{uuid.uuid4().hex[:8]}"
- msg_payload = {
- "id": msg_id,
- "object": "agent.message.delta",
- "delta": {
- "content": [
- {
- "type": "text",
- "payload": {
- "value": text_segment,
- "annotations": [],
- },
- }
- ]
- },
- }
- async for line in yield_sse_event("message", msg_payload):
- yield line
- @staticmethod
- async def yield_thinking_event(text_segment, thinking_id=None):
- thinking_id = thinking_id or f"think_{uuid.uuid4().hex[:8]}"
- thinking_data = {
- "id": thinking_id,
- "object": "agent.thinking.delta",
- "delta": {
- "content": [
- {
- "type": "text",
- "payload": {
- "value": text_segment,
- "annotations": [],
- },
- }
- ]
- },
- }
- async for line in yield_sse_event("thinking", thinking_data):
- yield line
- @staticmethod
- def yield_done_event():
- return "event: done\ndata: [DONE]\n\n"
- @staticmethod
- async def yield_error_event(error_message, error_id=None):
- error_id = error_id or f"err_{uuid.uuid4().hex[:8]}"
- error_payload = {
- "id": error_id,
- "object": "agent.error",
- "error": {"message": error_message, "type": "agent_error"},
- }
- async for line in yield_sse_event("error", error_payload):
- yield line
- @staticmethod
- async def yield_tool_call_event(tool_call_data):
- from ..api.models.retrieval.responses import ToolCallEvent
- tc_event = ToolCallEvent(event="tool_call", data=tool_call_data)
- async for line in yield_sse_event(
- "tool_call", tc_event.dict()["data"]
- ):
- yield line
- # New helper for emitting search results:
- @staticmethod
- async def yield_search_results_event(aggregated_results):
- payload = {
- "id": "search_1",
- "object": "rag.search_results",
- "data": aggregated_results.as_dict(),
- }
- async for line in yield_sse_event("search_results", payload):
- yield line
- @staticmethod
- async def yield_tool_result_event(tool_result_data):
- from ..api.models.retrieval.responses import ToolResultEvent
- tr_event = ToolResultEvent(event="tool_result", data=tool_result_data)
- async for line in yield_sse_event(
- "tool_result", tr_event.dict()["data"]
- ):
- yield line
|