r2r_file.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import tempfile
  2. import uuid
  3. from typing import List
  4. import aiofiles
  5. import aiofiles.os
  6. from fastapi import UploadFile
  7. from sqlalchemy.ext.asyncio import AsyncSession
  8. import os
  9. from app.models import File
  10. from app.providers.r2r import R2R
  11. from app.providers.storage import storage
  12. from app.services.file.impl.oss_file import OSSFileService
  13. # import asyncio
  14. from pathlib import Path
  15. allowed_formats = [".bmp", ".heic", ".jpeg", ".png", ".tiff"]
  16. class R2RFileService(OSSFileService):
  17. @staticmethod
  18. async def create_file(
  19. *, session: AsyncSession, purpose: str, file: UploadFile
  20. ) -> File:
  21. # 文件是否存在
  22. # statement = (
  23. # select(File)
  24. # .where(File.purpose == purpose)
  25. # .where(File.filename == file.filename)
  26. # .where(File.bytes == file.size)
  27. # )
  28. # result = await session.execute(statement)
  29. # ext_file = result.scalars().first()
  30. # if ext_file is not None:
  31. # # TODO: 文件去重策略
  32. # return ext_file
  33. file_extension = Path(file.filename).suffix
  34. file_key = f"{uuid.uuid4()}{file_extension}"
  35. print(file_key)
  36. fileinfo = {"document_id": file_key}
  37. # file_key = f"{uuid.uuid4()}-{file.filename}"
  38. with tempfile.NamedTemporaryFile(
  39. suffix="_" + file.filename, delete=True
  40. ) as temp_file:
  41. tmp_file_path = temp_file.name
  42. async with aiofiles.open(tmp_file_path, "wb") as f:
  43. while content := await file.read(1024):
  44. await f.write(content)
  45. # storage.save_from_path(filename=file_key, local_file_path=tmp_file_path)
  46. r2r = R2R()
  47. # await r2r.init()
  48. fileinfo = await r2r.ingest_file(
  49. file_path=tmp_file_path,
  50. metadata={"file_key": file_key, "title": file.filename},
  51. )
  52. fileinfo = fileinfo.results
  53. # 存储
  54. db_file = File(
  55. purpose=purpose,
  56. filename=file.filename,
  57. bytes=file.size,
  58. key=fileinfo.document_id,
  59. )
  60. session.add(db_file)
  61. await session.commit()
  62. await session.refresh(db_file)
  63. return db_file
  64. @staticmethod
  65. def search_in_files(
  66. query: str, file_keys: List[str], folder_keys: List[str] = None
  67. ) -> dict:
  68. files = []
  69. # {}
  70. file_key = {"$in": []}
  71. document_id = {"$in": []}
  72. filters = {"$or": []}
  73. print(
  74. "ggggggggggggggggggggggggggggggggggggddddddddddddddddddccccccccccccccccccccc"
  75. )
  76. for key in file_keys:
  77. if len(key) == 36:
  78. document_id["$in"].append(key)
  79. else:
  80. file_key["$in"].append(key)
  81. if len(document_id["$in"]) > 0:
  82. filters["$or"].append({"document_id": document_id})
  83. if len(file_key["$in"]) > 0:
  84. filters["$or"].append({"file_key": file_key})
  85. print(file_key)
  86. print(document_id)
  87. print(filters)
  88. print(folder_keys)
  89. if folder_keys:
  90. filters["$or"].append({"collection_ids": {"$overlap": folder_keys}})
  91. ## {"$or": [filters, {"collection_ids": {"$in": folder_keys}}]}
  92. ##filters["collection_ids"] = {"$overlap": folder_keys}
  93. ## {"$and": {"$document_id": ..., "collection_ids": ...}}
  94. """
  95. {
  96. "$or": [
  97. {"document_id": {"$eq": "9fbe403b-..."}},
  98. {"collection_ids": {"$in": ["122fdf6a-...", "..."]}}
  99. ]
  100. }
  101. """
  102. print(filters)
  103. if len(filters["$or"]) < 2:
  104. filters = filters["$or"][0]
  105. print("filtersfiltersfiltersfiltersfiltersfiltersfiltersfiltersfiltersfilters")
  106. print(filters)
  107. """
  108. loop = asyncio.get_event_loop() # 获取当前事件循环
  109. loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
  110. search_results = loop.run_until_complete(r2r.search(query, filters=filters))
  111. asyncio.run(r2r.init())
  112. search_results = asyncio.run(r2r.search(query, filters=filters))
  113. search_results = loop.run_until_complete(
  114. r2r.search(query, filters={"file_key": {"$in": file_keys}})
  115. )
  116. """
  117. r2r = R2R()
  118. # r2r.init_sync()
  119. search_results = r2r.search(query, filters=filters)
  120. if not search_results:
  121. return files
  122. for doc in search_results:
  123. file_extension = os.path.splitext(doc.metadata["title"])[1].lower()
  124. if file_extension in allowed_formats:
  125. files.append(
  126. {
  127. "id": str(doc.id),
  128. "text": doc.text,
  129. "title": doc.metadata["title"],
  130. "url": "https://r2r.s3.cn-north-1.amazonaws.com.cn/r2r/documents/"
  131. + doc.document_id,
  132. }
  133. )
  134. # print(doc.metadata)
  135. else:
  136. files.append(
  137. {
  138. "id": str(doc.id),
  139. "text": doc.text,
  140. "title": doc.metadata["title"],
  141. }
  142. )
  143. # file_key = doc.metadata.file_key
  144. # file_key = doc.metadata.title if file_key is None else file_key
  145. # text = doc.text
  146. # if "text" in files:
  147. # files["text"] += f"\n\n{text}"
  148. # else:
  149. # files["text"] = text
  150. # print("aaaaaaaaaaaaaa")
  151. # print(files)
  152. return files
  153. @staticmethod
  154. def list_in_files(
  155. ids: list[str] = None,
  156. offset: int = 0,
  157. limit: int = 100,
  158. ) -> dict:
  159. """
  160. loop = asyncio.get_event_loop() # 获取当前事件循环
  161. loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
  162. list_results = loop.run_until_complete(
  163. r2r.list(ids=ids, offset=offset, limit=limit)
  164. )
  165. asyncio.run(r2r.init())
  166. list_results = asyncio.run(r2r.list(ids=ids, offset=offset, limit=limit))
  167. """
  168. r2r = R2R()
  169. # r2r.init_sync()
  170. list_results = r2r.list(ids=ids, offset=offset, limit=limit)
  171. return list_results
  172. @staticmethod
  173. def list_documents(
  174. id: str = "",
  175. offset: int = 0,
  176. limit: int = 100,
  177. ) -> dict:
  178. """
  179. loop = asyncio.get_event_loop() # 获取当前事件循环
  180. loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
  181. list_results = loop.run_until_complete(
  182. r2r.list_documents(id=id, offset=offset, limit=limit)
  183. )
  184. asyncio.run(r2r.init())
  185. list_results = asyncio.run(
  186. r2r.list_documents(id=id, offset=offset, limit=limit)
  187. )
  188. """
  189. r2r = R2R()
  190. # r2r.init_sync()
  191. list_results = r2r.list_documents(id=id, offset=offset, limit=limit)
  192. return list_results
  193. @staticmethod
  194. def list_chunks(ids: list[str]) -> dict:
  195. if len(ids) > 0:
  196. r2r = R2R()
  197. # r2r.init_sync()
  198. print("list_chunkslist_chunkslist_chunkslist_chunkslist_chunkslist_chunks")
  199. list_results = r2r.list_chunks(ids=ids)
  200. print(list_results)
  201. files = {}
  202. for doc in list_results:
  203. text = doc.text
  204. if "text" in files:
  205. files["text"] += f"\n\n{text}"
  206. else:
  207. files["text"] = text
  208. print(files)
  209. return files
  210. else:
  211. return {}
  212. # TODO 删除s3&r2r文件