file_search_tool.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. from typing import Type, List
  2. from pydantic import BaseModel, Field
  3. from sqlalchemy.orm import Session
  4. from app.core.tools.base_tool import BaseTool
  5. from app.models.run import Run
  6. from app.services.file.file import FileService
  7. from app.services.assistant.assistant import AssistantService
  8. # import asyncio
  9. import nest_asyncio
  10. # 使得异步代码可以在已运行的事件循环中嵌套
  11. nest_asyncio.apply()
  12. # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  13. """
  14. class FileSearchToolInput(BaseModel):
  15. indexes: List[int] = Field(
  16. ..., description="file index list to look up in retrieval"
  17. )
  18. query: str = Field(..., description="query to look up in retrieval")
  19. """
  20. class FileSearchToolInput(BaseModel):
  21. query: str = Field(..., description="query to look up in retrieval")
  22. class FileSearchTool(BaseTool):
  23. name: str = "file_search"
  24. description: str = (
  25. "Can be used to look up information that was uploaded to this assistant."
  26. # "If the user is referencing particular files, that is often a good hint that information may be here."
  27. "A search engine optimized for comprehensive, accurate, and trusted results. "
  28. "Useful for when you need to answer questions about current events. "
  29. "Input should be a search query."
  30. )
  31. args_schema: Type[BaseModel] = FileSearchToolInput
  32. def __init__(self) -> None:
  33. super().__init__()
  34. self.__filenames = []
  35. self.__keys = []
  36. self.__dirkeys = []
  37. self.loop = None
  38. def configure(self, session: Session, run: Run, **kwargs):
  39. # 获取当前事件循环
  40. # document_id = []
  41. file_key = []
  42. # filesinfo = []
  43. # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
  44. for key in run.file_ids:
  45. if len(key) == 36:
  46. self.__keys.append(key) # 添加文件id 作为检索
  47. else:
  48. file_key.append(
  49. key
  50. ) ## assiatant的id数据,在r2r里没办法检索需要提取filekey字段
  51. print(
  52. "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
  53. )
  54. # print(document_id)
  55. print(file_key)
  56. files = []
  57. # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
  58. if len(file_key) > 0:
  59. ## 获取文件信息
  60. files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
  61. for file in files:
  62. self.__keys.append(file.key)
  63. print(files)
  64. # r2r接口不提供多条件,否则上面没必要存在
  65. """
  66. if len(document_id) > 0:
  67. filesinfo += FileService.list_in_files(ids=document_id, offset=0, limit=100)
  68. # asyncio.run(
  69. # FileService.list_in_files(ids=document_id, offset=0, limit=100)
  70. # )
  71. for file in filesinfo:
  72. self.__filenames.append(file.get("title"))
  73. self.__keys.append(file.get("id"))
  74. print(filesinfo)
  75. """
  76. # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
  77. # 读取assistant的数据,获取文件夹的id
  78. db_asst = AssistantService.get_assistant_sync(
  79. session=session, assistant_id=run.assistant_id
  80. )
  81. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  82. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  83. asst_folder_ids = (
  84. db_asst.tool_resources.get("file_search")
  85. .get("vector_stores")[0]
  86. .get("folder_ids")
  87. )
  88. print(asst_folder_ids)
  89. # folder_fileinfo = []
  90. if asst_folder_ids:
  91. self.__dirkeys = asst_folder_ids
  92. """
  93. for fid in asst_folder_ids:
  94. folder_fileinfo += FileService.list_documents(
  95. id=fid, offset=0, limit=100
  96. )
  97. # folder_fileinfo += asyncio.run(
  98. # FileService.list_documents(id=fid, offset=0, limit=100)
  99. # )
  100. print(folder_fileinfo)
  101. for file in folder_fileinfo:
  102. self.__filenames.append(file.get("title"))
  103. self.__keys.append(file.get("id"))
  104. """
  105. # pre-cache data to prevent thread conflicts that may occur later on.
  106. print(
  107. "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
  108. )
  109. print(self.__dirkeys)
  110. """
  111. for file in files:
  112. self.__filenames.append(file.filename)
  113. self.__keys.append(file.key)
  114. """
  115. print(self.__keys)
  116. # indexes: List[int],
  117. def run(self, query: str) -> dict:
  118. print(
  119. "file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
  120. )
  121. print(self.__keys)
  122. print(self.__dirkeys)
  123. files = FileService.search_in_files(
  124. query=query, file_keys=self.__keys, folder_keys=self.__dirkeys
  125. )
  126. print(files)
  127. return files
  128. """
  129. file_keys = []
  130. for index in indexes:
  131. if index is not None:
  132. file_key = self.__keys[index]
  133. file_keys.append(file_key)
  134. print(file_keys)
  135. files = []
  136. if len(file_keys) > 0:
  137. # self.loop = asyncio.get_event_loop()
  138. # files = asyncio.run(
  139. # FileService.search_in_files(query=query, file_keys=file_keys)
  140. # )
  141. print(files)
  142. return files
  143. """
  144. def instruction_supplement(self) -> str:
  145. """
  146. 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
  147. return (
  148. 'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
  149. # + 'Each line represents a file in the format "(index)filename":\n'
  150. # + "\n".join(filenames_info)
  151. + "\nMake sure to be extremely concise when using attached files. "
  152. )
  153. """
  154. return 'You can use the "retrieval" tool to search for relevant information. Please keywords or context you are looking for to retrieve the most relevant content.'
  155. #'You can use the "retrieval" to search for relevant information within the attached files. Please specify the keywords or context you are looking for to retrieve the most relevant content.'
  156. '''
  157. def configure(self, session: Session, run: Run, **kwargs):
  158. # 获取当前事件循环
  159. document_id = []
  160. file_key = []
  161. filesinfo = []
  162. # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
  163. for key in run.file_ids:
  164. if len(key) == 36:
  165. document_id.append(key)
  166. else:
  167. file_key.append(key)
  168. print(
  169. "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
  170. )
  171. print(document_id)
  172. print(file_key)
  173. files = []
  174. # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
  175. if len(file_key) > 0:
  176. ## 获取文件信息
  177. files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
  178. print(files)
  179. # r2r接口不提供多条件,否则上面没必要存在
  180. if len(document_id) > 0:
  181. filesinfo += FileService.list_in_files(ids=document_id, offset=0, limit=100)
  182. # asyncio.run(
  183. # FileService.list_in_files(ids=document_id, offset=0, limit=100)
  184. # )
  185. for file in filesinfo:
  186. self.__filenames.append(file.get("title"))
  187. self.__keys.append(file.get("id"))
  188. print(filesinfo)
  189. # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
  190. db_asst = AssistantService.get_assistant_sync(
  191. session=session, assistant_id=run.assistant_id
  192. )
  193. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  194. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  195. asst_folder_ids = (
  196. db_asst.tool_resources.get("file_search")
  197. .get("vector_stores")[0]
  198. .get("folder_ids")
  199. )
  200. print(asst_folder_ids)
  201. folder_fileinfo = []
  202. if asst_folder_ids:
  203. for fid in asst_folder_ids:
  204. folder_fileinfo += FileService.list_documents(
  205. id=fid, offset=0, limit=100
  206. )
  207. # folder_fileinfo += asyncio.run(
  208. # FileService.list_documents(id=fid, offset=0, limit=100)
  209. # )
  210. print(folder_fileinfo)
  211. for file in folder_fileinfo:
  212. self.__filenames.append(file.get("title"))
  213. self.__keys.append(file.get("id"))
  214. # pre-cache data to prevent thread conflicts that may occur later on.
  215. print(
  216. "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
  217. )
  218. print(files)
  219. for file in files:
  220. self.__filenames.append(file.filename)
  221. self.__keys.append(file.key)
  222. print(self.__keys)
  223. def run(self, indexes: List[int], query: str) -> dict:
  224. file_keys = []
  225. for index in indexes:
  226. if index is not None:
  227. file_key = self.__keys[index]
  228. file_keys.append(file_key)
  229. print(
  230. "file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
  231. )
  232. print(file_keys)
  233. files = []
  234. if len(file_keys) > 0:
  235. # self.loop = asyncio.get_event_loop()
  236. files = FileService.search_in_files(query=query, file_keys=file_keys)
  237. # files = asyncio.run(
  238. # FileService.search_in_files(query=query, file_keys=file_keys)
  239. # )
  240. print(files)
  241. return files
  242. def instruction_supplement(self) -> str:
  243. """
  244. 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
  245. """
  246. if len(self.__filenames) == 0:
  247. return ""
  248. else:
  249. filenames_info = [
  250. f"({index}){filename}"
  251. for index, filename in enumerate(self.__filenames)
  252. ]
  253. return (
  254. 'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
  255. + 'Each line represents a file in the format "(index)filename":\n'
  256. + "\n".join(filenames_info)
  257. + "\nMake sure to be extremely concise when using attached files. "
  258. )
  259. '''