jack 1 månad sedan
förälder
incheckning
f445d84f86

+ 10 - 5
app/core/runner/llm_backend.py

@@ -42,7 +42,9 @@ class LLMBackend:
         if stream_options:
             if isinstance(stream_options, dict):
                 if "include_usage" in stream_options:
-                    chat_params["stream_options"] = {"include_usage": bool(stream_options["include_usage"])}
+                    chat_params["stream_options"] = {
+                        "include_usage": bool(stream_options["include_usage"])
+                    }
         if temperature:
             chat_params["temperature"] = temperature
         if top_p:
@@ -50,11 +52,14 @@ class LLMBackend:
         if tools:
             chat_params["tools"] = tools
             chat_params["tool_choice"] = tool_choice if tool_choice else "auto"
-        if isinstance(response_format, dict) and response_format.get("type") == "json_object":
+        if (
+            isinstance(response_format, dict)
+            and response_format.get("type") == "json_object"
+        ):
             chat_params["response_format"] = {"type": "json_object"}
-        for message in chat_params['messages']:
-            if 'content' not in message:
-                message['content'] = ""
+        for message in chat_params["messages"]:
+            if "content" not in message:
+                message["content"] = ""
         logging.info("chat_params: %s", chat_params)
         response = self.client.chat.completions.create(**chat_params)
         logging.info("chat_response: %s", response)

+ 4 - 2
app/core/runner/thread_runner.py

@@ -79,13 +79,15 @@ class ThreadRunner:
         metadata = ast.metadata_ or {}
         memory = find_memory(metadata.get("memory", {}))
 
-        instructions = [run.instructions] if run.instructions else [ast.instructions]
+        instructions = (
+            [run.instructions or ""] if run.instructions else [ast.instructions or ""]
+        )
         tools = find_tools(run, self.session)
         for tool in tools:
             tool.configure(session=self.session, run=run)
             instruction_supplement = tool.instruction_supplement()
             if instruction_supplement:
-                instructions += [instruction_supplement]
+                instructions += [instruction_supplement or ""]
         instruction = "\n".join(instructions)
 
         llm = self.__init_llm_backend(run.assistant_id)

+ 3 - 1
app/core/tools/__init__.py

@@ -26,7 +26,9 @@ TOOLS = {
 
 def find_tools(run, session: Session) -> List[BaseTool]:
     action_ids = [tool.get("id") for tool in run.tools if tool.get("type") == "action"]
-    actions = session.execute(select(Action).where(Action.id.in_(action_ids))).scalars().all()
+    actions = (
+        session.execute(select(Action).where(Action.id.in_(action_ids))).scalars().all()
+    )
     action_map = {action.id: action for action in actions}
 
     tools = []

+ 10 - 2
app/core/tools/base_tool.py

@@ -33,8 +33,16 @@ class BaseTool(ABC):
     openai_function: Dict
 
     def __init_subclass__(cls) -> None:
-        lc_tool = LCTool(name=cls.name, description=cls.description, args_schema=cls.args_schema, _run=lambda x: x)
-        cls.openai_function = {"type": "function", "function": dict(format_tool_to_openai_function(lc_tool))}
+        lc_tool = LCTool(
+            name=cls.name,
+            description=cls.description,
+            args_schema=cls.args_schema,
+            _run=lambda x: x,
+        )
+        cls.openai_function = {
+            "type": "function",
+            "function": dict(format_tool_to_openai_function(lc_tool)),
+        }
 
     def configure(self, **kwargs):
         """

+ 27 - 3
app/core/tools/file_search_tool.py

@@ -6,6 +6,8 @@ from sqlalchemy.orm import Session
 from app.core.tools.base_tool import BaseTool
 from app.models.run import Run
 from app.services.file.file import FileService
+from app.services.assistant.assistant import AssistantService
+import asyncio
 
 
 class FileSearchToolInput(BaseModel):
@@ -33,15 +35,37 @@ class FileSearchTool(BaseTool):
         """
         置当前 Retrieval 涉及文件信息
         """
-        files = FileService.get_file_list_by_ids(session=session, file_ids=run.file_ids)
+        ## 获取文件信息
+        # files = FileService.get_file_list_by_ids(session=session, file_ids=run.file_ids)
+        files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
+
+        loop = asyncio.get_event_loop()  # 获取当前事件循环
+        db_asst = loop.run_until_complete(
+            AssistantService.get_assistant(
+                session=session, assistant_id=run.assistant_id
+            )
+        )
+
+        if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
+            ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
+            asst_folder_ids = (
+                db_asst.tool_resources.get("file_search")
+                .get("vector_stores")[0]
+                .get("folder_ids")
+            )
+            print(asst_folder_ids)
+            if asst_folder_ids:
+                for fid in asst_folder_ids:
+                    files += FileService.list_documents(id=fid, offset=0, limit=100)
+
         # pre-cache data to prevent thread conflicts that may occur later on.
         print(
             "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
         )
         print(files)
         for file in files:
-            self.__filenames.append(file.filename)
-            self.__keys.append(file.key)
+            self.__filenames.append(file.title)
+            self.__keys.append(file.get("metadata").get("file_key"))
         print(self.__keys)
 
     def run(self, indexes: List[int], query: str) -> dict:

+ 47 - 0
app/providers/r2r.py

@@ -59,6 +59,53 @@ class R2R:
         print(search_response.get("results"))
         return search_response.get("results").get("chunk_search_results")
 
+    async def list(
+        self,
+        ids: Optional[list[str]] = None,
+        offset: Optional[int] = 0,
+        limit: Optional[int] = 100,
+    ):
+        await self._check_login()
+
+        """
+            listed = mutable_client.documents.list(limit=2, offset=0)
+    results = listed.results
+    assert len(results) == 2, "Expected 2 results for paginated listing"
+        """
+        print("listlistlistlistlistlistlistlistlistlistlistlistlistlistlistlistlist")
+        if len(ids) > 0:
+            listed = await self.client.documents.list(
+                ids=ids, limit=limit, offset=offset
+            )
+            print(listed.get("results"))
+            return listed.get("results")
+        else:
+            return []
+
+    async def list_documents(
+        self,
+        id: Optional[str] = "",
+        offset: Optional[int] = 0,
+        limit: Optional[int] = 100,
+    ):
+        await self._check_login()
+
+        """
+            docs = client.collections.list_documents(empty_coll_id).results
+            assert len(docs) == 0, "Expected no documents in a new empty collection"
+        """
+        print(
+            "collectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollectionscollections"
+        )
+        if id != "":
+            listed = await self.client.collections.list_documents(
+                ids=id, limit=limit, offset=offset
+            )
+            print(listed.get("results"))
+            return listed.get("results")
+        else:
+            return []
+
     async def _check_login(self):
         if not self.auth_enabled:
             return

+ 27 - 3
app/services/file/impl/base.py

@@ -18,12 +18,16 @@ class BaseFileService(ABC):
 
     @staticmethod
     @abstractmethod
-    async def get_file_list(*, session: AsyncSession, purpose: str, file_ids: Optional[List[str]]) -> List[File]:
+    async def get_file_list(
+        *, session: AsyncSession, purpose: str, file_ids: Optional[List[str]]
+    ) -> List[File]:
         pass
 
     @staticmethod
     @abstractmethod
-    async def create_file(*, session: AsyncSession, purpose: str, file: UploadFile) -> File:
+    async def create_file(
+        *, session: AsyncSession, purpose: str, file: UploadFile
+    ) -> File:
         pass
 
     @staticmethod
@@ -33,7 +37,9 @@ class BaseFileService(ABC):
 
     @staticmethod
     @abstractmethod
-    async def get_file_content(*, session: AsyncSession, file_id: str) -> Tuple[Union[bytes, Generator], str]:
+    async def get_file_content(
+        *, session: AsyncSession, file_id: str
+    ) -> Tuple[Union[bytes, Generator], str]:
         pass
 
     @staticmethod
@@ -45,3 +51,21 @@ class BaseFileService(ABC):
     @abstractmethod
     def search_in_files(*, query: str, file_keys: List[str]) -> dict:
         pass
+
+    @staticmethod
+    @abstractmethod
+    def list_in_files(
+        ids: list[str] = None,
+        offset: int = 0,
+        limit: int = 100,
+    ) -> dict:
+        pass
+
+    @staticmethod
+    @abstractmethod
+    def list_documents(
+        id: str = "",
+        offset: int = 0,
+        limit: int = 100,
+    ) -> dict:
+        pass

+ 16 - 0
app/services/file/impl/oss_file.py

@@ -101,3 +101,19 @@ class OSSFileService(BaseFileService):
             files[file_key] = doc_loader.load(file_data)[:5000]
 
         return files
+
+    @staticmethod
+    def list_in_files(
+        ids: list[str] = None,
+        offset: int = 0,
+        limit: int = 100,
+    ) -> dict:
+        return []
+
+    @staticmethod
+    def list_documents(
+        ids: str = None,
+        offset: int = 0,
+        limit: int = 100,
+    ) -> dict:
+        return []

+ 45 - 5
app/services/file/impl/r2r_file.py

@@ -52,7 +52,8 @@ class R2RFileService(OSSFileService):
             # storage.save_from_path(filename=file_key, local_file_path=tmp_file_path)
             await r2r.init()
             await r2r.ingest_file(
-                file_path=tmp_file_path, metadata={"file_key": file_key}
+                file_path=tmp_file_path,
+                metadata={"file_key": file_key, "title": file.filename},
             )
         # 存储
         db_file = File(
@@ -64,13 +65,26 @@ class R2RFileService(OSSFileService):
         return db_file
 
     @staticmethod
-    def search_in_files(query: str, file_keys: List[str]) -> dict:
+    def search_in_files(
+        query: str, file_keys: List[str], folder_keys: List[str] = None
+    ) -> dict:
         files = {}
+        filters = {"file_key": {"$in": file_keys}}
+        if not folder_keys:
+            filters = {"$or": [filters, {"collection_ids": {"$in": folder_keys}}]}
+        ##filters["collection_ids"] = {"$overlap": folder_keys}
+        ## {"$and": {"$document_id": ..., "collection_ids": ...}}
+        """
+        {
+            "$or": [
+                {"document_id": {"$eq": "9fbe403b-..."}},
+                {"collection_ids": {"$in": ["122fdf6a-...", "..."]}}
+            ]
+        }
+        """
         loop = asyncio.get_event_loop()  # 获取当前事件循环
         loop.run_until_complete(r2r.init())  # 确保 r2r 已初始化
-        search_results = loop.run_until_complete(
-            r2r.search(query, filters={"file_key": {"$in": file_keys}})
-        )
+        search_results = loop.run_until_complete(r2r.search(query, filters=filters))
 
         if not search_results:
             return files
@@ -85,4 +99,30 @@ class R2RFileService(OSSFileService):
 
         return files
 
+    @staticmethod
+    def list_in_files(
+        ids: list[str] = None,
+        offset: int = 0,
+        limit: int = 100,
+    ) -> dict:
+        loop = asyncio.get_event_loop()  # 获取当前事件循环
+        loop.run_until_complete(r2r.init())  # 确保 r2r 已初始化
+        list_results = loop.run_until_complete(
+            r2r.list(ids=ids, offset=offset, limit=limit)
+        )
+        return list_results
+
+    @staticmethod
+    def list_documents(
+        ids: str = "",
+        offset: int = 0,
+        limit: int = 100,
+    ) -> dict:
+        loop = asyncio.get_event_loop()  # 获取当前事件循环
+        loop.run_until_complete(r2r.init())  # 确保 r2r 已初始化
+        list_results = loop.run_until_complete(
+            r2r.list_documents(ids=ids, offset=offset, limit=limit)
+        )
+        return list_results
+
     # TODO 删除s3&r2r文件