r2r.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from typing import Optional, Any
  2. from r2r import R2RClient
  3. from app.libs.util import verify_jwt_expiration
  4. from config.llm import tool_settings
  5. import nest_asyncio
  6. import asyncio
  7. # Apply nest_asyncio to allow nested event loops
  8. nest_asyncio.apply()
  9. class R2R:
  10. client: R2RClient
  11. def __init__(self):
  12. self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
  13. self.client = None
  14. def init(self):
  15. if not self.auth_enabled:
  16. return
  17. if not self.client:
  18. self.client = R2RClient(tool_settings.R2R_BASE_URL)
  19. asyncio.run(
  20. self.client.users.login(
  21. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  22. )
  23. )
  24. def ingest_file(self, file_path: str, metadata: Optional[dict]):
  25. self._check_login()
  26. ingest_response = self.client.documents.create(
  27. file_path=file_path, metadata=metadata if metadata else None, id=None
  28. )
  29. return ingest_response.get("results")
  30. def search(self, query: str, filters: dict[str, Any]):
  31. self._check_login()
  32. search_response = self.client.retrieval.search(
  33. query=query,
  34. search_settings={
  35. "filters": filters,
  36. "limit": tool_settings.R2R_SEARCH_LIMIT,
  37. },
  38. )
  39. return search_response.get("results").get("chunk_search_results")
  40. def _check_login(self):
  41. if not self.auth_enabled:
  42. return
  43. if verify_jwt_expiration(self.client.access_token):
  44. return
  45. else:
  46. self.init()
  47. # 创建 R2R 实例
  48. r2r = R2R()