jack 3 kuukautta sitten
vanhempi
commit
e46dcd7c75

+ 1 - 9
app/core/runner/thread_runner.py

@@ -221,15 +221,7 @@ class ThreadRunner:
             internal_tool_calls = list(
                 filter(lambda _tool_calls: _tool_calls[0] is not None, tool_calls)
             )
-            """
-            seen = set()
-            internal_tool_calls = []
-            for _tool_call in tool_calls:
-                tool_obj = _tool_call[0]
-                if tool_obj is not None and tool_obj not in seen:
-                    seen.add(tool_obj)
-                    internal_tool_calls.append(_tool_call)
-            """
+
             external_tool_call_dict = [
                 tool_call_dict for tool, tool_call_dict in tool_calls if tool is None
             ]

+ 3 - 0
app/core/tools/__init__.py

@@ -11,16 +11,19 @@ from app.core.tools.external_function_tool import ExternalFunctionTool
 from app.core.tools.openapi_function_tool import OpenapiFunctionTool
 from app.core.tools.file_search_tool import FileSearchTool
 from app.core.tools.web_search import WebSearchTool
+from app.core.tools.file_allcontent import FileContnetTool
 
 
 class AvailableTools(str, Enum):
     FILE_SEARCH = "file_search"
     WEB_SEARCH = "web_search"
+    FILE_CONTENT = "file_content"
 
 
 TOOLS = {
     AvailableTools.FILE_SEARCH: FileSearchTool,
     AvailableTools.WEB_SEARCH: WebSearchTool,
+    AvailableTools.FILE_CONTENT: FileContnetTool,
 }
 
 

+ 28 - 0
app/core/tools/file_allcontent.py

@@ -0,0 +1,28 @@
+from typing import Type
+
+from pydantic import BaseModel, Field
+
+from app.core.tools.base_tool import BaseTool
+from config.llm import tool_settings
+from sqlalchemy.orm import Session
+from app.models.run import Run
+from app.services.file.file import FileService
+
+
+class FileContnetTool(BaseTool):
+    name: str = "file_content"
+    description: str = (
+        "读取文件的所有或者全部内容并返回给用户,这里每一次只允许触发一次"
+        "只有提到读取全部内容的时候才会返回全部内容,其他时候这个工具不会调用"
+        "和file_search工具不会同时使用,用了此工具就不会调用file_search"
+    )
+
+    file_ids: list[str] = []
+    args_schema: Type[BaseModel] = {}
+
+    def configure(self, session: Session, run: Run, **kwargs):
+        if run.file_ids is not None and len(run.file_ids) > 0:
+            self.file_ids = run.file_ids
+
+    def run(self) -> dict:
+        return FileService.retrieve_documents(ids=self.file_ids)

+ 3 - 34
app/core/tools/file_search_tool.py

@@ -33,7 +33,8 @@ class FileSearchTool(BaseTool):
     name: str = "file_search"
     description: str = (
         "Can be used to look up knowledge base information that was uploaded to this assistant."
-        + "If the user is referencing about specific content within files, that is often a good hint that information may be here."
+        + "If the user is retrieve specified content from the knowledge base or file system, that is often a good hint that information may be here."
+        + "Retrieve content from files or knowledge base (similar to database lookup, document search, or information fetching)"
         + "## Input Requirements:"
         + "The prompt must return a strictly standard JSON object! Absolutely no code blocks, comments, or extra symbols. Example format: {'query': 'query to look up in retrieval'}"
         + "Singleton operation: Strictly 1 invocation per API call"
@@ -75,21 +76,6 @@ class FileSearchTool(BaseTool):
             for file in files:
                 self.__keys.append(file.key)
             print(files)
-        # r2r接口不提供多条件,否则上面没必要存在
-        """
-        if len(document_id) > 0:
-            filesinfo += FileService.list_in_files(ids=document_id, offset=0, limit=100)
-            # asyncio.run(
-            #    FileService.list_in_files(ids=document_id, offset=0, limit=100)
-            # )
-            for file in filesinfo:
-                self.__filenames.append(file.get("title"))
-                self.__keys.append(file.get("id"))
-            print(filesinfo)
-        """
-
-        # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
-
         # 读取assistant的数据,获取文件夹的id
         db_asst = AssistantService.get_assistant_sync(
             session=session, assistant_id=run.assistant_id
@@ -106,29 +92,12 @@ class FileSearchTool(BaseTool):
             # folder_fileinfo = []
             if asst_folder_ids:
                 self.__dirkeys = asst_folder_ids
-            """
-                for fid in asst_folder_ids:
-                    folder_fileinfo += FileService.list_documents(
-                        id=fid, offset=0, limit=100
-                    )
-                    # folder_fileinfo += asyncio.run(
-                    #    FileService.list_documents(id=fid, offset=0, limit=100)
-                    # )
-                print(folder_fileinfo)
-                for file in folder_fileinfo:
-                    self.__filenames.append(file.get("title"))
-                    self.__keys.append(file.get("id"))
-            """
+
         # pre-cache data to prevent thread conflicts that may occur later on.
         print(
             "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
         )
         print(self.__dirkeys)
-        """
-        for file in files:
-            self.__filenames.append(file.filename)
-            self.__keys.append(file.key)
-        """
         print(self.__keys)
 
     # indexes: List[int],

+ 16 - 0
app/providers/r2r.py

@@ -54,6 +54,7 @@ class R2R:
         return await self.client.documents.create(
             file_path=file_path,
             metadata=metadata if metadata else None,
+            ingestion_mode="fast",
             id=None,
         )
 
@@ -106,6 +107,21 @@ class R2R:
         else:
             return []
 
+    def list_chunks(
+        self,
+        ids: list[str] = [],
+    ):
+        self._check_login_sync()
+        print(
+            "retrieve_documentsretrieve_documentsretrieve_documentsretrieve_documentsretrieve_documents"
+        )
+        print(ids)
+        allfile = []
+        for id in ids:
+            listed = self.client_sync.documents.list_chunks(id=id)
+            allfile += listed.get("results")
+        return allfile
+
     def list_documents(
         self,
         id: Optional[str] = "",

+ 5 - 0
app/services/file/impl/base.py

@@ -71,3 +71,8 @@ class BaseFileService(ABC):
         limit: int = 100,
     ) -> dict:
         pass
+
+    @staticmethod
+    @abstractmethod
+    def list_chunks(ids: list[str] = []) -> dict:
+        pass

+ 4 - 0
app/services/file/impl/oss_file.py

@@ -119,3 +119,7 @@ class OSSFileService(BaseFileService):
         limit: int = 100,
     ) -> dict:
         return []
+
+    @staticmethod
+    def list_chunks(ids: list[str] = []) -> dict:
+        return []

+ 19 - 0
app/services/file/impl/r2r_file.py

@@ -183,4 +183,23 @@ class R2RFileService(OSSFileService):
         list_results = r2r.list_documents(id=id, offset=offset, limit=limit)
         return list_results
 
+    @staticmethod
+    def list_chunks(ids: list[str] = []) -> dict:
+        if len(ids) > 0:
+            r2r.init_sync()
+            list_results = r2r.list_chunks(ids=ids)
+            files = {}
+            for doc in list_results:
+                file_key = doc.get("metadata").get("file_key")
+                file_key = (
+                    doc.get("metadata").get("title") if file_key is None else file_key
+                )
+                text = doc.get("text")
+                if file_key in files and files[file_key]:
+                    files[file_key] += f"\n\n{text}"
+                else:
+                    files[file_key] = doc.get("text")
+            return list_results
+        return {}
+
     # TODO 删除s3&r2r文件