r2r.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from typing import Optional, Any
  2. from r2r import R2RAsyncClient
  3. from app.libs.util import verify_jwt_expiration
  4. from config.llm import tool_settings
  5. # import nest_asyncio
  6. import asyncio
  7. # Apply nest_asyncio to allow nested event loops
  8. # nest_asyncio.apply()
  9. class R2R:
  10. client: R2RAsyncClient
  11. def __init__(self):
  12. self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
  13. self.client = None
  14. async def init(self):
  15. if not self.auth_enabled:
  16. return
  17. if not self.client:
  18. self.client = R2RAsyncClient(tool_settings.R2R_BASE_URL, "/v3")
  19. await self.client.users.login(
  20. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  21. )
  22. print(self.client.access_token)
  23. def ingest_file(self, file_path: str, metadata: Optional[dict]):
  24. self._check_login()
  25. ingest_response = asyncio.run(
  26. self.client.documents.create(
  27. file_path=file_path, metadata=metadata if metadata else None, id=None
  28. )
  29. )
  30. return ingest_response.get("results")
  31. def search(self, query: str, filters: dict[str, Any]):
  32. self._check_login()
  33. search_response = asyncio.run(
  34. self.client.retrieval.search(
  35. query=query,
  36. search_settings={
  37. "filters": filters,
  38. "limit": tool_settings.R2R_SEARCH_LIMIT,
  39. },
  40. )
  41. )
  42. return search_response.get("results").get("chunk_search_results")
  43. def _check_login(self):
  44. if not self.auth_enabled:
  45. return
  46. if not self.client.access_token and verify_jwt_expiration(
  47. self.client.access_token
  48. ):
  49. return
  50. else:
  51. asyncio.create_task(self.init())
  52. # 创建 R2R 实例
  53. r2r = R2R()
  54. # 在您的应用程序启动时调用 initialize_r2r()
  55. async def initialize_r2r():
  56. await r2r.init()