jack 4 months ago
parent
commit
6c271d2d78

+ 2 - 4
app/core/tools/file_search_tool.py

@@ -40,10 +40,8 @@ class FileSearchTool(BaseTool):
         # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
 
         loop = asyncio.get_event_loop()  # 获取当前事件循环
-        db_asst = loop.run_until_complete(
-            AssistantService.get_assistant(
-                session=session, assistant_id=run.assistant_id
-            )
+        db_asst = AssistantService.get_assistant_sync(
+            session=session, assistant_id=run.assistant_id
         )
 
         if db_asst.tool_resources and "file_search" in db_asst.tool_resources:

+ 1 - 1
app/providers/r2r.py

@@ -101,7 +101,7 @@ class R2R:
         )
         if id != "":
             listed = await self.client.collections.list_documents(
-                ids=id, limit=limit, offset=offset
+                id=id, limit=limit, offset=offset
             )
             print(listed.get("results"))
             return listed.get("results")

+ 9 - 0
app/services/assistant/assistant.py

@@ -78,3 +78,12 @@ class AssistantService:
         if assistant is None:
             raise ResourceNotFoundError(message="Assistant not found")
         return assistant
+
+    @staticmethod
+    def get_assistant_sync(*, session: AsyncSession, assistant_id: str) -> Assistant:
+        statement = select(Assistant).where(Assistant.id == assistant_id)
+        result = session.execute(statement)
+        assistant = result.scalars().one_or_none()
+        if assistant is None:
+            raise ResourceNotFoundError(message="Assistant not found")
+        return assistant

+ 1 - 1
app/services/file/impl/oss_file.py

@@ -112,7 +112,7 @@ class OSSFileService(BaseFileService):
 
     @staticmethod
     def list_documents(
-        ids: str = None,
+        id: str = "",
         offset: int = 0,
         limit: int = 100,
     ) -> dict:

+ 2 - 2
app/services/file/impl/r2r_file.py

@@ -114,14 +114,14 @@ class R2RFileService(OSSFileService):
 
     @staticmethod
     def list_documents(
-        ids: str = "",
+        id: str = "",
         offset: int = 0,
         limit: int = 100,
     ) -> dict:
         loop = asyncio.get_event_loop()  # 获取当前事件循环
         loop.run_until_complete(r2r.init())  # 确保 r2r 已初始化
         list_results = loop.run_until_complete(
-            r2r.list_documents(ids=ids, offset=offset, limit=limit)
+            r2r.list_documents(id=id, offset=offset, limit=limit)
         )
         return list_results