123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- import tempfile
- import uuid
- from typing import List
- import aiofiles
- import aiofiles.os
- from fastapi import UploadFile
- from sqlalchemy.ext.asyncio import AsyncSession
- from app.models import File
- from app.providers.r2r import r2r
- from app.providers.storage import storage
- from app.services.file.impl.oss_file import OSSFileService
- # import asyncio
- from pathlib import Path
- class R2RFileService(OSSFileService):
- @staticmethod
- async def create_file(
- *, session: AsyncSession, purpose: str, file: UploadFile
- ) -> File:
- # 文件是否存在
- # statement = (
- # select(File)
- # .where(File.purpose == purpose)
- # .where(File.filename == file.filename)
- # .where(File.bytes == file.size)
- # )
- # result = await session.execute(statement)
- # ext_file = result.scalars().first()
- # if ext_file is not None:
- # # TODO: 文件去重策略
- # return ext_file
- file_extension = Path(file.filename).suffix
- file_key = f"{uuid.uuid4()}{file_extension}"
- print(file_key)
- fileinfo = {"document_id": file_key}
- # file_key = f"{uuid.uuid4()}-{file.filename}"
- with tempfile.NamedTemporaryFile(
- suffix="_" + file.filename, delete=True
- ) as temp_file:
- tmp_file_path = temp_file.name
- async with aiofiles.open(tmp_file_path, "wb") as f:
- while content := await file.read(1024):
- await f.write(content)
- # storage.save_from_path(filename=file_key, local_file_path=tmp_file_path)
- await r2r.init()
- fileinfo = await r2r.ingest_file(
- file_path=tmp_file_path,
- metadata={"file_key": file_key, "title": file.filename},
- )
- fileinfo = fileinfo["results"]
- # 存储
- db_file = File(
- purpose=purpose,
- filename=file.filename,
- bytes=file.size,
- key=fileinfo["document_id"],
- )
- session.add(db_file)
- await session.commit()
- await session.refresh(db_file)
- return db_file
- @staticmethod
- def search_in_files(
- query: str, file_keys: List[str], folder_keys: List[str] = None
- ) -> dict:
- files = {}
- file_key = {"$in": []}
- document_id = {"$in": []}
- filters = {"$or": []}
- print(
- "ggggggggggggggggggggggggggggggggggggddddddddddddddddddccccccccccccccccccccc"
- )
- for key in file_keys:
- if len(key) == 36:
- document_id["$in"].append(key)
- else:
- file_key["$in"].append(key)
- if len(document_id["$in"]) > 0:
- filters["$or"].append({"document_id": document_id})
- if len(file_key["$in"]) > 0:
- filters["$or"].append({"file_key": file_key})
- print(file_key)
- print(document_id)
- print(filters)
- print(folder_keys)
- if folder_keys:
- filters["$or"].append({"collection_ids": {"$overlap": folder_keys}})
- ## {"$or": [filters, {"collection_ids": {"$in": folder_keys}}]}
- ##filters["collection_ids"] = {"$overlap": folder_keys}
- ## {"$and": {"$document_id": ..., "collection_ids": ...}}
- """
- {
- "$or": [
- {"document_id": {"$eq": "9fbe403b-..."}},
- {"collection_ids": {"$in": ["122fdf6a-...", "..."]}}
- ]
- }
- """
- print(filters)
- if len(filters["$or"]) < 2:
- filters = filters["$or"][0]
- print("filtersfiltersfiltersfiltersfiltersfiltersfiltersfiltersfiltersfilters")
- print(filters)
- """
- loop = asyncio.get_event_loop() # 获取当前事件循环
- loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
- search_results = loop.run_until_complete(r2r.search(query, filters=filters))
- asyncio.run(r2r.init())
- search_results = asyncio.run(r2r.search(query, filters=filters))
-
- search_results = loop.run_until_complete(
- r2r.search(query, filters={"file_key": {"$in": file_keys}})
- )
- """
- r2r.init_sync()
- search_results = r2r.search(query, filters=filters)
- if not search_results:
- return files
- for doc in search_results:
- file_key = doc.get("metadata").get("file_key")
- file_key = (
- doc.get("metadata").get("title") if file_key is None else file_key
- )
- text = doc.get("text")
- if file_key in files and files[file_key]:
- files[file_key] += f"\n\n{text}"
- else:
- files[file_key] = doc.get("text")
- return files
- @staticmethod
- def list_in_files(
- ids: list[str] = None,
- offset: int = 0,
- limit: int = 100,
- ) -> dict:
- """
- loop = asyncio.get_event_loop() # 获取当前事件循环
- loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
- list_results = loop.run_until_complete(
- r2r.list(ids=ids, offset=offset, limit=limit)
- )
- asyncio.run(r2r.init())
- list_results = asyncio.run(r2r.list(ids=ids, offset=offset, limit=limit))
- """
- r2r.init_sync()
- list_results = r2r.list(ids=ids, offset=offset, limit=limit)
- return list_results
- @staticmethod
- def list_documents(
- id: str = "",
- offset: int = 0,
- limit: int = 100,
- ) -> dict:
- """
- loop = asyncio.get_event_loop() # 获取当前事件循环
- loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
- list_results = loop.run_until_complete(
- r2r.list_documents(id=id, offset=offset, limit=limit)
- )
- asyncio.run(r2r.init())
- list_results = asyncio.run(
- r2r.list_documents(id=id, offset=offset, limit=limit)
- )
- """
- r2r.init_sync()
- list_results = r2r.list_documents(id=id, offset=offset, limit=limit)
- return list_results
- @staticmethod
- def list_chunks(ids: list[str]) -> dict:
- if len(ids) > 0:
- r2r.init_sync()
- list_results = r2r.list_chunks(ids=ids)
- files = {}
- for doc in list_results:
- file_key = doc.get("metadata").get("file_key")
- file_key = (
- doc.get("metadata").get("title") if file_key is None else file_key
- )
- text = doc.get("text")
- if file_key in files and files[file_key]:
- files[file_key] += f"\n\n{text}"
- else:
- files[file_key] = doc.get("text")
- return list_results
- return {}
- # TODO 删除s3&r2r文件
|