r2r.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from typing import Optional, Any
  2. from r2r import R2RAsyncClient
  3. from app.libs.util import verify_jwt_expiration
  4. from config.llm import tool_settings
  5. import nest_asyncio
  6. # 使得异步代码可以在已运行的事件循环中嵌套
  7. nest_asyncio.apply()
  8. class R2R:
  9. client: R2RAsyncClient
  10. def __init__(self):
  11. self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
  12. self.client = None
  13. async def init(self):
  14. if not self.auth_enabled:
  15. return
  16. if not self.client:
  17. self.client = R2RAsyncClient(tool_settings.R2R_BASE_URL, "/v3")
  18. await self.client.users.login(
  19. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  20. )
  21. print(self.client.access_token)
  22. async def ingest_file(self, file_path: str, metadata: Optional[dict]):
  23. await self._check_login()
  24. return await self.client.documents.create(
  25. file_path=file_path,
  26. metadata=metadata if metadata else None,
  27. id=None,
  28. )
  29. async def search(self, query: str, filters: dict[str, Any]):
  30. await self._check_login()
  31. return await self.client.retrieval.search(
  32. query=query,
  33. search_settings={
  34. "filters": filters,
  35. "limit": tool_settings.R2R_SEARCH_LIMIT,
  36. },
  37. )
  38. async def _check_login(self):
  39. if not self.auth_enabled:
  40. return
  41. if not self.client.access_token and verify_jwt_expiration(
  42. self.client.access_token
  43. ):
  44. return
  45. else:
  46. await self.init()
  47. # 创建 R2R 实例
  48. r2r = R2R()
  49. # 在您的应用程序启动时调用 initialize_r2r()
  50. async def initialize_r2r():
  51. await r2r.init()