r2r_file.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import tempfile
  2. import uuid
  3. from typing import List
  4. import aiofiles
  5. import aiofiles.os
  6. from fastapi import UploadFile
  7. from sqlalchemy.ext.asyncio import AsyncSession
  8. from app.models import File
  9. from app.providers.r2r import r2r
  10. from app.providers.storage import storage
  11. from app.services.file.impl.oss_file import OSSFileService
  12. import asyncio
  13. class R2RFileService(OSSFileService):
  14. @staticmethod
  15. async def create_file(
  16. *, session: AsyncSession, purpose: str, file: UploadFile
  17. ) -> File:
  18. # 文件是否存在
  19. # statement = (
  20. # select(File)
  21. # .where(File.purpose == purpose)
  22. # .where(File.filename == file.filename)
  23. # .where(File.bytes == file.size)
  24. # )
  25. # result = await session.execute(statement)
  26. # ext_file = result.scalars().first()
  27. # if ext_file is not None:
  28. # # TODO: 文件去重策略
  29. # return ext_file
  30. file_key = f"{uuid.uuid4()}-{file.filename}"
  31. with tempfile.NamedTemporaryFile(
  32. suffix="_" + file.filename, delete=True
  33. ) as temp_file:
  34. tmp_file_path = temp_file.name
  35. async with aiofiles.open(tmp_file_path, "wb") as f:
  36. while content := await file.read(1024):
  37. await f.write(content)
  38. storage.save_from_path(filename=file_key, local_file_path=tmp_file_path)
  39. await r2r.init()
  40. r2r.ingest_file(file_path=tmp_file_path, metadata={"file_key": file_key})
  41. # 存储
  42. db_file = File(
  43. purpose=purpose, filename=file.filename, bytes=file.size, key=file_key
  44. )
  45. session.add(db_file)
  46. await session.commit()
  47. await session.refresh(db_file)
  48. return db_file
  49. @staticmethod
  50. def search_in_files(query: str, file_keys: List[str]) -> dict:
  51. files = {}
  52. asyncio.create_task(r2r.init())
  53. search_results = r2r.search(query, filters={"file_key": {"$in": file_keys}})
  54. if not search_results:
  55. return files
  56. for doc in search_results:
  57. file_key = doc.get("metadata").get("file_key")
  58. text = doc.get("text")
  59. if file_key in files and files[file_key]:
  60. files[file_key] += f"\n\n{text}"
  61. else:
  62. files[file_key] = doc.get("text")
  63. return files
  64. # TODO 删除s3&r2r文件