r2r.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. from typing import Optional, Any
  2. from r2r import R2RAsyncClient
  3. from r2r import R2RClient
  4. from fastapi import UploadFile
  5. from app.libs.util import verify_jwt_expiration
  6. from config.llm import tool_settings
  7. import nest_asyncio
  8. # 使得异步代码可以在已运行的事件循环中嵌套
  9. nest_asyncio.apply()
  10. class R2R:
  11. client: R2RAsyncClient
  12. client_sync: R2RClient
  13. def __init__(self):
  14. self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
  15. self.client = None
  16. self.client_sync = None
  17. def init_sync(self):
  18. if not self.auth_enabled:
  19. return
  20. if not self.client_sync:
  21. self.client_sync = R2RClient(tool_settings.R2R_BASE_URL, "/v3")
  22. self.client_sync.users.login(
  23. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  24. )
  25. print(
  26. "1111111111111111111111111111111122222vvdgdfdf" + tool_settings.R2R_USERNAME
  27. )
  28. print(tool_settings.R2R_USERNAME)
  29. print(tool_settings.R2R_PASSWORD)
  30. print(self.client_sync.access_token)
  31. async def init(self):
  32. if not self.auth_enabled:
  33. return
  34. if not self.client:
  35. self.client = R2RAsyncClient(tool_settings.R2R_BASE_URL, "/v3")
  36. print(
  37. "1111111111111111111111111111111122222vvdgdfdf" + tool_settings.R2R_USERNAME
  38. )
  39. print(tool_settings.R2R_USERNAME)
  40. print(tool_settings.R2R_PASSWORD)
  41. await self.client.users.login(
  42. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  43. )
  44. print(self.client.access_token)
  45. async def ingest_file(self, file_path: str, metadata: Optional[dict]):
  46. await self._check_login()
  47. return await self.client.documents.create(
  48. file_path=file_path,
  49. metadata=metadata if metadata else None,
  50. ingestion_mode="fast",
  51. id=None,
  52. )
  53. async def ingest_fileinfo(self, file: UploadFile, metadata: Optional[dict]):
  54. await self._check_login()
  55. return await self.client.documents.create(
  56. file=file,
  57. metadata=metadata if metadata else None,
  58. id=None,
  59. )
  60. def search(self, query: str, filters: dict[str, Any]):
  61. self._check_login_sync()
  62. print(
  63. "aaaaaaaaaaaaaaaaaaaaaaaaaaaasssssssssssssssssssssssssssssssssssssssssgggggggggggggggggggg"
  64. )
  65. search_response = self.client_sync.retrieval.search(
  66. query=query,
  67. search_settings={
  68. "filters": filters,
  69. "limit": tool_settings.R2R_SEARCH_LIMIT,
  70. },
  71. )
  72. print("vvvvvvvvvvvvvvvvvvmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm")
  73. print(search_response.get("results"))
  74. return search_response.get("results").get("chunk_search_results")
  75. def list(
  76. self,
  77. ids: Optional[list[str]] = None,
  78. offset: Optional[int] = 0,
  79. limit: Optional[int] = 100,
  80. ):
  81. self._check_login_sync()
  82. print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaassssssssssssssssssssssssssssssssssss")
  83. print(ids)
  84. """
  85. listed = mutable_client.documents.list(limit=2, offset=0)
  86. results = listed.results
  87. assert len(results) == 2, "Expected 2 results for paginated listing"
  88. """
  89. print("listlistlistlistlistlistlistlistlistlistlistlistlistlistlistlistlist")
  90. if len(ids) > 0:
  91. listed = self.client_sync.documents.list(
  92. ids=ids, limit=limit, offset=offset
  93. )
  94. print(listed.get("results"))
  95. return listed.get("results")
  96. else:
  97. return []
  98. def list_chunks(
  99. self,
  100. ids: list[str] = None,
  101. ):
  102. self._check_login_sync()
  103. print(
  104. "retrieve_documentsretrieve_documentsretrieve_documentsretrieve_documentsretrieve_documents"
  105. )
  106. print(ids)
  107. allfile = []
  108. for id in ids:
  109. listed = self.client_sync.documents.list_chunks(id=id)
  110. allfile += listed.get("results")
  111. return allfile
  112. def list_documents(
  113. self,
  114. id: Optional[str] = "",
  115. offset: Optional[int] = 0,
  116. limit: Optional[int] = 100,
  117. ):
  118. self._check_login_sync()
  119. """
  120. docs = client.collections.list_documents(empty_coll_id).results
  121. assert len(docs) == 0, "Expected no documents in a new empty collection"
  122. """
  123. print(
  124. "collectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollections"
  125. )
  126. if id != "":
  127. try:
  128. listed = self.client_sync.collections.list_documents(
  129. id=id, limit=limit, offset=offset
  130. )
  131. print(listed.get("results"))
  132. return listed.get("results")
  133. except Exception as e:
  134. print(e)
  135. listed = []
  136. return listed
  137. else:
  138. return []
  139. async def _check_login(self):
  140. if not self.auth_enabled:
  141. return
  142. if not self.client.access_token and verify_jwt_expiration(
  143. self.client.access_token
  144. ):
  145. return
  146. else:
  147. await self.init()
  148. def _check_login_sync(self):
  149. if not self.auth_enabled:
  150. return
  151. try:
  152. if not self.client_sync.access_token and verify_jwt_expiration(
  153. self.client_sync.access_token
  154. ):
  155. return
  156. except Exception as e:
  157. print(e)
  158. self.client_sync.users.login(
  159. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  160. )
  161. # 创建 R2R 实例
  162. r2r = R2R()
  163. # 在您的应用程序启动时调用 initialize_r2r()
  164. async def initialize_r2r():
  165. await r2r.init()