r2r_file.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import tempfile
  2. import uuid
  3. from typing import List
  4. import aiofiles
  5. import aiofiles.os
  6. from fastapi import UploadFile
  7. from sqlalchemy.ext.asyncio import AsyncSession
  8. from app.models import File
  9. from app.providers.r2r import r2r
  10. from app.providers.storage import storage
  11. from app.services.file.impl.oss_file import OSSFileService
  12. # import asyncio
  13. from pathlib import Path
  14. import nest_asyncio
  15. # 使得异步代码可以在已运行的事件循环中嵌套
  16. nest_asyncio.apply()
  17. class R2RFileService(OSSFileService):
  18. '''
  19. @staticmethod
  20. def get_file_list_by_ids(*, session: Session, file_ids: List[str]) -> List[File]:
  21. if not file_ids:
  22. return []
  23. statement = select(File).where(col(File.id).in_(file_ids))
  24. return session.execute(statement).scalars().all()
  25. @staticmethod
  26. async def get_file_list(
  27. *, session: AsyncSession, purpose: str, file_ids: List[str]
  28. ) -> List[File]:
  29. statement = select(File)
  30. if purpose is not None and len(purpose) > 0:
  31. statement = statement.where(File.purpose == purpose)
  32. if file_ids is not None:
  33. statement = statement.where(File.id.in_(file_ids))
  34. statement = statement.order_by(desc(File.created_at))
  35. result = await session.execute(statement)
  36. return result.scalars().all()
  37. '''
  38. @staticmethod
  39. async def create_file(
  40. *, session: AsyncSession, purpose: str, file: UploadFile
  41. ) -> File:
  42. # 文件是否存在
  43. # statement = (
  44. # select(File)
  45. # .where(File.purpose == purpose)
  46. # .where(File.filename == file.filename)
  47. # .where(File.bytes == file.size)
  48. # )
  49. # result = await session.execute(statement)
  50. # ext_file = result.scalars().first()
  51. # if ext_file is not None:
  52. # # TODO: 文件去重策略
  53. # return ext_file
  54. file_extension = Path(file.filename).suffix
  55. file_key = f"{uuid.uuid4()}{file_extension}"
  56. print(file_key)
  57. fileinfo = {"document_id": file_key}
  58. # file_key = f"{uuid.uuid4()}-{file.filename}"
  59. with tempfile.NamedTemporaryFile(
  60. suffix="_" + file.filename, delete=True
  61. ) as temp_file:
  62. tmp_file_path = temp_file.name
  63. async with aiofiles.open(tmp_file_path, "wb") as f:
  64. while content := await file.read(1024):
  65. await f.write(content)
  66. # storage.save_from_path(filename=file_key, local_file_path=tmp_file_path)
  67. await r2r.init()
  68. fileinfo = await r2r.ingest_file(
  69. file_path=tmp_file_path,
  70. metadata={"file_key": file_key, "title": file.filename},
  71. )
  72. fileinfo = fileinfo.get("results")
  73. # 存储
  74. db_file = File(
  75. purpose=purpose,
  76. filename=file.filename,
  77. bytes=file.size,
  78. key=fileinfo["document_id"],
  79. )
  80. session.add(db_file)
  81. await session.commit()
  82. await session.refresh(db_file)
  83. return db_file
  84. @staticmethod
  85. def search_in_files(
  86. query: str, file_keys: List[str], folder_keys: List[str] = None
  87. ) -> dict:
  88. files = {}
  89. file_key = {"$in": []}
  90. document_id = {"$in": []}
  91. filters = {"$or": []}
  92. for key in file_keys:
  93. if len(key) == 36:
  94. document_id["$in"].append(key)
  95. else:
  96. file_key["$in"].append(key)
  97. if len(document_id["$in"]) > 0:
  98. filters["$or"].append({"document_id": document_id})
  99. if len(file_key["$in"]) > 0:
  100. filters["$or"].append({"file_key": file_key})
  101. if folder_keys:
  102. filters = filters["$or"].append(
  103. {"collection_ids": {"$in": folder_keys}}
  104. ) ## {"$or": [filters, {"collection_ids": {"$in": folder_keys}}]}
  105. ##filters["collection_ids"] = {"$overlap": folder_keys}
  106. ## {"$and": {"$document_id": ..., "collection_ids": ...}}
  107. """
  108. {
  109. "$or": [
  110. {"document_id": {"$eq": "9fbe403b-..."}},
  111. {"collection_ids": {"$in": ["122fdf6a-...", "..."]}}
  112. ]
  113. }
  114. """
  115. if len(filters["$or"]) < 2:
  116. filters = filters["$or"][0]
  117. print("filtersfiltersfiltersfiltersfiltersfiltersfiltersfiltersfiltersfilters")
  118. print(filters)
  119. """
  120. loop = asyncio.get_event_loop() # 获取当前事件循环
  121. loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
  122. search_results = loop.run_until_complete(r2r.search(query, filters=filters))
  123. asyncio.run(r2r.init())
  124. search_results = asyncio.run(r2r.search(query, filters=filters))
  125. search_results = loop.run_until_complete(
  126. r2r.search(query, filters={"file_key": {"$in": file_keys}})
  127. )
  128. """
  129. r2r.init_sync()
  130. search_results = r2r.search(query, filters=filters)
  131. if not search_results:
  132. return files
  133. for doc in search_results:
  134. file_key = doc.get("metadata").get("file_key")
  135. text = doc.get("text")
  136. if file_key in files and files[file_key]:
  137. files[file_key] += f"\n\n{text}"
  138. else:
  139. files[file_key] = doc.get("text")
  140. return files
  141. @staticmethod
  142. def list_in_files(
  143. ids: list[str] = None,
  144. offset: int = 0,
  145. limit: int = 100,
  146. ) -> dict:
  147. """
  148. loop = asyncio.get_event_loop() # 获取当前事件循环
  149. loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
  150. list_results = loop.run_until_complete(
  151. r2r.list(ids=ids, offset=offset, limit=limit)
  152. )
  153. asyncio.run(r2r.init())
  154. list_results = asyncio.run(r2r.list(ids=ids, offset=offset, limit=limit))
  155. """
  156. r2r.init_sync()
  157. list_results = r2r.list(ids=ids, offset=offset, limit=limit)
  158. return list_results
  159. @staticmethod
  160. def list_documents(
  161. id: str = "",
  162. offset: int = 0,
  163. limit: int = 100,
  164. ) -> dict:
  165. """
  166. loop = asyncio.get_event_loop() # 获取当前事件循环
  167. loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
  168. list_results = loop.run_until_complete(
  169. r2r.list_documents(id=id, offset=offset, limit=limit)
  170. )
  171. asyncio.run(r2r.init())
  172. list_results = asyncio.run(
  173. r2r.list_documents(id=id, offset=offset, limit=limit)
  174. )
  175. """
  176. r2r.init_sync()
  177. list_results = r2r.list_documents(id=id, offset=offset, limit=limit)
  178. return list_results
  179. # TODO 删除s3&r2r文件