|
@@ -16,12 +16,17 @@ nest_asyncio.apply()
|
|
|
|
|
|
# asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
|
|
|
|
-
|
|
|
+"""
|
|
|
class FileSearchToolInput(BaseModel):
|
|
|
indexes: List[int] = Field(
|
|
|
..., description="file index list to look up in retrieval"
|
|
|
)
|
|
|
query: str = Field(..., description="query to look up in retrieval")
|
|
|
+"""
|
|
|
+
|
|
|
+
|
|
|
+class FileSearchToolInput(BaseModel):
|
|
|
+ query: str = Field(..., description="query to look up in retrieval")
|
|
|
|
|
|
|
|
|
class FileSearchTool(BaseTool):
|
|
@@ -37,18 +42,142 @@ class FileSearchTool(BaseTool):
|
|
|
super().__init__()
|
|
|
self.__filenames = []
|
|
|
self.__keys = []
|
|
|
+ self.__dirkeys = []
|
|
|
self.loop = None
|
|
|
|
|
|
def configure(self, session: Session, run: Run, **kwargs):
|
|
|
+ # 获取当前事件循环
|
|
|
+ # document_id = []
|
|
|
+ file_key = []
|
|
|
+ # filesinfo = []
|
|
|
+ # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
|
|
|
+ for key in run.file_ids:
|
|
|
+ if len(key) == 36:
|
|
|
+ self.__keys.append(key) # 添加文件id 作为检索
|
|
|
+ else:
|
|
|
+ file_key.append(
|
|
|
+ key
|
|
|
+ ) ## assiatant的id数据,在r2r里没办法检索需要提取filekey字段
|
|
|
+
|
|
|
+ print(
|
|
|
+ "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
|
|
|
+ )
|
|
|
+ # print(document_id)
|
|
|
+ print(file_key)
|
|
|
+ files = []
|
|
|
+ # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
|
|
|
+ if len(file_key) > 0:
|
|
|
+ ## 获取文件信息
|
|
|
+ files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
|
|
|
+ for file in files:
|
|
|
+ self.__keys.append(file.key)
|
|
|
+ print(files)
|
|
|
+ # r2r接口不提供多条件,否则上面没必要存在
|
|
|
+ """
|
|
|
+ if len(document_id) > 0:
|
|
|
+ filesinfo += FileService.list_in_files(ids=document_id, offset=0, limit=100)
|
|
|
+ # asyncio.run(
|
|
|
+ # FileService.list_in_files(ids=document_id, offset=0, limit=100)
|
|
|
+ # )
|
|
|
+ for file in filesinfo:
|
|
|
+ self.__filenames.append(file.get("title"))
|
|
|
+ self.__keys.append(file.get("id"))
|
|
|
+ print(filesinfo)
|
|
|
+ """
|
|
|
+
|
|
|
+ # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
|
|
|
+
|
|
|
+ # 读取assistant的数据,获取文件夹的id
|
|
|
+ db_asst = AssistantService.get_assistant_sync(
|
|
|
+ session=session, assistant_id=run.assistant_id
|
|
|
+ )
|
|
|
+
|
|
|
+ if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
|
|
|
+ ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
|
|
|
+ asst_folder_ids = (
|
|
|
+ db_asst.tool_resources.get("file_search")
|
|
|
+ .get("vector_stores")[0]
|
|
|
+ .get("folder_ids")
|
|
|
+ )
|
|
|
+ print(asst_folder_ids)
|
|
|
+ # folder_fileinfo = []
|
|
|
+ if asst_folder_ids:
|
|
|
+ self.__dirkeys = asst_folder_ids
|
|
|
+ """
|
|
|
+ for fid in asst_folder_ids:
|
|
|
+ folder_fileinfo += FileService.list_documents(
|
|
|
+ id=fid, offset=0, limit=100
|
|
|
+ )
|
|
|
+ # folder_fileinfo += asyncio.run(
|
|
|
+ # FileService.list_documents(id=fid, offset=0, limit=100)
|
|
|
+ # )
|
|
|
+ print(folder_fileinfo)
|
|
|
+ for file in folder_fileinfo:
|
|
|
+ self.__filenames.append(file.get("title"))
|
|
|
+ self.__keys.append(file.get("id"))
|
|
|
+ """
|
|
|
+ # pre-cache data to prevent thread conflicts that may occur later on.
|
|
|
+ print(
|
|
|
+ "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
|
|
|
+ )
|
|
|
+ print(self.__dirkeys)
|
|
|
"""
|
|
|
- # 提交任务到事件循环
|
|
|
- future = asyncio.run_coroutine_threadsafe(async_task(), loop)
|
|
|
- # 阻塞等待结果
|
|
|
- result = future.result()
|
|
|
+ for file in files:
|
|
|
+ self.__filenames.append(file.filename)
|
|
|
+ self.__keys.append(file.key)
|
|
|
"""
|
|
|
+ print(self.__keys)
|
|
|
+
|
|
|
+ # indexes: List[int],
|
|
|
+ def run(self, query: str) -> dict:
|
|
|
+
|
|
|
+ print(
|
|
|
+ "file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
|
|
|
+ )
|
|
|
+
|
|
|
+ print(self.__keys)
|
|
|
+ print(self.__dirkeys)
|
|
|
+ files = FileService.search_in_files(
|
|
|
+ query=query, file_keys=self.__keys, folder_keys=self.__dirkeys
|
|
|
+ )
|
|
|
+ print(files)
|
|
|
+ return files
|
|
|
+
|
|
|
"""
|
|
|
- 置当前 Retrieval 涉及文件信息
|
|
|
+ file_keys = []
|
|
|
+ for index in indexes:
|
|
|
+ if index is not None:
|
|
|
+ file_key = self.__keys[index]
|
|
|
+ file_keys.append(file_key)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ print(file_keys)
|
|
|
+ files = []
|
|
|
+ if len(file_keys) > 0:
|
|
|
+ # self.loop = asyncio.get_event_loop()
|
|
|
+
|
|
|
+ # files = asyncio.run(
|
|
|
+ # FileService.search_in_files(query=query, file_keys=file_keys)
|
|
|
+ # )
|
|
|
+ print(files)
|
|
|
+ return files
|
|
|
+ """
|
|
|
+
|
|
|
+ def instruction_supplement(self) -> str:
|
|
|
+ """
|
|
|
+ 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
|
|
|
"""
|
|
|
+ return (
|
|
|
+ 'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
|
|
|
+ # + 'Each line represents a file in the format "(index)filename":\n'
|
|
|
+ # + "\n".join(filenames_info)
|
|
|
+ + "\nMake sure to be extremely concise when using attached files. "
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+'''
|
|
|
+ def configure(self, session: Session, run: Run, **kwargs):
|
|
|
# 获取当前事件循环
|
|
|
document_id = []
|
|
|
file_key = []
|
|
@@ -119,6 +248,8 @@ class FileSearchTool(BaseTool):
|
|
|
self.__keys.append(file.key)
|
|
|
print(self.__keys)
|
|
|
|
|
|
+
|
|
|
+
|
|
|
def run(self, indexes: List[int], query: str) -> dict:
|
|
|
|
|
|
file_keys = []
|
|
@@ -158,3 +289,4 @@ class FileSearchTool(BaseTool):
|
|
|
+ "\n".join(filenames_info)
|
|
|
+ "\nMake sure to be extremely concise when using attached files. "
|
|
|
)
|
|
|
+'''
|