file_search_tool.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from typing import Type, List
  2. from pydantic import BaseModel, Field
  3. from sqlalchemy.orm import Session
  4. from app.core.tools.base_tool import BaseTool
  5. from app.models.run import Run
  6. from app.services.file.file import FileService
  7. from app.services.assistant.assistant import AssistantService
  8. # return '## important:You can use the "retrieval" tool to search for relevant information.\n If you are asking about the content of the files, please specify any keywords, topics, or context you are looking for to help retrieve the most relevant content.'
  9. # query: str = Field(
  10. # ...,
  11. # description="query to look up in retrieval",
  12. # )
  13. # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  14. # query: str = Field(..., description="query to look up in retrieval")
  15. class FileSearchToolInput(BaseModel):
  16. query: str = Field(
  17. description="query to look up in retrieval",
  18. )
  19. class FileSearchTool(BaseTool):
  20. name: str = "file_search"
  21. description: str = (
  22. "When a user requests specific knowledge, document content, or related information, the retrieval function should be invoked. This includes:\n"
  23. + "- User explicitly asking to find or retrieve certain information\n"
  24. + "- User inquiring about specific content in the knowledge base or uploaded files\n"
  25. + "- User questions involving specialized/domain-specific knowledge likely stored in the knowledge base\n"
  26. + "- User queries requiring access to complete documents or large datasets for accurate answers\n"
  27. + "Note: Singleton operation: Strictly 1 invocation per API call. Retrieval results should be used directly to answer the user's question without explaining the retrieval process.\n"
  28. )
  29. args_schema: Type[BaseModel] = FileSearchToolInput
  30. def __init__(self) -> None:
  31. super().__init__()
  32. self.__filenames = []
  33. self.__keys = []
  34. self.__dirkeys = []
  35. self.loop = None
  36. self.index = 0
  37. def configure(self, session: Session, run: Run, **kwargs):
  38. # 获取当前事件循环
  39. # document_id = []
  40. file_key = []
  41. # filesinfo = []
  42. # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
  43. for key in run.file_ids:
  44. if len(key) == 36:
  45. self.__keys.append(key) # 添加文件id 作为检索
  46. else:
  47. file_key.append(
  48. key
  49. ) ## assiatant的id数据,在r2r里没办法检索需要提取filekey字段
  50. print(
  51. "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
  52. )
  53. # print(document_id)
  54. print(file_key)
  55. files = []
  56. # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
  57. if len(file_key) > 0:
  58. ## 获取文件信息
  59. files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
  60. for file in files:
  61. self.__keys.append(file.key)
  62. print(files)
  63. # 读取assistant的数据,获取文件夹的id
  64. db_asst = AssistantService.get_assistant_sync(
  65. session=session, assistant_id=run.assistant_id
  66. )
  67. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  68. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  69. asst_folder_ids = (
  70. db_asst.tool_resources.get("file_search")
  71. .get("vector_stores")[0]
  72. .get("folder_ids")
  73. )
  74. print(asst_folder_ids)
  75. # folder_fileinfo = []
  76. if asst_folder_ids:
  77. self.__dirkeys = asst_folder_ids
  78. # pre-cache data to prevent thread conflicts that may occur later on.
  79. print(
  80. "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
  81. )
  82. print(self.__dirkeys)
  83. print(self.__keys)
  84. # indexes: List[int],
  85. def run(self, query: str) -> dict:
  86. print(
  87. "file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
  88. )
  89. print(self.__keys)
  90. print(self.__dirkeys)
  91. files = []
  92. if self.index == 0:
  93. files = FileService.search_in_files(
  94. query=query, file_keys=self.__keys, folder_keys=self.__dirkeys
  95. )
  96. self.index = 1
  97. print(files)
  98. return files
  99. def instruction_supplement(self) -> str:
  100. """
  101. 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
  102. """
  103. if (self.__keys and len(self.__keys) > 0) or (
  104. self.__dirkeys and len(self.__dirkeys) > 0
  105. ):
  106. return ""
  107. else:
  108. return ""