graph_search_pipe.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import json
  2. import logging
  3. from typing import Any, AsyncGenerator
  4. from uuid import UUID
  5. from core.base import (
  6. AsyncState,
  7. CompletionProvider,
  8. DatabaseProvider,
  9. EmbeddingProvider,
  10. )
  11. from core.base.abstractions import (
  12. GraphSearchResult,
  13. GraphSearchSettings,
  14. KGCommunityResult,
  15. KGEntityResult,
  16. KGRelationshipResult,
  17. KGSearchResultType,
  18. SearchSettings,
  19. )
  20. from ..abstractions.generator_pipe import GeneratorPipe
  21. logger = logging.getLogger()
  22. class GraphSearchSearchPipe(GeneratorPipe):
  23. """
  24. Embeds and stores documents using a specified embedding model and database.
  25. """
  26. def __init__(
  27. self,
  28. llm_provider: CompletionProvider,
  29. database_provider: DatabaseProvider,
  30. embedding_provider: EmbeddingProvider,
  31. config: GeneratorPipe.PipeConfig,
  32. *args,
  33. **kwargs,
  34. ):
  35. """
  36. Initializes the embedding pipe with necessary components and configurations.
  37. """
  38. super().__init__(
  39. llm_provider,
  40. database_provider,
  41. config,
  42. *args,
  43. **kwargs,
  44. )
  45. self.database_provider = database_provider
  46. self.llm_provider = llm_provider
  47. self.embedding_provider = embedding_provider
  48. self.pipe_run_info = None
  49. def filter_responses(self, map_responses):
  50. filtered_responses = []
  51. for response in map_responses:
  52. try:
  53. parsed_response = json.loads(response)
  54. for item in parsed_response["points"]:
  55. try:
  56. if item["score"] > 0:
  57. filtered_responses.append(item)
  58. except KeyError:
  59. # Skip this item if it doesn't have a 'score' key
  60. logger.warning(f"Item in response missing 'score' key")
  61. continue
  62. except json.JSONDecodeError:
  63. logger.warning(
  64. f"Response is not valid JSON: {response[:100]}..."
  65. )
  66. continue
  67. except KeyError:
  68. logger.warning(
  69. f"Response is missing 'points' key: {response[:100]}..."
  70. )
  71. continue
  72. filtered_responses = sorted(
  73. filtered_responses, key=lambda x: x["score"], reverse=True
  74. )
  75. responses = "\n".join(
  76. [
  77. response.get("description", "")
  78. for response in filtered_responses
  79. ]
  80. )
  81. return responses
  82. async def search(
  83. self,
  84. input: GeneratorPipe.Input,
  85. state: AsyncState,
  86. run_id: UUID,
  87. search_settings: SearchSettings,
  88. *args: Any,
  89. **kwargs: Any,
  90. ) -> AsyncGenerator[GraphSearchResult, None]:
  91. if search_settings.graph_settings.enabled == False:
  92. return
  93. async for message in input.message:
  94. query_embedding = (
  95. await self.embedding_provider.async_get_embedding(message)
  96. )
  97. # entity search
  98. search_type = "entities"
  99. base_limit = search_settings.limit
  100. if search_type not in search_settings.graph_settings.limits:
  101. logger.warning(
  102. f"No limit set for graph search type {search_type}, defaulting to global settings limit of {base_limit}"
  103. )
  104. async for search_result in self.database_provider.graphs_handler.graph_search( # type: ignore
  105. message,
  106. search_type=search_type,
  107. limit=search_settings.graph_settings.limits.get(
  108. search_type, base_limit
  109. ),
  110. query_embedding=query_embedding,
  111. property_names=[
  112. "name",
  113. "description",
  114. "chunk_ids",
  115. ],
  116. filters=search_settings.filters,
  117. ):
  118. yield GraphSearchResult(
  119. content=KGEntityResult(
  120. name=search_result["name"],
  121. description=search_result["description"],
  122. ),
  123. result_type=KGSearchResultType.ENTITY,
  124. score=(
  125. search_result["similarity_score"]
  126. if search_settings.include_scores
  127. else None
  128. ),
  129. # chunk_ids=search_result["chunk_ids"],
  130. metadata=(
  131. {
  132. "associated_query": message,
  133. **(search_result["metadata"] or {}),
  134. }
  135. if search_settings.include_metadatas
  136. else None
  137. ),
  138. )
  139. # # relationship search
  140. # # disabled for now. We will check evaluations and see if we need it
  141. search_type = "relationships"
  142. if search_type not in search_settings.graph_settings.limits:
  143. logger.warning(
  144. f"No limit set for graph search type {search_type}, defaulting to global settings limit of {base_limit}"
  145. )
  146. async for search_result in self.database_provider.graphs_handler.graph_search( # type: ignore
  147. input,
  148. search_type=search_type,
  149. limit=search_settings.graph_settings.limits.get(
  150. search_type, base_limit
  151. ),
  152. query_embedding=query_embedding,
  153. property_names=[
  154. # "name",
  155. "subject",
  156. "predicate",
  157. "object",
  158. # "name",
  159. "description",
  160. # "chunk_ids",
  161. # "document_ids",
  162. ],
  163. ):
  164. try:
  165. # TODO - remove this nasty hack
  166. search_result["metadata"] = json.loads(
  167. search_result["metadata"]
  168. )
  169. except:
  170. pass
  171. yield GraphSearchResult(
  172. content=KGRelationshipResult(
  173. # name=search_result["name"],
  174. subject=search_result["subject"],
  175. predicate=search_result["predicate"],
  176. object=search_result["object"],
  177. description=search_result["description"],
  178. ),
  179. result_type=KGSearchResultType.RELATIONSHIP,
  180. score=(
  181. search_result["similarity_score"]
  182. if search_settings.include_scores
  183. else None
  184. ),
  185. # chunk_ids=search_result["chunk_ids"],
  186. # document_ids=search_result["document_ids"],
  187. metadata=(
  188. {
  189. "associated_query": message,
  190. **(search_result["metadata"] or {}),
  191. }
  192. if search_settings.include_metadatas
  193. else None
  194. ),
  195. )
  196. # community search
  197. search_type = "communities"
  198. async for search_result in self.database_provider.graphs_handler.graph_search( # type: ignore
  199. message,
  200. search_type=search_type,
  201. limit=search_settings.graph_settings.limits.get(
  202. search_type, base_limit
  203. ),
  204. # embedding_type="embedding",
  205. query_embedding=query_embedding,
  206. property_names=[
  207. "community_id",
  208. "name",
  209. "findings",
  210. "rating",
  211. "rating_explanation",
  212. "summary",
  213. ],
  214. filters=search_settings.filters,
  215. ):
  216. yield GraphSearchResult(
  217. content=KGCommunityResult(
  218. name=search_result["name"],
  219. summary=search_result["summary"],
  220. rating=search_result["rating"],
  221. rating_explanation=search_result["rating_explanation"],
  222. findings=search_result["findings"],
  223. ),
  224. result_type=KGSearchResultType.COMMUNITY,
  225. metadata=(
  226. {
  227. "associated_query": message,
  228. **(search_result["metadata"] or {}),
  229. }
  230. if search_settings.include_metadatas
  231. else None
  232. ),
  233. score=(
  234. search_result["similarity_score"]
  235. if search_settings.include_scores
  236. else None
  237. ),
  238. )
  239. async def _run_logic( # type: ignore
  240. self,
  241. input: GeneratorPipe.Input,
  242. state: AsyncState,
  243. run_id: UUID,
  244. search_settings: GraphSearchSettings,
  245. *args: Any,
  246. **kwargs: Any,
  247. ) -> AsyncGenerator[GraphSearchResult, None]:
  248. async for result in self.search(input, state, run_id, search_settings):
  249. yield result