|
@@ -52,7 +52,8 @@ class R2RFileService(OSSFileService):
|
|
|
# storage.save_from_path(filename=file_key, local_file_path=tmp_file_path)
|
|
|
await r2r.init()
|
|
|
await r2r.ingest_file(
|
|
|
- file_path=tmp_file_path, metadata={"file_key": file_key}
|
|
|
+ file_path=tmp_file_path,
|
|
|
+ metadata={"file_key": file_key, "title": file.filename},
|
|
|
)
|
|
|
# 存储
|
|
|
db_file = File(
|
|
@@ -64,13 +65,26 @@ class R2RFileService(OSSFileService):
|
|
|
return db_file
|
|
|
|
|
|
@staticmethod
|
|
|
- def search_in_files(query: str, file_keys: List[str]) -> dict:
|
|
|
+ def search_in_files(
|
|
|
+ query: str, file_keys: List[str], folder_keys: List[str] = None
|
|
|
+ ) -> dict:
|
|
|
files = {}
|
|
|
+ filters = {"file_key": {"$in": file_keys}}
|
|
|
+ if not folder_keys:
|
|
|
+ filters = {"$or": [filters, {"collection_ids": {"$in": folder_keys}}]}
|
|
|
+ ##filters["collection_ids"] = {"$overlap": folder_keys}
|
|
|
+ ## {"$and": {"$document_id": ..., "collection_ids": ...}}
|
|
|
+ """
|
|
|
+ {
|
|
|
+ "$or": [
|
|
|
+ {"document_id": {"$eq": "9fbe403b-..."}},
|
|
|
+ {"collection_ids": {"$in": ["122fdf6a-...", "..."]}}
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ """
|
|
|
loop = asyncio.get_event_loop() # 获取当前事件循环
|
|
|
loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
|
|
|
- search_results = loop.run_until_complete(
|
|
|
- r2r.search(query, filters={"file_key": {"$in": file_keys}})
|
|
|
- )
|
|
|
+ search_results = loop.run_until_complete(r2r.search(query, filters=filters))
|
|
|
|
|
|
if not search_results:
|
|
|
return files
|
|
@@ -85,4 +99,30 @@ class R2RFileService(OSSFileService):
|
|
|
|
|
|
return files
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def list_in_files(
|
|
|
+ ids: list[str] = None,
|
|
|
+ offset: int = 0,
|
|
|
+ limit: int = 100,
|
|
|
+ ) -> dict:
|
|
|
+ loop = asyncio.get_event_loop() # 获取当前事件循环
|
|
|
+ loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
|
|
|
+ list_results = loop.run_until_complete(
|
|
|
+ r2r.list(ids=ids, offset=offset, limit=limit)
|
|
|
+ )
|
|
|
+ return list_results
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def list_documents(
|
|
|
+ ids: str = "",
|
|
|
+ offset: int = 0,
|
|
|
+ limit: int = 100,
|
|
|
+ ) -> dict:
|
|
|
+ loop = asyncio.get_event_loop() # 获取当前事件循环
|
|
|
+ loop.run_until_complete(r2r.init()) # 确保 r2r 已初始化
|
|
|
+ list_results = loop.run_until_complete(
|
|
|
+ r2r.list_documents(ids=ids, offset=offset, limit=limit)
|
|
|
+ )
|
|
|
+ return list_results
|
|
|
+
|
|
|
# TODO 删除s3&r2r文件
|