oss_file.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import uuid
  2. from typing import List, Union, Generator, Tuple
  3. from fastapi import UploadFile
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from sqlalchemy.orm import Session
  6. from sqlmodel import select, col, desc
  7. from app.core.doc_loaders import doc_loader
  8. from app.exceptions.exception import ResourceNotFoundError
  9. from app.models import File
  10. from app.providers.storage import storage
  11. from app.schemas.common import DeleteResponse
  12. from app.services.file.file import BaseFileService
  13. import json
  14. class OSSFileService(BaseFileService):
  15. @staticmethod
  16. def get_file_list_by_ids(*, session: Session, file_ids: List[str]) -> List[File]:
  17. if not file_ids:
  18. return []
  19. statement = select(File).where(col(File.id).in_(file_ids))
  20. return session.execute(statement).scalars().all()
  21. @staticmethod
  22. async def get_file_list(
  23. *, session: AsyncSession, purpose: str, file_ids: List[str]
  24. ) -> List[File]:
  25. statement = select(File)
  26. if purpose is not None and len(purpose) > 0:
  27. statement = statement.where(File.purpose == purpose)
  28. if file_ids is not None:
  29. statement = statement.where(File.id.in_(file_ids))
  30. statement = statement.order_by(desc(File.created_at))
  31. result = await session.execute(statement)
  32. return result.scalars().all()
  33. @staticmethod
  34. async def create_file(
  35. *, session: AsyncSession, purpose: str, file: UploadFile
  36. ) -> File:
  37. # 文件是否存在
  38. # statement = (
  39. # select(File)
  40. # .where(File.purpose == purpose)
  41. # .where(File.filename == file.filename)
  42. # .where(File.bytes == file.size)
  43. # )
  44. # result = await session.execute(statement)
  45. # ext_file = result.scalars().first()
  46. # if ext_file is not None:
  47. # # TODO: 文件去重策略
  48. # return ext_file
  49. file_key = f"{uuid.uuid4()}-{file.filename}"
  50. storage.save(filename=file_key, data=file.file.read())
  51. # 存储
  52. db_file = File(
  53. purpose=purpose, filename=file.filename, bytes=file.size, key=file_key
  54. )
  55. session.add(db_file)
  56. await session.commit()
  57. await session.refresh(db_file)
  58. return db_file
  59. @staticmethod
  60. async def get_file(*, session: AsyncSession, file_id: str) -> File:
  61. statement = select(File).where(File.id == file_id)
  62. result = await session.execute(statement)
  63. ext_file = result.scalars().one_or_none()
  64. if ext_file is None:
  65. raise ResourceNotFoundError(message="File not found")
  66. return ext_file
  67. @staticmethod
  68. async def get_file_content(
  69. *, session: AsyncSession, file_id: str
  70. ) -> Tuple[Union[bytes, Generator], str]:
  71. ext_file = await OSSFileService.get_file(session=session, file_id=file_id)
  72. file_data = storage.load(ext_file.key)
  73. return file_data, ext_file.filename
  74. @staticmethod
  75. async def delete_file(*, session: AsyncSession, file_id: str) -> DeleteResponse:
  76. ext_file = await OSSFileService.get_file(session=session, file_id=file_id)
  77. # TODO 删除s3文件
  78. # 删除记录
  79. await session.delete(ext_file)
  80. await session.commit()
  81. return DeleteResponse(id=file_id, deleted=True)
  82. @staticmethod
  83. def search_in_files(query: str, file_keys: List[str]) -> dict:
  84. files = {}
  85. for file_key in file_keys:
  86. file_data = storage.load(file_key)
  87. # 截取前 5000 字符,防止超出 LLM 最大上下文限制
  88. files[file_key] = doc_loader.load(file_data)[:5000]
  89. return files