retrieval.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import json
  2. import asyncclick as click
  3. from asyncclick import pass_context
  4. from cli.utils.param_types import JSON
  5. from cli.utils.timer import timer
  6. from r2r import R2RAsyncClient
  7. @click.group()
  8. def retrieval():
  9. """Retrieval commands."""
  10. pass
  11. @retrieval.command()
  12. @click.option(
  13. "--query", prompt="Enter your search query", help="The search query"
  14. )
  15. @click.option(
  16. "--limit", default=None, help="Number of search results to return"
  17. )
  18. @click.option(
  19. "--use-hybrid-search",
  20. default=None,
  21. help="Perform hybrid search? Equivalent to `use-semantic-search` and `use-fulltext-search`",
  22. )
  23. @click.option(
  24. "--use-semantic-search", default=None, help="Perform semantic search?"
  25. )
  26. @click.option(
  27. "--use-fulltext-search", default=None, help="Perform fulltext search?"
  28. )
  29. @click.option(
  30. "--filters",
  31. type=JSON,
  32. help="""Filters to apply to the vector search as a JSON, e.g. --filters='{"document_id":{"$in":["9fbe403b-c11c-5aae-8ade-ef22980c3ad1", "3e157b3a-8469-51db-90d9-52e7d896b49b"]}}'""",
  33. )
  34. @click.option(
  35. "--search-strategy",
  36. type=str,
  37. default="vanilla",
  38. help="Vanilla RAG or complex method like query fusion or HyDE.",
  39. )
  40. @click.option(
  41. "--graph-search-enabled", default=None, help="Use knowledge graph search?"
  42. )
  43. @click.option(
  44. "--chunk-search-enabled",
  45. default=None,
  46. help="Use search over document chunks?",
  47. )
  48. @pass_context
  49. async def search(ctx, query, **kwargs):
  50. """Perform a search query."""
  51. client: R2RAsyncClient = ctx.obj
  52. search_settings = {
  53. k: v
  54. for k, v in kwargs.items()
  55. if k
  56. in [
  57. "filters",
  58. "limit",
  59. "search_strategy",
  60. "use_hybrid_search",
  61. "use_semantic_search",
  62. "use_fulltext_search",
  63. "search_strategy",
  64. ]
  65. and v is not None
  66. }
  67. graph_search_enabled = kwargs.get("graph_search_enabled")
  68. if graph_search_enabled != None:
  69. search_settings["graph_settings"] = {"enabled": graph_search_enabled}
  70. chunk_search_enabled = kwargs.get("chunk_search_enabled")
  71. if chunk_search_enabled != None:
  72. search_settings["chunk_settings"] = {"enabled": chunk_search_enabled}
  73. with timer():
  74. results = await client.retrieval.search(
  75. query,
  76. "custom",
  77. search_settings,
  78. )
  79. if isinstance(results, dict) and "results" in results:
  80. results = results["results"]
  81. if "chunk_search_results" in results:
  82. click.echo("Vector search results:")
  83. for result in results["chunk_search_results"]:
  84. click.echo(json.dumps(result, indent=2))
  85. if (
  86. "graph_search_results" in results
  87. and results["graph_search_results"]
  88. ):
  89. click.echo("KG search results:")
  90. for result in results["graph_search_results"]:
  91. click.echo(json.dumps(result, indent=2))
  92. @retrieval.command()
  93. @click.option(
  94. "--query", prompt="Enter your search query", help="The search query"
  95. )
  96. @click.option(
  97. "--limit", default=None, help="Number of search results to return"
  98. )
  99. @click.option(
  100. "--use-hybrid-search",
  101. default=None,
  102. help="Perform hybrid search? Equivalent to `use-semantic-search` and `use-fulltext-search`",
  103. )
  104. @click.option(
  105. "--use-semantic-search", default=None, help="Perform semantic search?"
  106. )
  107. @click.option(
  108. "--use-fulltext-search", default=None, help="Perform fulltext search?"
  109. )
  110. @click.option(
  111. "--filters",
  112. type=JSON,
  113. help="""Filters to apply to the vector search as a JSON, e.g. --filters='{"document_id":{"$in":["9fbe403b-c11c-5aae-8ade-ef22980c3ad1", "3e157b3a-8469-51db-90d9-52e7d896b49b"]}}'""",
  114. )
  115. @click.option(
  116. "--search-strategy",
  117. type=str,
  118. default="vanilla",
  119. help="Vanilla RAG or complex method like query fusion or HyDE.",
  120. )
  121. @click.option(
  122. "--graph-search-enabled", default=None, help="Use knowledge graph search?"
  123. )
  124. @click.option(
  125. "--chunk-search-enabled",
  126. default=None,
  127. help="Use search over document chunks?",
  128. )
  129. @click.option("--stream", is_flag=True, help="Stream the RAG response")
  130. @click.option("--rag-model", default=None, help="Model for RAG")
  131. @pass_context
  132. async def rag(ctx, query, **kwargs):
  133. """Perform a RAG query."""
  134. client: R2RAsyncClient = ctx.obj
  135. rag_generation_config = {
  136. "stream": kwargs.get("stream", False),
  137. }
  138. if kwargs.get("rag_model"):
  139. rag_generation_config["model"] = kwargs["rag_model"]
  140. search_settings = {
  141. k: v
  142. for k, v in kwargs.items()
  143. if k
  144. in [
  145. "filters",
  146. "limit",
  147. "search_strategy",
  148. "use_hybrid_search",
  149. "use_semantic_search",
  150. "use_fulltext_search",
  151. "search_strategy",
  152. ]
  153. and v is not None
  154. }
  155. graph_search_enabled = kwargs.get("graph_search_enabled")
  156. if graph_search_enabled != None:
  157. search_settings["graph_settings"] = {"enabled": graph_search_enabled}
  158. chunk_search_enabled = kwargs.get("chunk_search_enabled")
  159. if chunk_search_enabled != None:
  160. search_settings["chunk_settings"] = {"enabled": chunk_search_enabled}
  161. with timer():
  162. response = await client.retrieval.rag(
  163. query=query,
  164. rag_generation_config=rag_generation_config,
  165. search_settings={**search_settings},
  166. )
  167. if rag_generation_config.get("stream"):
  168. async for chunk in response:
  169. click.echo(chunk, nl=False)
  170. click.echo()
  171. else:
  172. click.echo(json.dumps(response["results"]["completion"], indent=2))