r2r2.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from typing import Optional, Any
  2. import pytest
  3. from r2r import R2RClient
  4. from app.libs.util import verify_jwt_expiration
  5. from config.llm import tool_settings
  6. import nest_asyncio
  7. import asyncio
  8. nest_asyncio.apply()
  9. class R2R:
  10. client: R2RClient
  11. async def init(self):
  12. # self.client = R2RClient(tool_settings.R2R_BASE_URL)
  13. self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
  14. await self._login()
  15. def ingest_file(self, file_path: str, metadata: Optional[dict]):
  16. self._check_login()
  17. ingest_response = self.client.documents.create(
  18. file_path=file_path, metadata=metadata if metadata else None, id=None
  19. )
  20. return ingest_response.get("results")
  21. def search(self, query: str, filters: dict[str, Any]):
  22. self._check_login()
  23. search_response = self.client.retrieval.search(
  24. query=query,
  25. search_settings={
  26. "filters": filters,
  27. "limit": 10 #tool_settings.R2R_SEARCH_LIMIT,
  28. # ,"do_hybrid_search": True,
  29. },
  30. )
  31. #print(search_response)
  32. return search_response.get("results").get("chunk_search_results")
  33. #@pytest.fixture(scope="session")
  34. async def _login(self):
  35. if not self.auth_enabled:
  36. return
  37. self.client = R2RClient(tool_settings.R2R_BASE_URL)
  38. self.client.users.login(tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD)
  39. return self.client
  40. def _check_login(self):
  41. if not self.auth_enabled:
  42. return
  43. if verify_jwt_expiration(self.client.access_token):
  44. return
  45. else:
  46. self._login()
  47. r2r = R2R()
  48. def run_async():
  49. asyncio.run(r2r.init()) # 运行异步函数
  50. print(r2r.search("文档内容", {'file_key': {'$in': ['5a098a2c-e1ef-40b0-84a0-d6606261fee0-test.txt']}}))
  51. run_async()