import tempfile import uuid from typing import List import aiofiles import aiofiles.os from fastapi import UploadFile from sqlalchemy.ext.asyncio import AsyncSession from app.models import File from app.providers.r2r import r2r from app.providers.storage import storage from app.services.file.impl.oss_file import OSSFileService # import asyncio from pathlib import Path class R2RFileService(OSSFileService): @staticmethod async def create_file( *, session: AsyncSession, purpose: str, file: UploadFile ) -> File: # 文件是否存在 # statement = ( # select(File) # .where(File.purpose == purpose) # .where(File.filename == file.filename) # .where(File.bytes == file.size) # ) # result = await session.execute(statement) # ext_file = result.scalars().first() # if ext_file is not None: # # TODO: 文件去重策略 # return ext_file file_extension = Path(file.filename).suffix file_key = f"{uuid.uuid4()}{file_extension}" print(file_key) fileinfo = {"document_id": file_key} # file_key = f"{uuid.uuid4()}-{file.filename}" with tempfile.NamedTemporaryFile( suffix="_" + file.filename, delete=True ) as temp_file: tmp_file_path = temp_file.name async with aiofiles.open(tmp_file_path, "wb") as f: while content := await file.read(1024): await f.write(content) # storage.save_from_path(filename=file_key, local_file_path=tmp_file_path) await r2r.init() fileinfo = await r2r.ingest_file( file_path=tmp_file_path, metadata={"file_key": file_key, "title": file.filename}, ) fileinfo = fileinfo["results"] # 存储 db_file = File( purpose=purpose, filename=file.filename, bytes=file.size, key=fileinfo["document_id"], ) session.add(db_file) await session.commit() await session.refresh(db_file) return db_file @staticmethod def search_in_files( query: str, file_keys: List[str], folder_keys: List[str] = None ) -> dict: files = {} file_key = {"$in": []} document_id = {"$in": []} filters = {"$or": []} print( "ggggggggggggggggggggggggggggggggggggddddddddddddddddddccccccccccccccccccccc" ) for key in file_keys: if len(key) == 36: document_id["$in"].append(key) else: file_key["$in"].append(key) if len(document_id["$in"]) > 0: filters["$or"].append({"document_id": document_id}) if len(file_key["$in"]) > 0: filters["$or"].append({"file_key": file_key}) print(file_key) print(document_id) print(filters) print(folder_keys) if folder_keys: filters["$or"].append({"collection_ids": {"$overlap": folder_keys}}) ## {"$or": [filters, {"collection_ids": {"$in": folder_keys}}]} ##filters["collection_ids"] = {"$overlap": folder_keys} ## {"$and": {"$document_id": ..., "collection_ids": ...}} """ { "$or": [ {"document_id": {"$eq": "9fbe403b-..."}}, {"collection_ids": {"$in": ["122fdf6a-...", "..."]}} ] } """ print(filters) if len(filters["$or"]) < 2: filters = filters["$or"][0] print("filtersfiltersfiltersfiltersfiltersfiltersfiltersfiltersfiltersfilters") print(filters) """ loop = asyncio.get_event_loop() # 获取当前事件循环 loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化 search_results = loop.run_until_complete(r2r.search(query, filters=filters)) asyncio.run(r2r.init()) search_results = asyncio.run(r2r.search(query, filters=filters)) search_results = loop.run_until_complete( r2r.search(query, filters={"file_key": {"$in": file_keys}}) ) """ r2r.init_sync() search_results = r2r.search(query, filters=filters) if not search_results: return files for doc in search_results: file_key = doc.get("metadata").get("file_key") file_key = ( doc.get("metadata").get("title") if file_key is None else file_key ) text = doc.get("text") if file_key in files and files[file_key]: files[file_key] += f"\n\n{text}" else: files[file_key] = doc.get("text") return files @staticmethod def list_in_files( ids: list[str] = None, offset: int = 0, limit: int = 100, ) -> dict: """ loop = asyncio.get_event_loop() # 获取当前事件循环 loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化 list_results = loop.run_until_complete( r2r.list(ids=ids, offset=offset, limit=limit) ) asyncio.run(r2r.init()) list_results = asyncio.run(r2r.list(ids=ids, offset=offset, limit=limit)) """ r2r.init_sync() list_results = r2r.list(ids=ids, offset=offset, limit=limit) return list_results @staticmethod def list_documents( id: str = "", offset: int = 0, limit: int = 100, ) -> dict: """ loop = asyncio.get_event_loop() # 获取当前事件循环 loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化 list_results = loop.run_until_complete( r2r.list_documents(id=id, offset=offset, limit=limit) ) asyncio.run(r2r.init()) list_results = asyncio.run( r2r.list_documents(id=id, offset=offset, limit=limit) ) """ r2r.init_sync() list_results = r2r.list_documents(id=id, offset=offset, limit=limit) return list_results @staticmethod def list_chunks(ids: list[str]) -> dict: if len(ids) > 0: r2r.init_sync() list_results = r2r.list_chunks(ids=ids) files = {} for doc in list_results: file_key = doc.get("metadata").get("file_key") file_key = ( doc.get("metadata").get("title") if file_key is None else file_key ) text = doc.get("text") if file_key in files and files[file_key]: files[file_key] += f"\n\n{text}" else: files[file_key] = doc.get("text") return list_results return {} # TODO 删除s3&r2r文件