file_search_tool.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from typing import Type, List
  2. from pydantic import BaseModel, Field
  3. from sqlalchemy.orm import Session
  4. from app.core.tools.base_tool import BaseTool
  5. from app.models.run import Run
  6. from app.services.file.file import FileService
  7. from app.services.assistant.assistant import AssistantService
  8. import asyncio
  9. import nest_asyncio
  10. # 使得异步代码可以在已运行的事件循环中嵌套
  11. nest_asyncio.apply()
  12. # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  13. try:
  14. loop = asyncio.get_running_loop() # 检查是否有运行的事件循环
  15. except RuntimeError:
  16. loop = asyncio.get_event_loop()
  17. print("事件循环未运行,手动启动")
  18. class FileSearchToolInput(BaseModel):
  19. indexes: List[int] = Field(
  20. ..., description="file index list to look up in retrieval"
  21. )
  22. query: str = Field(..., description="query to look up in retrieval")
  23. class FileSearchTool(BaseTool):
  24. name: str = "file_search"
  25. description: str = (
  26. "Can be used to look up information that was uploaded to this assistant."
  27. "If the user is referencing particular files, that is often a good hint that information may be here."
  28. )
  29. args_schema: Type[BaseModel] = FileSearchToolInput
  30. def __init__(self) -> None:
  31. super().__init__()
  32. self.__filenames = []
  33. self.__keys = []
  34. self.loop = None
  35. def configure(self, session: Session, run: Run, **kwargs):
  36. """
  37. # 提交任务到事件循环
  38. future = asyncio.run_coroutine_threadsafe(async_task(), loop)
  39. # 阻塞等待结果
  40. result = future.result()
  41. """
  42. """
  43. 置当前 Retrieval 涉及文件信息
  44. """
  45. # 获取当前事件循环
  46. document_id = []
  47. file_key = []
  48. filesinfo = []
  49. # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
  50. for key in run.file_ids:
  51. if len(key) == 36:
  52. document_id.append(key)
  53. else:
  54. file_key.append(key)
  55. print(
  56. "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
  57. )
  58. print(document_id)
  59. print(file_key)
  60. files = []
  61. # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
  62. if len(file_key) > 0:
  63. ## 获取文件信息
  64. files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
  65. print(files)
  66. # r2r接口不提供多条件,否则上面没必要存在
  67. if len(document_id) > 0:
  68. filesinfo += loop.run_until_complete(
  69. FileService.list_in_files(ids=document_id, offset=0, limit=100)
  70. )
  71. # asyncio.run(
  72. # FileService.list_in_files(ids=document_id, offset=0, limit=100)
  73. # )
  74. for file in filesinfo:
  75. self.__filenames.append(file.get("title"))
  76. self.__keys.append(file.get("id"))
  77. print(filesinfo)
  78. # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
  79. db_asst = AssistantService.get_assistant_sync(
  80. session=session, assistant_id=run.assistant_id
  81. )
  82. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  83. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  84. asst_folder_ids = (
  85. db_asst.tool_resources.get("file_search")
  86. .get("vector_stores")[0]
  87. .get("folder_ids")
  88. )
  89. print(asst_folder_ids)
  90. folder_fileinfo = []
  91. if asst_folder_ids:
  92. for fid in asst_folder_ids:
  93. folder_fileinfo += loop.run_until_complete(
  94. FileService.list_documents(id=fid, offset=0, limit=100)
  95. )
  96. # folder_fileinfo += asyncio.run(
  97. # FileService.list_documents(id=fid, offset=0, limit=100)
  98. # )
  99. print(folder_fileinfo)
  100. for file in folder_fileinfo:
  101. self.__filenames.append(file.get("title"))
  102. self.__keys.append(file.get("id"))
  103. # pre-cache data to prevent thread conflicts that may occur later on.
  104. print(
  105. "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
  106. )
  107. print(files)
  108. for file in files:
  109. self.__filenames.append(file.filename)
  110. self.__keys.append(file.key)
  111. print(self.__keys)
  112. def run(self, indexes: List[int], query: str) -> dict:
  113. file_keys = []
  114. for index in indexes:
  115. if index is not None:
  116. file_key = self.__keys[index]
  117. file_keys.append(file_key)
  118. print(
  119. "file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
  120. )
  121. print(file_keys)
  122. files = []
  123. if len(file_keys) > 0:
  124. # self.loop = asyncio.get_event_loop()
  125. files += loop.run_until_complete(
  126. FileService.search_in_files(query=query, file_keys=file_keys)
  127. )
  128. # files = asyncio.run(
  129. # FileService.search_in_files(query=query, file_keys=file_keys)
  130. # )
  131. return files
  132. def instruction_supplement(self) -> str:
  133. """
  134. 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
  135. """
  136. if len(self.__filenames) == 0:
  137. return ""
  138. else:
  139. filenames_info = [
  140. f"({index}){filename}"
  141. for index, filename in enumerate(self.__filenames)
  142. ]
  143. return (
  144. 'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
  145. + 'Each line represents a file in the format "(index)filename":\n'
  146. + "\n".join(filenames_info)
  147. + "\nMake sure to be extremely concise when using attached files. "
  148. )