mcp.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Add to your local machine with `mcp install r2r/mcp.py -v R2R_API_URL=http://localhost:7272` or so.
  2. from r2r import R2RClient
  3. def id_to_shorthand(id: str) -> str:
  4. return str(id)[:7]
  5. def format_search_results_for_llm(
  6. results,
  7. ) -> str:
  8. """
  9. Instead of resetting 'source_counter' to 1, we:
  10. - For each chunk / graph / web / doc in `results`,
  11. - Find the aggregator index from the collector,
  12. - Print 'Source [X]:' with that aggregator index.
  13. """
  14. lines = []
  15. # We'll build a quick helper to locate aggregator indices for each object:
  16. # Or you can rely on the fact that we've added them to the collector
  17. # in the same order. But let's do a "lookup aggregator index" approach:
  18. # 1) Chunk search
  19. if results.chunk_search_results:
  20. lines.append("Vector Search Results:")
  21. for c in results.chunk_search_results:
  22. lines.append(f"Source ID [{id_to_shorthand(c.id)}]:")
  23. lines.append(c.text or "") # or c.text[:200] to truncate
  24. # 2) Graph search
  25. if results.graph_search_results:
  26. lines.append("Graph Search Results:")
  27. for g in results.graph_search_results:
  28. lines.append(f"Source ID [{id_to_shorthand(g.id)}]:")
  29. if hasattr(g.content, "summary"):
  30. lines.append(f"Community Name: {g.content.name}")
  31. lines.append(f"ID: {g.content.id}")
  32. lines.append(f"Summary: {g.content.summary}")
  33. # etc. ...
  34. elif hasattr(g.content, "name") and hasattr(
  35. g.content, "description"
  36. ):
  37. lines.append(f"Entity Name: {g.content.name}")
  38. lines.append(f"Description: {g.content.description}")
  39. elif (
  40. hasattr(g.content, "subject")
  41. and hasattr(g.content, "predicate")
  42. and hasattr(g.content, "object")
  43. ):
  44. lines.append(
  45. f"Relationship: {g.content.subject}-{g.content.predicate}-{g.content.object}"
  46. )
  47. # Add metadata if needed
  48. # 3) Web search
  49. if results.web_search_results:
  50. lines.append("Web Search Results:")
  51. for w in results.web_search_results:
  52. lines.append(f"Source ID [{id_to_shorthand(w.id)}]:")
  53. lines.append(f"Title: {w.title}")
  54. lines.append(f"Link: {w.link}")
  55. lines.append(f"Snippet: {w.snippet}")
  56. # 4) Local context docs
  57. if results.document_search_results:
  58. lines.append("Local Context Documents:")
  59. for doc_result in results.document_search_results:
  60. doc_title = doc_result.title or "Untitled Document"
  61. doc_id = doc_result.id
  62. summary = doc_result.summary
  63. lines.append(f"Full Document ID: {doc_id}")
  64. lines.append(f"Shortened Document ID: {id_to_shorthand(doc_id)}")
  65. lines.append(f"Document Title: {doc_title}")
  66. if summary:
  67. lines.append(f"Summary: {summary}")
  68. if doc_result.chunks:
  69. # Then each chunk inside:
  70. for chunk in doc_result.chunks:
  71. lines.append(
  72. f"\nChunk ID {id_to_shorthand(chunk['id'])}:\n{chunk['text']}"
  73. )
  74. result = "\n".join(lines)
  75. return result
  76. # Create a FastMCP server
  77. try:
  78. from mcp.server.fastmcp import FastMCP
  79. mcp = FastMCP("R2R Retrieval System")
  80. except Exception as e:
  81. raise ImportError(
  82. "MCP is not installed. Please run `pip install mcp`"
  83. ) from e
  84. # Pass lifespan to server
  85. mcp = FastMCP("R2R Retrieval System")
  86. # RAG query tool
  87. @mcp.tool()
  88. async def search(query: str) -> str:
  89. """
  90. Performs a
  91. Args:
  92. query: The question to answer using the knowledge base
  93. Returns:
  94. A response generated based on relevant context from the knowledge base
  95. """
  96. client = R2RClient()
  97. # Call the RAG endpoint
  98. search_response = client.retrieval.search(
  99. query=query,
  100. )
  101. return format_search_results_for_llm(search_response.results)
  102. # RAG query tool
  103. @mcp.tool()
  104. async def rag(query: str) -> str:
  105. """
  106. Perform a Retrieval-Augmented Generation query
  107. Args:
  108. query: The question to answer using the knowledge base
  109. Returns:
  110. A response generated based on relevant context from the knowledge base
  111. """
  112. client = R2RClient()
  113. # Call the RAG endpoint
  114. rag_response = client.retrieval.rag(
  115. query=query,
  116. )
  117. return rag_response.results.generated_answer # type: ignore
  118. # Run the server if executed directly
  119. if __name__ == "__main__":
  120. mcp.run()