r2r.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from typing import Optional, Any
  2. from r2r import R2RAsyncClient
  3. from r2r import R2RClient
  4. from fastapi import UploadFile
  5. from app.libs.util import verify_jwt_expiration
  6. from config.llm import tool_settings
  7. import nest_asyncio
  8. # 使得异步代码可以在已运行的事件循环中嵌套
  9. nest_asyncio.apply()
  10. class R2R:
  11. client: R2RAsyncClient
  12. client_sync: R2RClient
  13. def __init__(self):
  14. self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
  15. # self.client = R2RAsyncClient(tool_settings.R2R_BASE_URL)
  16. # self.client_sync = R2RClient(tool_settings.R2R_BASE_URL)
  17. self.client = R2RAsyncClient(tool_settings.R2R_BASE_URL)
  18. self.client_sync = R2RClient(tool_settings.R2R_BASE_URL)
  19. def init_sync(self):
  20. if not self.auth_enabled:
  21. return
  22. # if not self.client_sync:
  23. # client_sync = R2RClient(tool_settings.R2R_BASE_URL)
  24. self.client_sync.users.login(
  25. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  26. )
  27. print(
  28. "1111111111111111111111111111111122222vvdgdfdf" + tool_settings.R2R_USERNAME
  29. )
  30. # print(tool_settings.R2R_USERNAME)
  31. # print(tool_settings.R2R_PASSWORD)
  32. print(self.client_sync)
  33. return self.client_sync
  34. async def init(self):
  35. if not self.auth_enabled:
  36. return
  37. # if not self.client:
  38. print(
  39. "1111111111111111111111111111111122222vvdgdfdf" + tool_settings.R2R_USERNAME
  40. )
  41. print(tool_settings.R2R_USERNAME)
  42. print(tool_settings.R2R_PASSWORD)
  43. # client = R2RAsyncClient(tool_settings.R2R_BASE_URL)
  44. await self.client.users.login(
  45. tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
  46. )
  47. print(self.client.access_token)
  48. return self.client
  49. async def ingest_file(self, file_path: str, metadata: Optional[dict]):
  50. client = await self._check_login()
  51. return await client.documents.create(
  52. file_path=file_path,
  53. metadata=metadata if metadata else None,
  54. ingestion_mode="fast",
  55. id=None,
  56. )
  57. async def ingest_fileinfo(self, file: UploadFile, metadata: Optional[dict]):
  58. client = await self._check_login()
  59. return await client.documents.create(
  60. file=file,
  61. metadata=metadata if metadata else None,
  62. id=None,
  63. )
  64. def search(self, query: str, filters: dict[str, Any]):
  65. client = self._check_login_sync()
  66. print(
  67. "aaaaaaaaaaaaaaaaaaaaaaaaaaaasssssssssssssssssssssssssssssssssssssssssgggggggggggggggggggg"
  68. )
  69. search_response = client.retrieval.search(
  70. query=query,
  71. # search_mode="basic",
  72. search_settings={
  73. "filters": filters,
  74. "limit": tool_settings.R2R_SEARCH_LIMIT,
  75. },
  76. )
  77. print("vvvvvvvvvvvvvvvvvvmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm")
  78. # print(search_response)
  79. print(search_response.results)
  80. return search_response.results.chunk_search_results
  81. def list_chunks(self, ids: list[str] = []):
  82. client = self._check_login_sync()
  83. print(
  84. "retrieve_documentsretrieve_documentsretrieve_documentsretrieve_documentsretrieve_documents"
  85. )
  86. print(ids)
  87. allfile = []
  88. for id in ids:
  89. listed = client.documents.list_chunks(id=id)
  90. allfile += listed.results
  91. return allfile
  92. def list_documents(
  93. self,
  94. id: Optional[str] = "",
  95. offset: Optional[int] = 0,
  96. limit: Optional[int] = 100,
  97. ):
  98. client = self._check_login_sync()
  99. """
  100. docs = client.collections.list_documents(empty_coll_id).results
  101. assert len(docs) == 0, "Expected no documents in a new empty collection"
  102. """
  103. print(
  104. "collectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollections"
  105. )
  106. if id != "":
  107. try:
  108. listed = client.collections.list_documents(
  109. id=id, limit=limit, offset=offset
  110. )
  111. print(listed.results)
  112. return listed.results
  113. except Exception as e:
  114. print(e)
  115. listed = []
  116. return listed
  117. else:
  118. return []
  119. async def _check_login(self):
  120. if not self.auth_enabled:
  121. return
  122. # if self.client.access_token and verify_jwt_expiration(self.client.access_token):
  123. # return
  124. # else:
  125. return await self.init()
  126. def _check_login_sync(self):
  127. print("access_tokenaccess_tokenaccess_tokenaccess_token")
  128. # print(client_sync)
  129. if not self.auth_enabled:
  130. return
  131. # try:
  132. # if self.client_sync.access_token and verify_jwt_expiration(
  133. # self.client_sync.access_token
  134. # ):
  135. # print(self.client_sync.access_token)
  136. # return
  137. # except Exception as e:
  138. # print(e)
  139. return self.init_sync()
  140. # 创建 R2R 实例
  141. r2r = R2R()
  142. # 在您的应用程序启动时调用 initialize_r2r()
  143. async def initialize_r2r():
  144. await r2r.init()