base_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import asyncio
  2. import json
  3. import logging
  4. from copy import deepcopy
  5. from datetime import datetime
  6. from typing import (
  7. TYPE_CHECKING,
  8. Any,
  9. AsyncGenerator,
  10. Iterable,
  11. Optional,
  12. TypeVar,
  13. )
  14. from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5
  15. from ..abstractions.search import (
  16. AggregateSearchResult,
  17. KGCommunityResult,
  18. KGEntityResult,
  19. KGRelationshipResult,
  20. )
  21. from ..abstractions.vector import VectorQuantizationType
  22. logger = logging.getLogger()
  23. def format_search_results_for_llm(results: AggregateSearchResult) -> str:
  24. formatted_results = []
  25. source_counter = 1
  26. if results.chunk_search_results:
  27. formatted_results.append("Vector Search Results:")
  28. for result in results.chunk_search_results:
  29. formatted_results.extend(
  30. (f"Source [{source_counter}]:", f"{result.text}")
  31. )
  32. source_counter += 1
  33. if results.graph_search_results:
  34. formatted_results.append("KG Search Results:")
  35. for kg_result in results.graph_search_results:
  36. try:
  37. formatted_results.extend((f"Source [{source_counter}]:",))
  38. except AttributeError:
  39. raise ValueError(f"Invalid KG search result: {kg_result}")
  40. # formatted_results.extend(
  41. # (
  42. # f"Source [{source_counter}]:",
  43. # f"Type: {kg_result.content.type}",
  44. # )
  45. # )
  46. if isinstance(kg_result.content, KGCommunityResult):
  47. formatted_results.extend(
  48. (
  49. f"Name: {kg_result.content.name}",
  50. f"Summary: {kg_result.content.summary}",
  51. # f"Rating: {kg_result.content.rating}",
  52. # f"Rating Explanation: {kg_result.content.rating_explanation}",
  53. # "Findings:",
  54. )
  55. )
  56. # formatted_results.append(
  57. # f"- {finding}" for finding in kg_result.content.findings
  58. # )
  59. elif isinstance(
  60. kg_result.content,
  61. KGEntityResult,
  62. ):
  63. formatted_results.extend(
  64. [
  65. f"Name: {kg_result.content.name}",
  66. f"Description: {kg_result.content.description}",
  67. ]
  68. )
  69. elif isinstance(kg_result.content, KGRelationshipResult):
  70. formatted_results.append(
  71. f"Relationship: {kg_result.content.subject} - {kg_result.content.predicate} - {kg_result.content.object}",
  72. # f"Description: {kg_result.content.description}"
  73. )
  74. if kg_result.metadata:
  75. formatted_results.append("Metadata:")
  76. formatted_results.extend(
  77. f"- {key}: {value}"
  78. for key, value in kg_result.metadata.items()
  79. )
  80. source_counter += 1
  81. if results.web_search_results:
  82. formatted_results.append("Web Search Results:")
  83. for result in results.web_search_results:
  84. formatted_results.extend(
  85. (
  86. f"Source [{source_counter}]:",
  87. f"Title: {result.title}",
  88. f"Link: {result.link}",
  89. f"Snippet: {result.snippet}",
  90. )
  91. )
  92. if result.date:
  93. formatted_results.append(f"Date: {result.date}")
  94. source_counter += 1
  95. return "\n".join(formatted_results)
  96. def format_search_results_for_stream(result: AggregateSearchResult) -> str:
  97. CHUNK_SEARCH_STREAM_MARKER = "chunk_search"
  98. GRAPH_SEARCH_STREAM_MARKER = "graph_search"
  99. WEB_SEARCH_STREAM_MARKER = "web_search"
  100. context = ""
  101. if result.chunk_search_results:
  102. context += f"<{CHUNK_SEARCH_STREAM_MARKER}>"
  103. vector_results_list = [
  104. result.as_dict() for result in result.chunk_search_results
  105. ]
  106. context += json.dumps(vector_results_list, default=str)
  107. context += f"</{CHUNK_SEARCH_STREAM_MARKER}>"
  108. if result.graph_search_results:
  109. context += f"<{GRAPH_SEARCH_STREAM_MARKER}>"
  110. kg_results_list = [
  111. result.dict() for result in result.graph_search_results
  112. ]
  113. context += json.dumps(kg_results_list, default=str)
  114. context += f"</{GRAPH_SEARCH_STREAM_MARKER}>"
  115. if result.web_search_results:
  116. context += f"<{WEB_SEARCH_STREAM_MARKER}>"
  117. web_results_list = [
  118. result.to_dict() for result in result.web_search_results
  119. ]
  120. context += json.dumps(web_results_list, default=str)
  121. context += f"</{WEB_SEARCH_STREAM_MARKER}>"
  122. return context
  123. if TYPE_CHECKING:
  124. from ..pipeline.base_pipeline import AsyncPipeline
  125. def _generate_id_from_label(label) -> UUID:
  126. return uuid5(NAMESPACE_DNS, label)
  127. def generate_id(label: Optional[str] = None) -> UUID:
  128. """
  129. Generates a unique run id
  130. """
  131. return _generate_id_from_label(label if label != None else str(uuid4()))
  132. # def generate_id(label: Optional[str]= None) -> UUID:
  133. # """
  134. # Generates a unique run id
  135. # """
  136. # return _generate_id_from_label(str(uuid4(label)))
  137. def generate_document_id(filename: str, user_id: UUID) -> UUID:
  138. """
  139. Generates a unique document id from a given filename and user id
  140. """
  141. return _generate_id_from_label(f'{filename.split("/")[-1]}-{str(user_id)}')
  142. def generate_extraction_id(
  143. document_id: UUID, iteration: int = 0, version: str = "0"
  144. ) -> UUID:
  145. """
  146. Generates a unique extraction id from a given document id and iteration
  147. """
  148. return _generate_id_from_label(f"{str(document_id)}-{iteration}-{version}")
  149. def generate_default_user_collection_id(user_id: UUID) -> UUID:
  150. """
  151. Generates a unique collection id from a given user id
  152. """
  153. return _generate_id_from_label(str(user_id))
  154. def generate_user_id(email: str) -> UUID:
  155. """
  156. Generates a unique user id from a given email
  157. """
  158. return _generate_id_from_label(email)
  159. def generate_default_prompt_id(prompt_name: str) -> UUID:
  160. """
  161. Generates a unique prompt id
  162. """
  163. return _generate_id_from_label(prompt_name)
  164. def generate_entity_document_id() -> UUID:
  165. """
  166. Generates a unique document id inserting entities into a graph
  167. """
  168. generation_time = datetime.now().isoformat()
  169. return _generate_id_from_label(f"entity-{generation_time}")
  170. async def to_async_generator(
  171. iterable: Iterable[Any],
  172. ) -> AsyncGenerator[Any, None]:
  173. for item in iterable:
  174. yield item
  175. def run_pipeline(pipeline: "AsyncPipeline", input: Any, *args, **kwargs):
  176. if not isinstance(input, AsyncGenerator):
  177. if not isinstance(input, list):
  178. input = to_async_generator([input])
  179. else:
  180. input = to_async_generator(input)
  181. async def _run_pipeline(input, *args, **kwargs):
  182. return await pipeline.run(input, *args, **kwargs)
  183. return asyncio.run(_run_pipeline(input, *args, **kwargs))
  184. def increment_version(version: str) -> str:
  185. prefix = version[:-1]
  186. suffix = int(version[-1])
  187. return f"{prefix}{suffix + 1}"
  188. def decrement_version(version: str) -> str:
  189. prefix = version[:-1]
  190. suffix = int(version[-1])
  191. return f"{prefix}{max(0, suffix - 1)}"
  192. def llm_cost_per_million_tokens(
  193. model: str, input_output_ratio: float = 2
  194. ) -> float:
  195. """
  196. Returns the cost per million tokens for a given model and input/output ratio.
  197. Input/Output ratio is the ratio of input tokens to output tokens.
  198. """
  199. # improving this to use provider in the future
  200. model = model.split("/")[-1] # simplifying assumption
  201. cost_dict = {
  202. "gpt-4o-mini": (0.15, 0.6),
  203. "gpt-4o": (2.5, 10),
  204. }
  205. if model in cost_dict:
  206. return (
  207. cost_dict[model][0] * input_output_ratio * cost_dict[model][1]
  208. ) / (1 + input_output_ratio)
  209. else:
  210. # use gpt-4o as default
  211. logger.warning(f"Unknown model: {model}. Using gpt-4o as default.")
  212. return (
  213. cost_dict["gpt-4o"][0]
  214. * input_output_ratio
  215. * cost_dict["gpt-4o"][1]
  216. ) / (1 + input_output_ratio)
  217. def validate_uuid(uuid_str: str) -> UUID:
  218. return UUID(uuid_str)
  219. def update_settings_from_dict(server_settings, settings_dict: dict):
  220. """
  221. Updates a settings object with values from a dictionary.
  222. """
  223. settings = deepcopy(server_settings)
  224. for key, value in settings_dict.items():
  225. if value is not None:
  226. if isinstance(value, dict):
  227. for k, v in value.items():
  228. if isinstance(getattr(settings, key), dict):
  229. getattr(settings, key)[k] = v
  230. else:
  231. setattr(getattr(settings, key), k, v)
  232. else:
  233. setattr(settings, key, value)
  234. return settings
  235. def _decorate_vector_type(
  236. input_str: str,
  237. quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
  238. ) -> str:
  239. return f"{quantization_type.db_type}{input_str}"
  240. def _get_str_estimation_output(x: tuple[Any, Any]) -> str:
  241. if isinstance(x[0], int) and isinstance(x[1], int):
  242. return " - ".join(map(str, x))
  243. else:
  244. return " - ".join(f"{round(a, 2)}" for a in x)
  245. KeyType = TypeVar("KeyType")
  246. def deep_update(
  247. mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]
  248. ) -> dict[KeyType, Any]:
  249. """
  250. Taken from Pydantic v1:
  251. https://github.com/pydantic/pydantic/blob/fd2991fe6a73819b48c906e3c3274e8e47d0f761/pydantic/utils.py#L200
  252. """
  253. updated_mapping = mapping.copy()
  254. for updating_mapping in updating_mappings:
  255. for k, v in updating_mapping.items():
  256. if (
  257. k in updated_mapping
  258. and isinstance(updated_mapping[k], dict)
  259. and isinstance(v, dict)
  260. ):
  261. updated_mapping[k] = deep_update(updated_mapping[k], v)
  262. else:
  263. updated_mapping[k] = v
  264. return updated_mapping