r2r.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from typing import Optional, Any
  2. from r2r import R2RAsyncClient
  3. from app.libs.util import verify_jwt_expiration
  4. from config.llm import tool_settings
  5. # import nest_asyncio
  6. import asyncio
  7. # Apply nest_asyncio to allow nested event loops
  8. # nest_asyncio.apply()
  9. class R2R:
  10. client: R2RAsyncClient
  11. def __init__(self):
  12. self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
  13. self.client = None
  14. async def init(self):
  15. if not self.auth_enabled:
  16. return
  17. if not self.client:
  18. self.client = R2RAsyncClient(tool_settings.R2R_BASE_URL, "/v3")
  19. await self.client.users.login(
  20. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  21. )
  22. print(self.client.access_token)
  23. def ingest_file(self, file_path: str, metadata: Optional[dict]):
  24. self._check_login()
  25. loop = asyncio.get_event_loop()
  26. if loop.is_running():
  27. # 如果事件循环已经在运行,可以通过loop.create_task()调度任务
  28. ingest_response = loop.create_task(
  29. self.client.documents.create(
  30. file_path=file_path,
  31. metadata=metadata if metadata else None,
  32. id=None,
  33. )
  34. )
  35. else:
  36. # 如果没有运行中的事件循环,使用 run_until_complete 来执行
  37. ingest_response = loop.run_until_complete(
  38. self.client.documents.create(
  39. file_path=file_path,
  40. metadata=metadata if metadata else None,
  41. id=None,
  42. )
  43. )
  44. return ingest_response.get("results")
  45. def search(self, query: str, filters: dict[str, Any]):
  46. self._check_login()
  47. loop = asyncio.get_event_loop()
  48. if loop.is_running():
  49. # 如果事件循环已经在运行,可以通过loop.create_task()调度任务
  50. search_response = loop.create_task(
  51. self.client.retrieval.search(
  52. query=query,
  53. search_settings={
  54. "filters": filters,
  55. "limit": tool_settings.R2R_SEARCH_LIMIT,
  56. },
  57. )
  58. )
  59. else:
  60. # 如果没有运行中的事件循环,使用 run_until_complete 来执行
  61. search_response = loop.run_until_complete(
  62. self.client.retrieval.search(
  63. query=query,
  64. search_settings={
  65. "filters": filters,
  66. "limit": tool_settings.R2R_SEARCH_LIMIT,
  67. },
  68. )
  69. )
  70. return search_response.get("results").get("chunk_search_results")
  71. def _check_login(self):
  72. if not self.auth_enabled:
  73. return
  74. if not self.client.access_token and verify_jwt_expiration(
  75. self.client.access_token
  76. ):
  77. return
  78. else:
  79. asyncio.create_task(self.init())
  80. # 创建 R2R 实例
  81. r2r = R2R()
  82. # 在您的应用程序启动时调用 initialize_r2r()
  83. async def initialize_r2r():
  84. await r2r.init()