r2r.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from typing import Optional, Any
  2. from r2r import R2RAsyncClient
  3. from r2r import R2RClient
  4. from fastapi import UploadFile
  5. from app.libs.util import verify_jwt_expiration
  6. from config.llm import tool_settings
  7. # import nest_asyncio
  8. # 使得异步代码可以在已运行的事件循环中嵌套
  9. # nest_asyncio.apply()
  10. client = R2RAsyncClient(tool_settings.R2R_BASE_URL)
  11. client_sync = R2RClient(tool_settings.R2R_BASE_URL)
  12. class R2R:
  13. client: R2RAsyncClient
  14. client_sync: R2RClient
  15. def __init__(self):
  16. self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
  17. # self.client = R2RAsyncClient(tool_settings.R2R_BASE_URL)
  18. # self.client_sync = R2RClient(tool_settings.R2R_BASE_URL)
  19. self.client = client
  20. self.client_sync = client_sync
  21. def init_sync(self):
  22. if not self.auth_enabled:
  23. return
  24. # if not self.client_sync:
  25. # client_sync = R2RClient(tool_settings.R2R_BASE_URL)
  26. self.client_sync.users.login(
  27. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  28. )
  29. print(
  30. "1111111111111111111111111111111122222vvdgdfdf" + tool_settings.R2R_USERNAME
  31. )
  32. # print(tool_settings.R2R_USERNAME)
  33. # print(tool_settings.R2R_PASSWORD)
  34. print(self.client_sync)
  35. return self.client_sync
  36. async def init(self):
  37. if not self.auth_enabled:
  38. return
  39. # if not self.client:
  40. print(
  41. "1111111111111111111111111111111122222vvdgdfdf" + tool_settings.R2R_USERNAME
  42. )
  43. print(tool_settings.R2R_USERNAME)
  44. print(tool_settings.R2R_PASSWORD)
  45. # client = R2RAsyncClient(tool_settings.R2R_BASE_URL)
  46. await self.client.users.login(
  47. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  48. )
  49. print(self.client.access_token)
  50. return self.client
  51. async def ingest_file(self, file_path: str, metadata: Optional[dict]):
  52. client = await self._check_login()
  53. return await client.documents.create(
  54. file_path=file_path,
  55. metadata=metadata if metadata else None,
  56. ingestion_mode="fast",
  57. id=None,
  58. run_with_orchestration=False,
  59. )
  60. async def ingest_fileinfo(self, file: UploadFile, metadata: Optional[dict]):
  61. client = await self._check_login()
  62. return await client.documents.create(
  63. file=file,
  64. metadata=metadata if metadata else None,
  65. id=None,
  66. )
  67. def search(self, query: str, filters: dict[str, Any]):
  68. client = self._check_login_sync()
  69. print(
  70. "aaaaaaaaaaaaaaaaaaaaaaaaaaaasssssssssssssssssssssssssssssssssssssssssgggggggggggggggggggg"
  71. )
  72. search_response = client.retrieval.search(
  73. query=query,
  74. # search_mode="basic",
  75. search_settings={
  76. "filters": filters,
  77. "limit": tool_settings.R2R_SEARCH_LIMIT,
  78. },
  79. )
  80. print("vvvvvvvvvvvvvvvvvvmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm")
  81. # print(search_response)
  82. print(search_response.results)
  83. return search_response.results.chunk_search_results
  84. def ingest_file_sync(self, file_path: str, metadata: Optional[dict]):
  85. client = self._check_login_sync()
  86. return client.documents.create(
  87. file_path=file_path,
  88. metadata=metadata if metadata else None,
  89. ingestion_mode="fast",
  90. id=None,
  91. run_with_orchestration=False,
  92. )
  93. def list_chunks(self, ids: list[str] = []):
  94. client = self._check_login_sync()
  95. print(
  96. "retrieve_documentsretrieve_documentsretrieve_documentsretrieve_documentsretrieve_documents"
  97. )
  98. print(ids)
  99. allfile = []
  100. for id in ids:
  101. listed = client.documents.list_chunks(id=id)
  102. allfile += listed.results
  103. return allfile
  104. def list_documents(
  105. self,
  106. id: Optional[str] = "",
  107. offset: Optional[int] = 0,
  108. limit: Optional[int] = 100,
  109. ):
  110. client = self._check_login_sync()
  111. """
  112. docs = client.collections.list_documents(empty_coll_id).results
  113. assert len(docs) == 0, "Expected no documents in a new empty collection"
  114. """
  115. print(
  116. "collectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollections"
  117. )
  118. if id != "":
  119. try:
  120. listed = client.collections.list_documents(
  121. id=id, limit=limit, offset=offset
  122. )
  123. print(listed.results)
  124. return listed.results
  125. except Exception as e:
  126. print(e)
  127. listed = []
  128. return listed
  129. else:
  130. return []
  131. async def _check_login(self):
  132. if not self.auth_enabled:
  133. return
  134. if self.client.access_token and verify_jwt_expiration(self.client.access_token):
  135. return self.client
  136. else:
  137. return await self.init()
  138. def _check_login_sync(self):
  139. print("access_tokenaccess_tokenaccess_tokenaccess_token")
  140. # print(client_sync)
  141. if not self.auth_enabled:
  142. return
  143. try:
  144. if self.client_sync.access_token and verify_jwt_expiration(
  145. self.client_sync.access_token
  146. ):
  147. print(self.client_sync.access_token)
  148. return self.client_sync
  149. except Exception as e:
  150. print(e)
  151. return self.init_sync()
  152. # 创建 R2R 实例
  153. # 在您的应用程序启动时调用 initialize_r2r()
  154. # async def initialize_r2r():
  155. # await r2r.init()