r2r.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from typing import Optional, Any
  2. import pytest
  3. from r2r import R2RClient
  4. from app.libs.util import verify_jwt_expiration
  5. from config.llm import tool_settings
  6. import nest_asyncio
  7. import asyncio
  8. nest_asyncio.apply()
  9. class R2R:
  10. client: R2RClient
  11. def init(self):
  12. # self.client = R2RClient(tool_settings.R2R_BASE_URL)
  13. self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
  14. self.client = None
  15. loop = asyncio.get_event_loop()
  16. if loop.is_running():
  17. return loop.create_task(self._login()) # 在现有事件循环中运行异步任务
  18. else:
  19. return asyncio.run(self._login()) # 如果没有事件循环则创建一个新的
  20. #loop.create_task(self._login())
  21. def ingest_file(self, file_path: str, metadata: Optional[dict]):
  22. self._check_login()
  23. ingest_response = self.client.documents.create(
  24. file_path=file_path, metadata=metadata if metadata else None, id=None
  25. )
  26. return ingest_response.get("results")
  27. def search(self, query: str, filters: dict[str, Any]):
  28. self._check_login()
  29. print("aaaaaaacccccccccccccccccccccccccccccccccccccccccvvvvvvvvvvvvvvvvvvvvvvvvvvvvv")
  30. print(filters)
  31. print(tool_settings.R2R_SEARCH_LIMIT)
  32. search_response = self.client.retrieval.search(
  33. query=query,
  34. search_settings={
  35. "filters": filters,
  36. "limit": tool_settings.R2R_SEARCH_LIMIT,
  37. # ,"do_hybrid_search": True,
  38. },
  39. )
  40. print(search_response)
  41. return search_response.get("results").get("chunk_search_results")
  42. #@pytest.fixture(scope="session")
  43. async def _login(self):
  44. if not self.auth_enabled:
  45. return
  46. if not self.client:
  47. self.client = R2RClient(tool_settings.R2R_BASE_URL)
  48. result = self.client.users.login(tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD) # 同步调用异步函数
  49. #self.client.users.login(tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD)
  50. #return self.client
  51. def _check_login(self):
  52. if not self.auth_enabled:
  53. return
  54. if verify_jwt_expiration(self.client.access_token):
  55. return
  56. else:
  57. loop = asyncio.get_event_loop()
  58. if loop.is_running():
  59. return loop.create_task(self._login()) # 在现有事件循环中运行异步任务
  60. else:
  61. return asyncio.run(self._login()) # 如果没有事件循环则创建一个新的
  62. #loop.create_task(self._login())
  63. #self._login()
  64. r2r = R2R()
  65. #async def run_async():
  66. r2r.init() # 运行异步函数
  67. #asyncio.run(run_async())