r2r_file.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import tempfile
  2. import uuid
  3. from typing import List
  4. import aiofiles
  5. import aiofiles.os
  6. from fastapi import UploadFile
  7. from sqlalchemy.ext.asyncio import AsyncSession
  8. from app.models import File
  9. from app.providers.r2r import R2R
  10. from app.providers.storage import storage
  11. from app.services.file.impl.oss_file import OSSFileService
  12. # import asyncio
  13. from pathlib import Path
  14. class R2RFileService(OSSFileService):
  15. @staticmethod
  16. async def create_file(
  17. *, session: AsyncSession, purpose: str, file: UploadFile
  18. ) -> File:
  19. # 文件是否存在
  20. # statement = (
  21. # select(File)
  22. # .where(File.purpose == purpose)
  23. # .where(File.filename == file.filename)
  24. # .where(File.bytes == file.size)
  25. # )
  26. # result = await session.execute(statement)
  27. # ext_file = result.scalars().first()
  28. # if ext_file is not None:
  29. # # TODO: 文件去重策略
  30. # return ext_file
  31. file_extension = Path(file.filename).suffix
  32. file_key = f"{uuid.uuid4()}{file_extension}"
  33. print(file_key)
  34. fileinfo = {"document_id": file_key}
  35. # file_key = f"{uuid.uuid4()}-{file.filename}"
  36. with tempfile.NamedTemporaryFile(
  37. suffix="_" + file.filename, delete=True
  38. ) as temp_file:
  39. tmp_file_path = temp_file.name
  40. async with aiofiles.open(tmp_file_path, "wb") as f:
  41. while content := await file.read(1024):
  42. await f.write(content)
  43. # storage.save_from_path(filename=file_key, local_file_path=tmp_file_path)
  44. r2r = R2R()
  45. # await r2r.init()
  46. fileinfo = r2r.ingest_file_sync(
  47. file_path=tmp_file_path,
  48. metadata={"file_key": file_key, "title": file.filename},
  49. )
  50. fileinfo = fileinfo.results
  51. # 存储
  52. db_file = File(
  53. purpose=purpose,
  54. filename=file.filename,
  55. bytes=file.size,
  56. key=fileinfo.document_id,
  57. )
  58. session.add(db_file)
  59. await session.commit()
  60. await session.refresh(db_file)
  61. return db_file
  62. @staticmethod
  63. def search_in_files(
  64. query: str, file_keys: List[str], folder_keys: List[str] = None
  65. ) -> dict:
  66. files = {}
  67. file_key = {"$in": []}
  68. document_id = {"$in": []}
  69. filters = {"$or": []}
  70. print(
  71. "ggggggggggggggggggggggggggggggggggggddddddddddddddddddccccccccccccccccccccc"
  72. )
  73. for key in file_keys:
  74. if len(key) == 36:
  75. document_id["$in"].append(key)
  76. else:
  77. file_key["$in"].append(key)
  78. if len(document_id["$in"]) > 0:
  79. filters["$or"].append({"document_id": document_id})
  80. if len(file_key["$in"]) > 0:
  81. filters["$or"].append({"file_key": file_key})
  82. print(file_key)
  83. print(document_id)
  84. print(filters)
  85. print(folder_keys)
  86. if folder_keys:
  87. filters["$or"].append({"collection_ids": {"$overlap": folder_keys}})
  88. ## {"$or": [filters, {"collection_ids": {"$in": folder_keys}}]}
  89. ##filters["collection_ids"] = {"$overlap": folder_keys}
  90. ## {"$and": {"$document_id": ..., "collection_ids": ...}}
  91. """
  92. {
  93. "$or": [
  94. {"document_id": {"$eq": "9fbe403b-..."}},
  95. {"collection_ids": {"$in": ["122fdf6a-...", "..."]}}
  96. ]
  97. }
  98. """
  99. print(filters)
  100. if len(filters["$or"]) < 2:
  101. filters = filters["$or"][0]
  102. print("filtersfiltersfiltersfiltersfiltersfiltersfiltersfiltersfiltersfilters")
  103. print(filters)
  104. """
  105. loop = asyncio.get_event_loop() # 获取当前事件循环
  106. loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
  107. search_results = loop.run_until_complete(r2r.search(query, filters=filters))
  108. asyncio.run(r2r.init())
  109. search_results = asyncio.run(r2r.search(query, filters=filters))
  110. search_results = loop.run_until_complete(
  111. r2r.search(query, filters={"file_key": {"$in": file_keys}})
  112. )
  113. """
  114. r2r = R2R()
  115. # r2r.init_sync()
  116. search_results = r2r.search(query, filters=filters)
  117. if not search_results:
  118. return files
  119. for doc in search_results:
  120. # file_key = doc.metadata.file_key
  121. # file_key = doc.metadata.title if file_key is None else file_key
  122. text = doc.text
  123. if "text" in files:
  124. files["text"] += f"\n\n{text}"
  125. else:
  126. files["text"] = text
  127. print("aaaaaaaaaaaaaa")
  128. print(files)
  129. return files
  130. @staticmethod
  131. def list_in_files(
  132. ids: list[str] = None,
  133. offset: int = 0,
  134. limit: int = 100,
  135. ) -> dict:
  136. """
  137. loop = asyncio.get_event_loop() # 获取当前事件循环
  138. loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
  139. list_results = loop.run_until_complete(
  140. r2r.list(ids=ids, offset=offset, limit=limit)
  141. )
  142. asyncio.run(r2r.init())
  143. list_results = asyncio.run(r2r.list(ids=ids, offset=offset, limit=limit))
  144. """
  145. r2r = R2R()
  146. # r2r.init_sync()
  147. list_results = r2r.list(ids=ids, offset=offset, limit=limit)
  148. return list_results
  149. @staticmethod
  150. def list_documents(
  151. id: str = "",
  152. offset: int = 0,
  153. limit: int = 100,
  154. ) -> dict:
  155. """
  156. loop = asyncio.get_event_loop() # 获取当前事件循环
  157. loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
  158. list_results = loop.run_until_complete(
  159. r2r.list_documents(id=id, offset=offset, limit=limit)
  160. )
  161. asyncio.run(r2r.init())
  162. list_results = asyncio.run(
  163. r2r.list_documents(id=id, offset=offset, limit=limit)
  164. )
  165. """
  166. r2r = R2R()
  167. # r2r.init_sync()
  168. list_results = r2r.list_documents(id=id, offset=offset, limit=limit)
  169. return list_results
  170. @staticmethod
  171. def list_chunks(ids: list[str]) -> dict:
  172. if len(ids) > 0:
  173. r2r = R2R()
  174. # r2r.init_sync()
  175. list_results = r2r.list_chunks(ids=ids)
  176. files = {}
  177. for doc in list_results:
  178. # file_key = doc.metadata.file_key
  179. # file_key = doc.metadata.title if file_key is None else file_key
  180. text = doc.text
  181. if "text" in files:
  182. files["text"] += f"\n\n{text}"
  183. else:
  184. files["text"] = text
  185. return files
  186. return {}
  187. # TODO 删除s3&r2r文件