123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- from typing import Type, List
- from pydantic import BaseModel, Field
- from sqlalchemy.orm import Session
- from app.core.tools.base_tool import BaseTool
- from app.models.run import Run
- from app.services.file.file import FileService
- from app.services.assistant.assistant import AssistantService
- import asyncio
- import nest_asyncio
- # 使得异步代码可以在已运行的事件循环中嵌套
- nest_asyncio.apply()
- # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
- class FileSearchToolInput(BaseModel):
- indexes: List[int] = Field(
- ..., description="file index list to look up in retrieval"
- )
- query: str = Field(..., description="query to look up in retrieval")
- class FileSearchTool(BaseTool):
- name: str = "file_search"
- description: str = (
- "Can be used to look up information that was uploaded to this assistant."
- "If the user is referencing particular files, that is often a good hint that information may be here."
- )
- args_schema: Type[BaseModel] = FileSearchToolInput
- def __init__(self) -> None:
- super().__init__()
- self.__filenames = []
- self.__keys = []
- self.loop = None
- def configure(self, session: Session, run: Run, **kwargs):
- """
- 置当前 Retrieval 涉及文件信息
- """
- # self.loop = asyncio.get_event_loop() # 获取当前事件循环
- document_id = []
- file_key = []
- filesinfo = []
- # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
- for key in run.file_ids:
- if len(key) == 36:
- document_id.append(key)
- else:
- file_key.append(key)
- print(
- "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
- )
- print(document_id)
- print(file_key)
- files = []
- # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
- if len(file_key) > 0:
- ## 获取文件信息
- files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
- print(files)
- # r2r接口不提供多条件,否则上面没必要存在
- if len(document_id) > 0:
- filesinfo += asyncio.run(
- FileService.list_in_files(ids=document_id, offset=0, limit=100)
- )
- for file in filesinfo:
- self.__filenames.append(file.get("title"))
- self.__keys.append(file.get("id"))
- print(filesinfo)
- # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
- db_asst = AssistantService.get_assistant_sync(
- session=session, assistant_id=run.assistant_id
- )
- if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
- ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
- asst_folder_ids = (
- db_asst.tool_resources.get("file_search")
- .get("vector_stores")[0]
- .get("folder_ids")
- )
- print(asst_folder_ids)
- folder_fileinfo = []
- if asst_folder_ids:
- for fid in asst_folder_ids:
- folder_fileinfo += asyncio.run(
- FileService.list_documents(id=fid, offset=0, limit=100)
- )
- print(folder_fileinfo)
- for file in folder_fileinfo:
- self.__filenames.append(file.get("title"))
- self.__keys.append(file.get("id"))
- # pre-cache data to prevent thread conflicts that may occur later on.
- print(
- "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
- )
- print(files)
- for file in files:
- self.__filenames.append(file.filename)
- self.__keys.append(file.key)
- print(self.__keys)
- def run(self, indexes: List[int], query: str) -> dict:
- file_keys = []
- for index in indexes:
- if index is not None:
- file_key = self.__keys[index]
- file_keys.append(file_key)
- print(file_key)
- files = []
- if len(file_key) > 0:
- # self.loop = asyncio.get_event_loop()
- files = asyncio.run(
- FileService.search_in_files(query=query, file_keys=file_keys)
- )
- return files
- def instruction_supplement(self) -> str:
- """
- 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
- """
- if len(self.__filenames) == 0:
- return ""
- else:
- filenames_info = [
- f"({index}){filename}"
- for index, filename in enumerate(self.__filenames)
- ]
- return (
- 'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
- + 'Each line represents a file in the format "(index)filename":\n'
- + "\n".join(filenames_info)
- + "\nMake sure to be extremely concise when using attached files. "
- )
|