|
@@ -14,10 +14,13 @@ import nest_asyncio
|
|
|
# 使得异步代码可以在已运行的事件循环中嵌套
|
|
|
nest_asyncio.apply()
|
|
|
|
|
|
-
|
|
|
# asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
|
-# query: str = Field(..., description="query to look up in retrieval")
|
|
|
+
|
|
|
+
|
|
|
class FileSearchToolInput(BaseModel):
|
|
|
+ indexes: List[int] = Field(
|
|
|
+ ..., description="file index list to look up in retrieval"
|
|
|
+ )
|
|
|
query: str = Field(..., description="query to look up in retrieval")
|
|
|
|
|
|
|
|
@@ -25,10 +28,7 @@ class FileSearchTool(BaseTool):
|
|
|
name: str = "file_search"
|
|
|
description: str = (
|
|
|
"Can be used to look up information that was uploaded to this assistant."
|
|
|
- # "If the user is referencing particular files, that is often a good hint that information may be here."
|
|
|
- "A search engine optimized for comprehensive, accurate, and trusted results. "
|
|
|
- "Useful for when you need to answer questions about current events. "
|
|
|
- "Input should be a search query."
|
|
|
+ "If the user is referencing particular files, that is often a good hint that information may be here."
|
|
|
)
|
|
|
|
|
|
args_schema: Type[BaseModel] = FileSearchToolInput
|
|
@@ -37,38 +37,41 @@ class FileSearchTool(BaseTool):
|
|
|
super().__init__()
|
|
|
self.__filenames = []
|
|
|
self.__keys = []
|
|
|
- self.__dirkeys = []
|
|
|
self.loop = None
|
|
|
|
|
|
def configure(self, session: Session, run: Run, **kwargs):
|
|
|
+ """
|
|
|
+ # 提交任务到事件循环
|
|
|
+ future = asyncio.run_coroutine_threadsafe(async_task(), loop)
|
|
|
+ # 阻塞等待结果
|
|
|
+ result = future.result()
|
|
|
+ """
|
|
|
+ """
|
|
|
+ 置当前 Retrieval 涉及文件信息
|
|
|
+ """
|
|
|
# 获取当前事件循环
|
|
|
- # document_id = []
|
|
|
+ document_id = []
|
|
|
file_key = []
|
|
|
- # filesinfo = []
|
|
|
+ filesinfo = []
|
|
|
# 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
|
|
|
for key in run.file_ids:
|
|
|
if len(key) == 36:
|
|
|
- self.__keys.append(key) # 添加文件id 作为检索
|
|
|
+ document_id.append(key)
|
|
|
else:
|
|
|
- file_key.append(
|
|
|
- key
|
|
|
- ) ## assiatant的id数据,在r2r里没办法检索需要提取filekey字段
|
|
|
+ file_key.append(key)
|
|
|
|
|
|
print(
|
|
|
"document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
|
|
|
)
|
|
|
- # print(document_id)
|
|
|
+ print(document_id)
|
|
|
print(file_key)
|
|
|
files = []
|
|
|
# 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
|
|
|
if len(file_key) > 0:
|
|
|
## 获取文件信息
|
|
|
files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
|
|
|
- for file in files:
|
|
|
- self.__keys.append(file.key)
|
|
|
print(files)
|
|
|
# r2r接口不提供多条件,否则上面没必要存在
|
|
|
- """
|
|
|
if len(document_id) > 0:
|
|
|
filesinfo += FileService.list_in_files(ids=document_id, offset=0, limit=100)
|
|
|
# asyncio.run(
|
|
@@ -78,11 +81,8 @@ class FileSearchTool(BaseTool):
|
|
|
self.__filenames.append(file.get("title"))
|
|
|
self.__keys.append(file.get("id"))
|
|
|
print(filesinfo)
|
|
|
- """
|
|
|
|
|
|
# files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
|
|
|
-
|
|
|
- # 读取assistant的数据,获取文件夹的id
|
|
|
db_asst = AssistantService.get_assistant_sync(
|
|
|
session=session, assistant_id=run.assistant_id
|
|
|
)
|
|
@@ -95,10 +95,8 @@ class FileSearchTool(BaseTool):
|
|
|
.get("folder_ids")
|
|
|
)
|
|
|
print(asst_folder_ids)
|
|
|
- # folder_fileinfo = []
|
|
|
+ folder_fileinfo = []
|
|
|
if asst_folder_ids:
|
|
|
- self.__dirkeys = asst_folder_ids
|
|
|
- """
|
|
|
for fid in asst_folder_ids:
|
|
|
folder_fileinfo += FileService.list_documents(
|
|
|
id=fid, offset=0, limit=100
|
|
@@ -110,30 +108,36 @@ class FileSearchTool(BaseTool):
|
|
|
for file in folder_fileinfo:
|
|
|
self.__filenames.append(file.get("title"))
|
|
|
self.__keys.append(file.get("id"))
|
|
|
- """
|
|
|
+
|
|
|
# pre-cache data to prevent thread conflicts that may occur later on.
|
|
|
print(
|
|
|
"---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
|
|
|
)
|
|
|
- print(self.__dirkeys)
|
|
|
- """
|
|
|
+ print(files)
|
|
|
for file in files:
|
|
|
self.__filenames.append(file.filename)
|
|
|
self.__keys.append(file.key)
|
|
|
- """
|
|
|
print(self.__keys)
|
|
|
|
|
|
- # indexes: List[int],
|
|
|
- def run(self, query: str) -> dict:
|
|
|
+ def run(self, indexes: List[int], query: str) -> dict:
|
|
|
+
|
|
|
+ file_keys = []
|
|
|
+
|
|
|
+ for index in indexes:
|
|
|
+ if index is not None:
|
|
|
+ file_key = self.__keys[index]
|
|
|
+ file_keys.append(file_key)
|
|
|
print(
|
|
|
"file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
|
|
|
)
|
|
|
-
|
|
|
- print(self.__keys)
|
|
|
- print(self.__dirkeys)
|
|
|
- files = FileService.search_in_files(
|
|
|
- query=query, file_keys=self.__keys, folder_keys=self.__dirkeys
|
|
|
- )
|
|
|
+ print(file_keys)
|
|
|
+ files = []
|
|
|
+ if len(file_keys) > 0:
|
|
|
+ # self.loop = asyncio.get_event_loop()
|
|
|
+ files = FileService.search_in_files(query=query, file_keys=file_keys)
|
|
|
+ # files = asyncio.run(
|
|
|
+ # FileService.search_in_files(query=query, file_keys=file_keys)
|
|
|
+ # )
|
|
|
print(files)
|
|
|
return files
|
|
|
|
|
@@ -141,26 +145,16 @@ class FileSearchTool(BaseTool):
|
|
|
"""
|
|
|
为 Retrieval 提供文件选择信息,用于 llm 调用抉择
|
|
|
"""
|
|
|
- if (self.__keys and len(self.__keys) > 0) or (
|
|
|
- self.__dirkeys and len(self.__dirkeys) > 0
|
|
|
- ):
|
|
|
+ if len(self.__filenames) == 0:
|
|
|
+ return ""
|
|
|
+ else:
|
|
|
+ filenames_info = [
|
|
|
+ f"({index}){filename}"
|
|
|
+ for index, filename in enumerate(self.__filenames)
|
|
|
+ ]
|
|
|
return (
|
|
|
- "## 工具使用规范"
|
|
|
- + "可调用工具:"
|
|
|
- + "- retrieval:用于在文件库中搜索与问题相关的具体内容"
|
|
|
- + "**调用规则**:"
|
|
|
- + "1. 当问题涉及以下情况时必须调用本工具:"
|
|
|
- + " - 询问文件/文档中的具体内容"
|
|
|
- + " - 需要查找数据、条款或技术细节"
|
|
|
- + ' - 用户明确要求"查文件"或"搜索资料"'
|
|
|
- + "2. 调用时需遵循:"
|
|
|
- + " ```json"
|
|
|
- + " {"
|
|
|
- + ' "action": "retrieval",'
|
|
|
- + ' "action_input": {'
|
|
|
- + ' "query": "精炼后的搜索语句,需包含至少2个关键要素(用户问题的原始上下文)"'
|
|
|
- + " }"
|
|
|
- + " }"
|
|
|
+ 'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
|
|
|
+ + 'Each line represents a file in the format "(index)filename":\n'
|
|
|
+ + "\n".join(filenames_info)
|
|
|
+ + "\nMake sure to be extremely concise when using attached files. "
|
|
|
)
|
|
|
- else:
|
|
|
- return ""
|