file_search_tool.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from typing import Type, List
  2. from pydantic import BaseModel, Field
  3. from sqlalchemy.orm import Session
  4. from app.core.tools.base_tool import BaseTool
  5. from app.models.run import Run
  6. from app.services.file.file import FileService
  7. from app.services.assistant.assistant import AssistantService
  8. # import asyncio
  9. import nest_asyncio
  10. # 使得异步代码可以在已运行的事件循环中嵌套
  11. nest_asyncio.apply()
  12. # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  13. # query: str = Field(..., description="query to look up in retrieval")
  14. class FileSearchToolInput(BaseModel):
  15. query: str = Field(..., description="query to look up in retrieval")
  16. class FileSearchTool(BaseTool):
  17. name: str = "file_search"
  18. description: str = (
  19. "Can be used to look up information that was uploaded to this assistant."
  20. # "If the user is referencing particular files, that is often a good hint that information may be here."
  21. "A search engine optimized for comprehensive, accurate, and trusted results. "
  22. "Useful for when you need to answer questions about current events. "
  23. "Input should be a search query."
  24. )
  25. args_schema: Type[BaseModel] = FileSearchToolInput
  26. def __init__(self) -> None:
  27. super().__init__()
  28. self.__filenames = []
  29. self.__keys = []
  30. self.__dirkeys = []
  31. self.loop = None
  32. def configure(self, session: Session, run: Run, **kwargs):
  33. # 获取当前事件循环
  34. # document_id = []
  35. file_key = []
  36. # filesinfo = []
  37. # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
  38. for key in run.file_ids:
  39. if len(key) == 36:
  40. self.__keys.append(key) # 添加文件id 作为检索
  41. else:
  42. file_key.append(
  43. key
  44. ) ## assiatant的id数据,在r2r里没办法检索需要提取filekey字段
  45. print(
  46. "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
  47. )
  48. # print(document_id)
  49. print(file_key)
  50. files = []
  51. # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
  52. if len(file_key) > 0:
  53. ## 获取文件信息
  54. files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
  55. for file in files:
  56. self.__keys.append(file.key)
  57. print(files)
  58. # r2r接口不提供多条件,否则上面没必要存在
  59. """
  60. if len(document_id) > 0:
  61. filesinfo += FileService.list_in_files(ids=document_id, offset=0, limit=100)
  62. # asyncio.run(
  63. # FileService.list_in_files(ids=document_id, offset=0, limit=100)
  64. # )
  65. for file in filesinfo:
  66. self.__filenames.append(file.get("title"))
  67. self.__keys.append(file.get("id"))
  68. print(filesinfo)
  69. """
  70. # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
  71. # 读取assistant的数据,获取文件夹的id
  72. db_asst = AssistantService.get_assistant_sync(
  73. session=session, assistant_id=run.assistant_id
  74. )
  75. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  76. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  77. asst_folder_ids = (
  78. db_asst.tool_resources.get("file_search")
  79. .get("vector_stores")[0]
  80. .get("folder_ids")
  81. )
  82. print(asst_folder_ids)
  83. # folder_fileinfo = []
  84. if asst_folder_ids:
  85. self.__dirkeys = asst_folder_ids
  86. """
  87. for fid in asst_folder_ids:
  88. folder_fileinfo += FileService.list_documents(
  89. id=fid, offset=0, limit=100
  90. )
  91. # folder_fileinfo += asyncio.run(
  92. # FileService.list_documents(id=fid, offset=0, limit=100)
  93. # )
  94. print(folder_fileinfo)
  95. for file in folder_fileinfo:
  96. self.__filenames.append(file.get("title"))
  97. self.__keys.append(file.get("id"))
  98. """
  99. # pre-cache data to prevent thread conflicts that may occur later on.
  100. print(
  101. "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
  102. )
  103. print(self.__dirkeys)
  104. """
  105. for file in files:
  106. self.__filenames.append(file.filename)
  107. self.__keys.append(file.key)
  108. """
  109. print(self.__keys)
  110. # indexes: List[int],
  111. def run(self, query: str) -> dict:
  112. print(
  113. "file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
  114. )
  115. print(self.__keys)
  116. print(self.__dirkeys)
  117. files = FileService.search_in_files(
  118. query=query, file_keys=self.__keys, folder_keys=self.__dirkeys
  119. )
  120. print(files)
  121. return files
  122. def instruction_supplement(self) -> str:
  123. """
  124. 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
  125. """
  126. if (self.__keys and len(self.__keys) > 0) or (
  127. self.__dirkeys and len(self.__dirkeys) > 0
  128. ):
  129. return (
  130. '当用户使用以下动词时,必须调用 "retrieval" tool: '
  131. + "- 搜索/查找/检索/调取/查看/找/有没有..."
  132. + "输入格式: "
  133. + "{"
  134. + ' "query": "用户问题中的关键词",'
  135. + "}"
  136. )
  137. else:
  138. return ""