|
@@ -11,6 +11,7 @@ from app.models import File
|
|
from app.providers.r2r import r2r
|
|
from app.providers.r2r import r2r
|
|
from app.providers.storage import storage
|
|
from app.providers.storage import storage
|
|
from app.services.file.impl.oss_file import OSSFileService
|
|
from app.services.file.impl.oss_file import OSSFileService
|
|
|
|
+import asyncio
|
|
|
|
|
|
|
|
|
|
class R2RFileService(OSSFileService):
|
|
class R2RFileService(OSSFileService):
|
|
@@ -42,7 +43,7 @@ class R2RFileService(OSSFileService):
|
|
await f.write(content)
|
|
await f.write(content)
|
|
|
|
|
|
storage.save_from_path(filename=file_key, local_file_path=tmp_file_path)
|
|
storage.save_from_path(filename=file_key, local_file_path=tmp_file_path)
|
|
- r2r.init()
|
|
|
|
|
|
+ await r2r.init()
|
|
r2r.ingest_file(file_path=tmp_file_path, metadata={"file_key": file_key})
|
|
r2r.ingest_file(file_path=tmp_file_path, metadata={"file_key": file_key})
|
|
|
|
|
|
# 存储
|
|
# 存储
|
|
@@ -57,7 +58,7 @@ class R2RFileService(OSSFileService):
|
|
@staticmethod
|
|
@staticmethod
|
|
def search_in_files(query: str, file_keys: List[str]) -> dict:
|
|
def search_in_files(query: str, file_keys: List[str]) -> dict:
|
|
files = {}
|
|
files = {}
|
|
- r2r.init()
|
|
|
|
|
|
+ asyncio.create_task(r2r.init())
|
|
search_results = r2r.search(query, filters={"file_key": {"$in": file_keys}})
|
|
search_results = r2r.search(query, filters={"file_key": {"$in": file_keys}})
|
|
if not search_results:
|
|
if not search_results:
|
|
return files
|
|
return files
|