file_search_tool.py 11 KB


  1. from typing import Type, List
  2. from pydantic import BaseModel, Field
  3. from sqlalchemy.orm import Session
  4. from app.core.tools.base_tool import BaseTool
  5. from app.models.run import Run
  6. from app.services.file.file import FileService
  7. from app.services.assistant.assistant import AssistantService
  8. # import asyncio
  9. import nest_asyncio
  10. # 使得异步代码可以在已运行的事件循环中嵌套
  11. nest_asyncio.apply()
  12. # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  13. """
  14. class FileSearchToolInput(BaseModel):
  15. indexes: List[int] = Field(
  16. ..., description="file index list to look up in retrieval"
  17. )
  18. query: str = Field(..., description="query to look up in retrieval")
  19. """
  20. class FileSearchToolInput(BaseModel):
  21. query: str = Field(..., description="query to look up in retrieval")
  22. class FileSearchTool(BaseTool):
  23. name: str = "file_search"
  24. description: str = (
  25. "Can be used to look up information that was uploaded to this assistant."
  26. "If the user is referencing particular files, that is often a good hint that information may be here."
  27. )
  28. args_schema: Type[BaseModel] = FileSearchToolInput
  29. def __init__(self) -> None:
  30. super().__init__()
  31. self.__filenames = []
  32. self.__keys = []
  33. self.__dirkeys = []
  34. self.loop = None
  35. def configure(self, session: Session, run: Run, **kwargs):
  36. # 获取当前事件循环
  37. # document_id = []
  38. file_key = []
  39. # filesinfo = []
  40. # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
  41. for key in run.file_ids:
  42. if len(key) == 36:
  43. self.__keys.append(key) # 添加文件id 作为检索
  44. else:
  45. file_key.append(
  46. key
  47. ) ## assiatant的id数据,在r2r里没办法检索需要提取filekey字段
  48. print(
  49. "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
  50. )
  51. # print(document_id)
  52. print(file_key)
  53. files = []
  54. # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
  55. if len(file_key) > 0:
  56. ## 获取文件信息
  57. files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
  58. for file in files:
  59. self.__keys.append(file.key)
  60. print(files)
  61. # r2r接口不提供多条件,否则上面没必要存在
  62. """
  63. if len(document_id) > 0:
  64. filesinfo += FileService.list_in_files(ids=document_id, offset=0, limit=100)
  65. # asyncio.run(
  66. # FileService.list_in_files(ids=document_id, offset=0, limit=100)
  67. # )
  68. for file in filesinfo:
  69. self.__filenames.append(file.get("title"))
  70. self.__keys.append(file.get("id"))
  71. print(filesinfo)
  72. """
  73. # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
  74. # 读取assistant的数据,获取文件夹的id
  75. db_asst = AssistantService.get_assistant_sync(
  76. session=session, assistant_id=run.assistant_id
  77. )
  78. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  79. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  80. asst_folder_ids = (
  81. db_asst.tool_resources.get("file_search")
  82. .get("vector_stores")[0]
  83. .get("folder_ids")
  84. )
  85. print(asst_folder_ids)
  86. # folder_fileinfo = []
  87. if asst_folder_ids:
  88. self.__dirkeys = asst_folder_ids
  89. """
  90. for fid in asst_folder_ids:
  91. folder_fileinfo += FileService.list_documents(
  92. id=fid, offset=0, limit=100
  93. )
  94. # folder_fileinfo += asyncio.run(
  95. # FileService.list_documents(id=fid, offset=0, limit=100)
  96. # )
  97. print(folder_fileinfo)
  98. for file in folder_fileinfo:
  99. self.__filenames.append(file.get("title"))
  100. self.__keys.append(file.get("id"))
  101. """
  102. # pre-cache data to prevent thread conflicts that may occur later on.
  103. print(
  104. "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
  105. )
  106. print(self.__dirkeys)
  107. """
  108. for file in files:
  109. self.__filenames.append(file.filename)
  110. self.__keys.append(file.key)
  111. """
  112. print(self.__keys)
  113. # indexes: List[int],
  114. def run(self, query: str) -> dict:
  115. print(
  116. "file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
  117. )
  118. print(self.__keys)
  119. print(self.__dirkeys)
  120. files = FileService.search_in_files(
  121. query=query, file_keys=self.__keys, folder_keys=self.__dirkeys
  122. )
  123. print(files)
  124. return files
  125. """
  126. file_keys = []
  127. for index in indexes:
  128. if index is not None:
  129. file_key = self.__keys[index]
  130. file_keys.append(file_key)
  131. print(file_keys)
  132. files = []
  133. if len(file_keys) > 0:
  134. # self.loop = asyncio.get_event_loop()
  135. # files = asyncio.run(
  136. # FileService.search_in_files(query=query, file_keys=file_keys)
  137. # )
  138. print(files)
  139. return files
  140. """
  141. def instruction_supplement(self) -> str:
  142. """
  143. 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
  144. """
  145. return (
  146. 'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
  147. # + 'Each line represents a file in the format "(index)filename":\n'
  148. # + "\n".join(filenames_info)
  149. + "\nMake sure to be extremely concise when using attached files. "
  150. )
  151. '''
  152. def configure(self, session: Session, run: Run, **kwargs):
  153. # 获取当前事件循环
  154. document_id = []
  155. file_key = []
  156. filesinfo = []
  157. # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
  158. for key in run.file_ids:
  159. if len(key) == 36:
  160. document_id.append(key)
  161. else:
  162. file_key.append(key)
  163. print(
  164. "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
  165. )
  166. print(document_id)
  167. print(file_key)
  168. files = []
  169. # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
  170. if len(file_key) > 0:
  171. ## 获取文件信息
  172. files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
  173. print(files)
  174. # r2r接口不提供多条件,否则上面没必要存在
  175. if len(document_id) > 0:
  176. filesinfo += FileService.list_in_files(ids=document_id, offset=0, limit=100)
  177. # asyncio.run(
  178. # FileService.list_in_files(ids=document_id, offset=0, limit=100)
  179. # )
  180. for file in filesinfo:
  181. self.__filenames.append(file.get("title"))
  182. self.__keys.append(file.get("id"))
  183. print(filesinfo)
  184. # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
  185. db_asst = AssistantService.get_assistant_sync(
  186. session=session, assistant_id=run.assistant_id
  187. )
  188. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  189. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  190. asst_folder_ids = (
  191. db_asst.tool_resources.get("file_search")
  192. .get("vector_stores")[0]
  193. .get("folder_ids")
  194. )
  195. print(asst_folder_ids)
  196. folder_fileinfo = []
  197. if asst_folder_ids:
  198. for fid in asst_folder_ids:
  199. folder_fileinfo += FileService.list_documents(
  200. id=fid, offset=0, limit=100
  201. )
  202. # folder_fileinfo += asyncio.run(
  203. # FileService.list_documents(id=fid, offset=0, limit=100)
  204. # )
  205. print(folder_fileinfo)
  206. for file in folder_fileinfo:
  207. self.__filenames.append(file.get("title"))
  208. self.__keys.append(file.get("id"))
  209. # pre-cache data to prevent thread conflicts that may occur later on.
  210. print(
  211. "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
  212. )
  213. print(files)
  214. for file in files:
  215. self.__filenames.append(file.filename)
  216. self.__keys.append(file.key)
  217. print(self.__keys)
  218. def run(self, indexes: List[int], query: str) -> dict:
  219. file_keys = []
  220. for index in indexes:
  221. if index is not None:
  222. file_key = self.__keys[index]
  223. file_keys.append(file_key)
  224. print(
  225. "file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
  226. )
  227. print(file_keys)
  228. files = []
  229. if len(file_keys) > 0:
  230. # self.loop = asyncio.get_event_loop()
  231. files = FileService.search_in_files(query=query, file_keys=file_keys)
  232. # files = asyncio.run(
  233. # FileService.search_in_files(query=query, file_keys=file_keys)
  234. # )
  235. print(files)
  236. return files
  237. def instruction_supplement(self) -> str:
  238. """
  239. 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
  240. """
  241. if len(self.__filenames) == 0:
  242. return ""
  243. else:
  244. filenames_info = [
  245. f"({index}){filename}"
  246. for index, filename in enumerate(self.__filenames)
  247. ]
  248. return (
  249. 'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
  250. + 'Each line represents a file in the format "(index)filename":\n'
  251. + "\n".join(filenames_info)
  252. + "\nMake sure to be extremely concise when using attached files. "
  253. )
  254. '''