r2r.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from typing import Optional, Any
  2. import pytest
  3. from r2r import R2RClient
  4. from app.libs.util import verify_jwt_expiration
  5. from config.llm import tool_settings
  6. import nest_asyncio
  7. import asyncio
  8. # Apply nest_asyncio to allow nested event loops
  9. nest_asyncio.apply()
  10. class R2R:
  11. client: R2RClient
  12. def init(self):
  13. # self.client = R2RClient(tool_settings.R2R_BASE_URL)
  14. self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
  15. self.client = None
  16. loop = asyncio.get_event_loop()
  17. if loop.is_running():
  18. return loop.create_task(self._login()) # 在现有事件循环中运行异步任务
  19. else:
  20. return asyncio.run(self._login()) # 如果没有事件循环则创建一个新的
  21. # loop.create_task(self._login())
  22. def ingest_file(self, file_path: str, metadata: Optional[dict]):
  23. self._check_login()
  24. ingest_response = self.client.documents.create(
  25. file_path=file_path, metadata=metadata if metadata else None, id=None
  26. )
  27. return ingest_response.get("results")
  28. def search(self, query: str, filters: dict[str, Any]):
  29. self._check_login()
  30. print(
  31. "aaaaaaacccccccccccccccccccccccccccccccccccccccccvvvvvvvvvvvvvvvvvvvvvvvvvvvvv"
  32. )
  33. print(filters)
  34. print(tool_settings.R2R_SEARCH_LIMIT)
  35. search_response = self.client.retrieval.search(
  36. query=query,
  37. search_settings={
  38. "filters": filters,
  39. "limit": tool_settings.R2R_SEARCH_LIMIT,
  40. # ,"do_hybrid_search": True,
  41. },
  42. )
  43. print(search_response)
  44. return search_response.get("results").get("chunk_search_results")
  45. # @pytest.fixture(scope="session")
  46. async def _login(self):
  47. if not self.auth_enabled:
  48. return
  49. print(
  50. "client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>"
  51. )
  52. print(self.client)
  53. if not self.client:
  54. self.client = R2RClient(tool_settings.R2R_BASE_URL)
  55. result = await self.client.users.login(
  56. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  57. ) # 同步调用异步函数
  58. # return self.client
  59. def _check_login(self):
  60. if not self.auth_enabled:
  61. return
  62. if verify_jwt_expiration(self.client.access_token):
  63. return
  64. else:
  65. loop = asyncio.get_event_loop()
  66. if loop.is_running():
  67. return loop.create_task(self._login()) # 在现有事件循环中运行异步任务
  68. else:
  69. return asyncio.run(self._login()) # 如果没有事件循环则创建一个新的
  70. # loop.create_task(self._login())
  71. # self._login()
  72. r2r = R2R()
  73. r2r.init() # 运行异步函数