jack 3 viikkoa sitten
vanhempi
commit
0a71f5482b

+ 22 - 0
app/core/runner/thread_runner.py

@@ -85,6 +85,28 @@ class ThreadRunner:
         instructions = (
             [run.instructions or ""] if run.instructions else [ast.instructions or ""]
         )
+
+        asst_ids = []
+        if ast.tool_resources and "file_search" in ast.tool_resources:
+            asst_ids += (
+                ast.tool_resources.get("file_search")
+                .get("vector_stores")[0]
+                .get("folder_ids")
+            )
+            asst_ids += (
+                ast.tool_resources.get("file_search")
+                .get("vector_stores")[0]
+                .get("file_ids")
+            )
+
+        if len(asst_ids) > 0:
+            if len(run.file_ids) > 0:
+                run.tools.append({"type": "knowledge_search"})
+            else:
+                for tool in run.tools:
+                    if tool.get("type") == "file_search":
+                        tool["type"] = "knowledge_search"
+
         tools = find_tools(run, self.session)
         for tool in tools:
             tool.configure(session=self.session, run=run)

+ 3 - 0
app/core/tools/__init__.py

@@ -11,6 +11,7 @@ from app.core.tools.external_function_tool import ExternalFunctionTool
 from app.core.tools.openapi_function_tool import OpenapiFunctionTool
 from app.core.tools.file_search_tool import FileSearchTool
 from app.core.tools.web_search import WebSearchTool
+from app.core.tools.knowledge_search_tool import KnowledgeSearchTool
 
 # from app.core.tools.file_allcontent import FileContnetTool
 
@@ -18,12 +19,14 @@ from app.core.tools.web_search import WebSearchTool
 class AvailableTools(str, Enum):
     FILE_SEARCH = "file_search"
     WEB_SEARCH = "web_search"
+    KNOWLEDGE_SEARCH = "knowledge_search"
     # FILE_CONTENT = "file_content"
 
 
 TOOLS = {
     AvailableTools.FILE_SEARCH: FileSearchTool,
     AvailableTools.WEB_SEARCH: WebSearchTool,
+    AvailableTools.KNOWLEDGE_SEARCH: KnowledgeSearchTool,
     # AvailableTools.FILE_CONTENT: FileContnetTool,
 }
 

+ 2 - 1
app/core/tools/file_search_tool.py

@@ -69,6 +69,7 @@ class FileSearchTool(BaseTool):
             for file in files:
                 self.__keys.append(file.key)
             print(files)
+        """
         # 读取assistant的数据,获取文件夹的id
         db_asst = AssistantService.get_assistant_sync(
             session=session, assistant_id=run.assistant_id
@@ -85,7 +86,7 @@ class FileSearchTool(BaseTool):
             # folder_fileinfo = []
             if asst_folder_ids:
                 self.__dirkeys = asst_folder_ids
-
+        """
         # pre-cache data to prevent thread conflicts that may occur later on.
         print(
             "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"

+ 72 - 8
app/core/tools/knowledge_search_tool.py

@@ -32,7 +32,7 @@ class KnowledgeSearchTool(BaseTool):
         + "If the user is retrieve specified content from the knowledge base or file content, that is often a good hint that information may be here."
         + "Singleton operation: Strictly 1 invocation per API call"
     )
-    args_schema: Type[BaseModel] = KnowledgeSearchToolInput
+    args_schema: Type[BaseModel] = FileSearchToolInput
 
     def __init__(self) -> None:
         super().__init__()
@@ -45,13 +45,23 @@ class KnowledgeSearchTool(BaseTool):
     def configure(self, session: Session, run: Run, **kwargs):
         # 获取当前事件循环
         # document_id = []
-        print("====KnowledgeSearchToolKnowledgeSearchToolKnowledgeSearchTool====")
         file_key = []
+        files = []
+        # filesinfo = []
+        # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
+        """
         for key in run.file_ids:
             if len(key) == 36:
                 self.__keys.append(key)  # 添加文件id 作为检索
             else:
-                file_key.append(key)
+                file_key.append(
+                    key
+                )  ## assiatant的id数据,在r2r里没办法检索需要提取filekey字段
+
+        print(
+            "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
+        )
+        # print(document_id)
         print(file_key)
         files = []
         # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
@@ -60,17 +70,71 @@ class KnowledgeSearchTool(BaseTool):
             files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
             for file in files:
                 self.__keys.append(file.key)
+            print(files)
+        # 读取assistant的数据,获取文件夹的id
+        """
+        db_asst = AssistantService.get_assistant_sync(
+            session=session, assistant_id=run.assistant_id
+        )
+
+        if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
+            ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
+            asst_folder_ids = (
+                db_asst.tool_resources.get("file_search")
+                .get("vector_stores")[0]
+                .get("folder_ids")
+            )
+            file_key = (
+                db_asst.tool_resources.get("file_search")
+                .get("vector_stores")[0]
+                .get("file_ids")
+            )
+
+            print(asst_folder_ids)
+            print(file_key)
+            file_id = []
+            # folder_fileinfo = []
+            if asst_folder_ids:
+                self.__dirkeys = asst_folder_ids
+            if file_key:
+                for key in file_key:
+                    if len(key) == 36:
+                        self.__keys.append(key)  # 添加文件id 作为检索
+                    else:
+                        file_id.append(key)
+                if len(file_id) > 0:
+                    files = FileService.get_file_list_by_ids(
+                        session=session, file_ids=file_id
+                    )
+                    for file in files:
+                        self.__keys.append(file.key)
+                print(files)
+
+        # pre-cache data to prevent thread conflicts that may occur later on.
+        print(
+            "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
+        )
+        print(self.__dirkeys)
         print(self.__keys)
 
+    # indexes: List[int],
     def run(self, query: str) -> dict:
+        print(
+            "file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
+        )
+
         print(self.__keys)
+        print(self.__dirkeys)
         files = []
         ## 必须有总结的内容query和才能触发
         if self.index == 0 and query:
-            files = FileService.search_in_files(
-                query=query, file_keys=self.__keys, folder_keys=self.__dirkeys
-            )
-            self.index = 1
+            try:
+                files = FileService.search_in_files(
+                    query=query, file_keys=self.__keys, folder_keys=self.__dirkeys
+                )
+                self.index = 1
+            except Exception as e:
+                print(e)
         # print(files)
         return files
 
@@ -83,4 +147,4 @@ class KnowledgeSearchTool(BaseTool):
         ):
             return ""
         else:
-            return ""
+            return "如果您不确定用户发的文件内容或者代码库结构,请使用文件搜索工具读取内容并收集相关信息,不要瞎猜或者编造答案。"

+ 2 - 0
app/services/run/run.py

@@ -50,6 +50,7 @@ class RunService:
             body.instructions = body.additional_instructions
 
         file_ids = []
+        """
         asst_file_ids = db_asst.file_ids
         if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
             ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
@@ -60,6 +61,7 @@ class RunService:
             )
         if asst_file_ids:
             file_ids += asst_file_ids
+        """
 
         # get thread
         db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)