base_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. import asyncio
  2. import json
  3. import logging
  4. from copy import deepcopy
  5. from datetime import datetime
  6. from typing import (
  7. TYPE_CHECKING,
  8. Any,
  9. AsyncGenerator,
  10. Iterable,
  11. Optional,
  12. TypeVar,
  13. )
  14. from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5
  15. from ..abstractions.search import (
  16. AggregateSearchResult,
  17. KGCommunityResult,
  18. KGEntityResult,
  19. KGGlobalResult,
  20. KGRelationshipResult,
  21. )
  22. from ..abstractions.vector import VectorQuantizationType
  23. logger = logging.getLogger()
  24. def format_search_results_for_llm(results: AggregateSearchResult) -> str:
  25. formatted_results = []
  26. source_counter = 1
  27. if results.chunk_search_results:
  28. formatted_results.append("Vector Search Results:")
  29. for result in results.chunk_search_results:
  30. formatted_results.extend(
  31. (f"Source [{source_counter}]:", f"{result.text}")
  32. )
  33. source_counter += 1
  34. if results.graph_search_results:
  35. formatted_results.append("KG Search Results:")
  36. for kg_result in results.graph_search_results:
  37. try:
  38. formatted_results.extend((f"Source [{source_counter}]:",))
  39. except AttributeError:
  40. raise ValueError(f"Invalid KG search result: {kg_result}")
  41. # formatted_results.extend(
  42. # (
  43. # f"Source [{source_counter}]:",
  44. # f"Type: {kg_result.content.type}",
  45. # )
  46. # )
  47. if isinstance(kg_result.content, KGCommunityResult):
  48. formatted_results.extend(
  49. (
  50. f"Name: {kg_result.content.name}",
  51. f"Summary: {kg_result.content.summary}",
  52. # f"Rating: {kg_result.content.rating}",
  53. # f"Rating Explanation: {kg_result.content.rating_explanation}",
  54. # "Findings:",
  55. )
  56. )
  57. # formatted_results.append(
  58. # f"- {finding}" for finding in kg_result.content.findings
  59. # )
  60. elif isinstance(
  61. kg_result.content,
  62. KGEntityResult,
  63. ):
  64. formatted_results.extend(
  65. [
  66. f"Name: {kg_result.content.name}",
  67. f"Description: {kg_result.content.description}",
  68. ]
  69. )
  70. elif isinstance(kg_result.content, KGRelationshipResult):
  71. formatted_results.append(
  72. f"Relationship: {kg_result.content.subject} - {kg_result.content.predicate} - {kg_result.content.object}",
  73. # f"Description: {kg_result.content.description}"
  74. )
  75. if kg_result.metadata:
  76. formatted_results.append("Metadata:")
  77. formatted_results.extend(
  78. f"- {key}: {value}"
  79. for key, value in kg_result.metadata.items()
  80. )
  81. source_counter += 1
  82. if results.web_search_results:
  83. formatted_results.append("Web Search Results:")
  84. for result in results.web_search_results:
  85. formatted_results.extend(
  86. (
  87. f"Source [{source_counter}]:",
  88. f"Title: {result.title}",
  89. f"Link: {result.link}",
  90. f"Snippet: {result.snippet}",
  91. )
  92. )
  93. if result.date:
  94. formatted_results.append(f"Date: {result.date}")
  95. source_counter += 1
  96. return "\n".join(formatted_results)
  97. def format_search_results_for_stream(result: AggregateSearchResult) -> str:
  98. CHUNK_SEARCH_STREAM_MARKER = "chunk_search"
  99. GRAPH_SEARCH_STREAM_MARKER = "graph_search"
  100. WEB_SEARCH_STREAM_MARKER = "web_search"
  101. context = ""
  102. if result.chunk_search_results:
  103. context += f"<{CHUNK_SEARCH_STREAM_MARKER}>"
  104. vector_results_list = [
  105. result.as_dict() for result in result.chunk_search_results
  106. ]
  107. context += json.dumps(vector_results_list, default=str)
  108. context += f"</{CHUNK_SEARCH_STREAM_MARKER}>"
  109. if result.graph_search_results:
  110. context += f"<{GRAPH_SEARCH_STREAM_MARKER}>"
  111. kg_results_list = [
  112. result.dict() for result in result.graph_search_results
  113. ]
  114. context += json.dumps(kg_results_list, default=str)
  115. context += f"</{GRAPH_SEARCH_STREAM_MARKER}>"
  116. if result.web_search_results:
  117. context += f"<{WEB_SEARCH_STREAM_MARKER}>"
  118. web_results_list = [
  119. result.to_dict() for result in result.web_search_results
  120. ]
  121. context += json.dumps(web_results_list, default=str)
  122. context += f"</{WEB_SEARCH_STREAM_MARKER}>"
  123. return context
  124. if TYPE_CHECKING:
  125. from ..pipeline.base_pipeline import AsyncPipeline
  126. def _generate_id_from_label(label) -> UUID:
  127. return uuid5(NAMESPACE_DNS, label)
  128. def generate_id(label: Optional[str] = None) -> UUID:
  129. """
  130. Generates a unique run id
  131. """
  132. return _generate_id_from_label(label if label != None else str(uuid4()))
  133. # def generate_id(label: Optional[str]= None) -> UUID:
  134. # """
  135. # Generates a unique run id
  136. # """
  137. # return _generate_id_from_label(str(uuid4(label)))
  138. def generate_document_id(filename: str, user_id: UUID) -> UUID:
  139. """
  140. Generates a unique document id from a given filename and user id
  141. """
  142. return _generate_id_from_label(f'{filename.split("/")[-1]}-{str(user_id)}')
  143. def generate_extraction_id(
  144. document_id: UUID, iteration: int = 0, version: str = "0"
  145. ) -> UUID:
  146. """
  147. Generates a unique extraction id from a given document id and iteration
  148. """
  149. return _generate_id_from_label(f"{str(document_id)}-{iteration}-{version}")
  150. def generate_default_user_collection_id(user_id: UUID) -> UUID:
  151. """
  152. Generates a unique collection id from a given user id
  153. """
  154. return _generate_id_from_label(str(user_id))
  155. def generate_user_id(email: str) -> UUID:
  156. """
  157. Generates a unique user id from a given email
  158. """
  159. return _generate_id_from_label(email)
  160. def generate_default_prompt_id(prompt_name: str) -> UUID:
  161. """
  162. Generates a unique prompt id
  163. """
  164. return _generate_id_from_label(prompt_name)
  165. def generate_entity_document_id() -> UUID:
  166. """
  167. Generates a unique document id inserting entities into a graph
  168. """
  169. generation_time = datetime.now().isoformat()
  170. return _generate_id_from_label(f"entity-{generation_time}")
  171. async def to_async_generator(
  172. iterable: Iterable[Any],
  173. ) -> AsyncGenerator[Any, None]:
  174. for item in iterable:
  175. yield item
  176. def run_pipeline(pipeline: "AsyncPipeline", input: Any, *args, **kwargs):
  177. if not isinstance(input, AsyncGenerator):
  178. if not isinstance(input, list):
  179. input = to_async_generator([input])
  180. else:
  181. input = to_async_generator(input)
  182. async def _run_pipeline(input, *args, **kwargs):
  183. return await pipeline.run(input, *args, **kwargs)
  184. return asyncio.run(_run_pipeline(input, *args, **kwargs))
  185. def increment_version(version: str) -> str:
  186. prefix = version[:-1]
  187. suffix = int(version[-1])
  188. return f"{prefix}{suffix + 1}"
  189. def decrement_version(version: str) -> str:
  190. prefix = version[:-1]
  191. suffix = int(version[-1])
  192. return f"{prefix}{max(0, suffix - 1)}"
  193. def llm_cost_per_million_tokens(
  194. model: str, input_output_ratio: float = 2
  195. ) -> float:
  196. """
  197. Returns the cost per million tokens for a given model and input/output ratio.
  198. Input/Output ratio is the ratio of input tokens to output tokens.
  199. """
  200. # improving this to use provider in the future
  201. model = model.split("/")[-1] # simplifying assumption
  202. cost_dict = {
  203. "gpt-4o-mini": (0.15, 0.6),
  204. "gpt-4o": (2.5, 10),
  205. }
  206. if model in cost_dict:
  207. return (
  208. cost_dict[model][0] * input_output_ratio * cost_dict[model][1]
  209. ) / (1 + input_output_ratio)
  210. else:
  211. # use gpt-4o as default
  212. logger.warning(f"Unknown model: {model}. Using gpt-4o as default.")
  213. return (
  214. cost_dict["gpt-4o"][0]
  215. * input_output_ratio
  216. * cost_dict["gpt-4o"][1]
  217. ) / (1 + input_output_ratio)
  218. def validate_uuid(uuid_str: str) -> UUID:
  219. return UUID(uuid_str)
  220. def update_settings_from_dict(server_settings, settings_dict: dict):
  221. """
  222. Updates a settings object with values from a dictionary.
  223. """
  224. settings = deepcopy(server_settings)
  225. for key, value in settings_dict.items():
  226. if value is not None:
  227. if isinstance(value, dict):
  228. for k, v in value.items():
  229. if isinstance(getattr(settings, key), dict):
  230. getattr(settings, key)[k] = v
  231. else:
  232. setattr(getattr(settings, key), k, v)
  233. else:
  234. setattr(settings, key, value)
  235. return settings
  236. def _decorate_vector_type(
  237. input_str: str,
  238. quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
  239. ) -> str:
  240. return f"{quantization_type.db_type}{input_str}"
  241. def _get_str_estimation_output(x: tuple[Any, Any]) -> str:
  242. if isinstance(x[0], int) and isinstance(x[1], int):
  243. return " - ".join(map(str, x))
  244. else:
  245. return " - ".join(f"{round(a, 2)}" for a in x)
  246. KeyType = TypeVar("KeyType")
  247. def deep_update(
  248. mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]
  249. ) -> dict[KeyType, Any]:
  250. """
  251. Taken from Pydantic v1:
  252. https://github.com/pydantic/pydantic/blob/fd2991fe6a73819b48c906e3c3274e8e47d0f761/pydantic/utils.py#L200
  253. """
  254. updated_mapping = mapping.copy()
  255. for updating_mapping in updating_mappings:
  256. for k, v in updating_mapping.items():
  257. if (
  258. k in updated_mapping
  259. and isinstance(updated_mapping[k], dict)
  260. and isinstance(v, dict)
  261. ):
  262. updated_mapping[k] = deep_update(updated_mapping[k], v)
  263. else:
  264. updated_mapping[k] = v
  265. return updated_mapping