r2r.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. from typing import Optional, Any
  2. from r2r import R2RAsyncClient
  3. from r2r import R2RClient
  4. from fastapi import UploadFile
  5. from app.libs.util import verify_jwt_expiration
  6. from config.llm import tool_settings
  7. import uuid
  8. # import nest_asyncio
  9. # 使得异步代码可以在已运行的事件循环中嵌套
  10. # nest_asyncio.apply()
  11. client = R2RAsyncClient(tool_settings.R2R_BASE_URL, 300)
  12. client_sync = R2RClient(tool_settings.R2R_BASE_URL, 300)
  13. class R2R:
  14. client: R2RAsyncClient
  15. client_sync: R2RClient
  16. def __init__(self):
  17. self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
  18. # self.client = R2RAsyncClient(tool_settings.R2R_BASE_URL)
  19. # self.client_sync = R2RClient(tool_settings.R2R_BASE_URL)
  20. self.client = client
  21. self.client_sync = client_sync
  22. def init_sync(self):
  23. if not self.auth_enabled:
  24. return
  25. # if not self.client_sync:
  26. # client_sync = R2RClient(tool_settings.R2R_BASE_URL)
  27. self.client_sync.users.login(
  28. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  29. )
  30. print(
  31. "1111111111111111111111111111111122222vvdgdfdf" + tool_settings.R2R_USERNAME
  32. )
  33. # print(tool_settings.R2R_USERNAME)
  34. # print(tool_settings.R2R_PASSWORD)
  35. print(self.client_sync)
  36. return self.client_sync
  37. async def init(self):
  38. if not self.auth_enabled:
  39. return
  40. # if not self.client:
  41. print(
  42. "1111111111111111111111111111111122222vvdgdfdf" + tool_settings.R2R_USERNAME
  43. )
  44. print(tool_settings.R2R_USERNAME)
  45. print(tool_settings.R2R_PASSWORD)
  46. # client = R2RAsyncClient(tool_settings.R2R_BASE_URL)
  47. await self.client.users.login(
  48. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  49. )
  50. print(self.client.access_token)
  51. return self.client
  52. async def ingest_file(self, file_path: str, metadata: Optional[dict]):
  53. client = await self._check_login()
  54. return await client.documents.create(
  55. file_path=file_path,
  56. metadata=metadata if metadata else None,
  57. ingestion_mode="fast",
  58. id=uuid.uuid4(),
  59. run_with_orchestration=False,
  60. )
  61. async def ingest_fileinfo(self, file: UploadFile, metadata: Optional[dict]):
  62. client = await self._check_login()
  63. return await client.documents.create(
  64. file=file,
  65. metadata=metadata if metadata else None,
  66. id=None,
  67. )
  68. def search(self, query: str, filters: dict[str, Any]):
  69. client = self._check_login_sync()
  70. print(
  71. "aaaaaaaaaaaaaaaaaaaaaaaaaaaasssssssssssssssssssssssssssssssssssssssssgggggggggggggggggggg"
  72. )
  73. search_response = client.retrieval.search(
  74. query=query,
  75. # search_mode="basic",
  76. search_settings={
  77. "filters": filters,
  78. "limit": tool_settings.R2R_SEARCH_LIMIT,
  79. },
  80. )
  81. print("vvvvvvvvvvvvvvvvvvmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm")
  82. # print(search_response)
  83. print(search_response.results)
  84. return search_response.results.chunk_search_results
  85. def ingest_file_sync(self, file_path: str, metadata: Optional[dict]):
  86. client = self._check_login_sync()
  87. return client.documents.create(
  88. file_path=file_path,
  89. metadata=metadata if metadata else None,
  90. ingestion_mode="fast",
  91. id=None,
  92. run_with_orchestration=False,
  93. )
  94. def list_chunks(self, ids: list[str] = []):
  95. client = self._check_login_sync()
  96. print(
  97. "retrieve_documentsretrieve_documentsretrieve_documentsretrieve_documentsretrieve_documents"
  98. )
  99. print(ids)
  100. allfile = []
  101. for id in ids:
  102. listed = client.documents.list_chunks(id=id)
  103. allfile += listed.results
  104. return allfile
  105. def list_documents(
  106. self,
  107. id: Optional[str] = "",
  108. offset: Optional[int] = 0,
  109. limit: Optional[int] = 100,
  110. ):
  111. client = self._check_login_sync()
  112. """
  113. docs = client.collections.list_documents(empty_coll_id).results
  114. assert len(docs) == 0, "Expected no documents in a new empty collection"
  115. """
  116. print(
  117. "collectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollections"
  118. )
  119. if id != "":
  120. try:
  121. listed = client.collections.list_documents(
  122. id=id, limit=limit, offset=offset
  123. )
  124. print(listed.results)
  125. return listed.results
  126. except Exception as e:
  127. print(e)
  128. listed = []
  129. return listed
  130. else:
  131. return []
  132. async def _check_login(self):
  133. if not self.auth_enabled:
  134. return
  135. if self.client.access_token and verify_jwt_expiration(self.client.access_token):
  136. return self.client
  137. else:
  138. return await self.init()
  139. def _check_login_sync(self):
  140. print("access_tokenaccess_tokenaccess_tokenaccess_token")
  141. # print(client_sync)
  142. if not self.auth_enabled:
  143. return
  144. try:
  145. if self.client_sync.access_token and verify_jwt_expiration(
  146. self.client_sync.access_token
  147. ):
  148. print(self.client_sync.access_token)
  149. return self.client_sync
  150. except Exception as e:
  151. print(e)
  152. return self.init_sync()
  153. # 创建 R2R 实例
  154. # 在您的应用程序启动时调用 initialize_r2r()
  155. # async def initialize_r2r():
  156. # await r2r.init()