from typing import Type, List from pydantic import BaseModel, Field from sqlalchemy.orm import Session from app.core.tools.base_tool import BaseTool from app.models.run import Run from app.services.file.file import FileService from app.services.assistant.assistant import AssistantService import asyncio import nest_asyncio # 使得异步代码可以在已运行的事件循环中嵌套 nest_asyncio.apply() # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) class FileSearchToolInput(BaseModel): indexes: List[int] = Field( ..., description="file index list to look up in retrieval" ) query: str = Field(..., description="query to look up in retrieval") class FileSearchTool(BaseTool): name: str = "file_search" description: str = ( "Can be used to look up information that was uploaded to this assistant." "If the user is referencing particular files, that is often a good hint that information may be here." ) args_schema: Type[BaseModel] = FileSearchToolInput def __init__(self) -> None: super().__init__() self.__filenames = [] self.__keys = [] self.loop = None def configure(self, session: Session, run: Run, **kwargs): """ 置当前 Retrieval 涉及文件信息 """ loop = asyncio.get_event_loop() # 获取当前事件循环 document_id = [] file_key = [] filesinfo = [] # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在 for key in run.file_ids: if len(key) == 36: document_id.append(key) else: file_key.append(key) print( "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id" ) print(document_id) print(file_key) files = [] # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉 if len(file_key) > 0: ## 获取文件信息 files = FileService.get_file_list_by_ids(session=session, file_ids=file_key) print(files) # r2r接口不提供多条件,否则上面没必要存在 if len(document_id) > 0: filesinfo += loop.run_until_complete( FileService.list_in_files(ids=document_id, offset=0, limit=100) ) for file in filesinfo: self.__filenames.append(file.get("title")) self.__keys.append(file.get("id")) print(filesinfo) # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100) db_asst = AssistantService.get_assistant_sync( session=session, assistant_id=run.assistant_id ) if db_asst.tool_resources and "file_search" in db_asst.tool_resources: ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}} asst_folder_ids = ( db_asst.tool_resources.get("file_search") .get("vector_stores")[0] .get("folder_ids") ) print(asst_folder_ids) folder_fileinfo = [] if asst_folder_ids: for fid in asst_folder_ids: folder_fileinfo += loop.run_until_complete( FileService.list_documents(id=fid, offset=0, limit=100) ) print(folder_fileinfo) for file in folder_fileinfo: self.__filenames.append(file.get("title")) self.__keys.append(file.get("id")) # pre-cache data to prevent thread conflicts that may occur later on. print( "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------" ) print(files) for file in files: self.__filenames.append(file.filename) self.__keys.append(file.key) print(self.__keys) def run(self, indexes: List[int], query: str) -> dict: file_keys = [] for index in indexes: file_key = self.__keys[index] file_keys.append(file_key) print(file_key) loop = asyncio.get_event_loop() files = loop.run_until_complete( FileService.search_in_files(query=query, file_keys=file_keys) ) return files def instruction_supplement(self) -> str: """ 为 Retrieval 提供文件选择信息,用于 llm 调用抉择 """ if len(self.__filenames) == 0: return "" else: filenames_info = [ f"({index}){filename}" for index, filename in enumerate(self.__filenames) ] return ( 'You can use the "retrieval" tool to retrieve relevant context from the following attached files. ' + 'Each line represents a file in the format "(index)filename":\n' + "\n".join(filenames_info) + "\nMake sure to be extremely concise when using attached files. " )