retrieval.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import json
  2. import asyncclick as click
  3. from asyncclick import pass_context
  4. from cli.utils.param_types import JSON
  5. from cli.utils.timer import timer
  6. from r2r import R2RAsyncClient, R2RException
  7. @click.group()
  8. def retrieval():
  9. """Retrieval commands."""
  10. pass
  11. @retrieval.command()
  12. @click.option(
  13. "--query", prompt="Enter your search query", help="The search query"
  14. )
  15. @click.option(
  16. "--limit", default=None, help="Number of search results to return"
  17. )
  18. @click.option(
  19. "--use-hybrid-search",
  20. default=None,
  21. help="Perform hybrid search? Equivalent to `use-semantic-search` and `use-fulltext-search`",
  22. )
  23. @click.option(
  24. "--use-semantic-search", default=None, help="Perform semantic search?"
  25. )
  26. @click.option(
  27. "--use-fulltext-search", default=None, help="Perform fulltext search?"
  28. )
  29. @click.option(
  30. "--filters",
  31. type=JSON,
  32. help="""Filters to apply to the vector search as a JSON, e.g. --filters='{"document_id":{"$in":["9fbe403b-c11c-5aae-8ade-ef22980c3ad1", "3e157b3a-8469-51db-90d9-52e7d896b49b"]}}'""",
  33. )
  34. @click.option(
  35. "--search-strategy",
  36. type=str,
  37. default="vanilla",
  38. help="Vanilla RAG or complex method like query fusion or HyDE.",
  39. )
  40. @click.option(
  41. "--graph-search-enabled", default=None, help="Use knowledge graph search?"
  42. )
  43. @click.option(
  44. "--chunk-search-enabled",
  45. default=None,
  46. help="Use search over document chunks?",
  47. )
  48. @pass_context
  49. async def search(ctx: click.Context, query, **kwargs):
  50. """Perform a search query."""
  51. search_settings = {
  52. k: v
  53. for k, v in kwargs.items()
  54. if k
  55. in [
  56. "filters",
  57. "limit",
  58. "search_strategy",
  59. "use_hybrid_search",
  60. "use_semantic_search",
  61. "use_fulltext_search",
  62. "search_strategy",
  63. ]
  64. and v is not None
  65. }
  66. graph_search_enabled = kwargs.get("graph_search_enabled")
  67. if graph_search_enabled != None:
  68. search_settings["graph_settings"] = {"enabled": graph_search_enabled}
  69. chunk_search_enabled = kwargs.get("chunk_search_enabled")
  70. if chunk_search_enabled != None:
  71. search_settings["chunk_settings"] = {"enabled": chunk_search_enabled}
  72. client: R2RAsyncClient = ctx.obj
  73. print("client.base_url = ", client.base_url)
  74. try:
  75. with timer():
  76. results = await client.retrieval.search(
  77. query,
  78. "custom",
  79. search_settings,
  80. )
  81. if isinstance(results, dict) and "results" in results:
  82. results = results["results"]
  83. if "chunk_search_results" in results: # type: ignore
  84. click.echo("Vector search results:")
  85. for result in results["chunk_search_results"]: # type: ignore
  86. click.echo(json.dumps(result, indent=2))
  87. if (
  88. "graph_search_results" in results # type: ignore
  89. and results["graph_search_results"] # type: ignore
  90. ):
  91. click.echo("KG search results:")
  92. for result in results["graph_search_results"]: # type: ignore
  93. click.echo(json.dumps(result, indent=2))
  94. except R2RException as e:
  95. click.echo(str(e), err=True)
  96. except Exception as e:
  97. click.echo(str(f"An unexpected error occurred: {e}"), err=True)
  98. @retrieval.command()
  99. @click.option(
  100. "--query", prompt="Enter your search query", help="The search query"
  101. )
  102. @click.option(
  103. "--limit", default=None, help="Number of search results to return"
  104. )
  105. @click.option(
  106. "--use-hybrid-search",
  107. default=None,
  108. help="Perform hybrid search? Equivalent to `use-semantic-search` and `use-fulltext-search`",
  109. )
  110. @click.option(
  111. "--use-semantic-search", default=None, help="Perform semantic search?"
  112. )
  113. @click.option(
  114. "--use-fulltext-search", default=None, help="Perform fulltext search?"
  115. )
  116. @click.option(
  117. "--filters",
  118. type=JSON,
  119. help="""Filters to apply to the vector search as a JSON, e.g. --filters='{"document_id":{"$in":["9fbe403b-c11c-5aae-8ade-ef22980c3ad1", "3e157b3a-8469-51db-90d9-52e7d896b49b"]}}'""",
  120. )
  121. @click.option(
  122. "--search-strategy",
  123. type=str,
  124. default="vanilla",
  125. help="Vanilla RAG or complex method like query fusion or HyDE.",
  126. )
  127. @click.option(
  128. "--graph-search-enabled", default=None, help="Use knowledge graph search?"
  129. )
  130. @click.option(
  131. "--chunk-search-enabled",
  132. default=None,
  133. help="Use search over document chunks?",
  134. )
  135. @click.option("--stream", is_flag=True, help="Stream the RAG response")
  136. @click.option("--rag-model", default=None, help="Model for RAG")
  137. @pass_context
  138. async def rag(ctx: click.Context, query, **kwargs):
  139. """Perform a RAG query."""
  140. rag_generation_config = {
  141. "stream": kwargs.get("stream", False),
  142. }
  143. if kwargs.get("rag_model"):
  144. rag_generation_config["model"] = kwargs["rag_model"]
  145. search_settings = {
  146. k: v
  147. for k, v in kwargs.items()
  148. if k
  149. in [
  150. "filters",
  151. "limit",
  152. "search_strategy",
  153. "use_hybrid_search",
  154. "use_semantic_search",
  155. "use_fulltext_search",
  156. "search_strategy",
  157. ]
  158. and v is not None
  159. }
  160. graph_search_enabled = kwargs.get("graph_search_enabled")
  161. if graph_search_enabled != None:
  162. search_settings["graph_settings"] = {"enabled": graph_search_enabled}
  163. chunk_search_enabled = kwargs.get("chunk_search_enabled")
  164. if chunk_search_enabled != None:
  165. search_settings["chunk_settings"] = {"enabled": chunk_search_enabled}
  166. client: R2RAsyncClient = ctx.obj
  167. try:
  168. with timer():
  169. response = await client.retrieval.rag(
  170. query=query,
  171. rag_generation_config=rag_generation_config,
  172. search_settings={**search_settings},
  173. )
  174. if rag_generation_config.get("stream"):
  175. async for chunk in response: # type: ignore
  176. click.echo(chunk, nl=False)
  177. click.echo()
  178. else:
  179. click.echo(json.dumps(response["results"]["completion"], indent=2)) # type: ignore
  180. except R2RException as e:
  181. click.echo(str(e), err=True)
  182. except Exception as e:
  183. click.echo(str(f"An unexpected error occurred: {e}"), err=True)