file_search_tool.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from typing import Type, List
  2. from pydantic import BaseModel, Field
  3. from sqlalchemy.orm import Session
  4. from app.core.tools.base_tool import BaseTool
  5. from app.models.run import Run
  6. from app.services.file.file import FileService
  7. from app.services.assistant.assistant import AssistantService
  8. import asyncio
  9. import nest_asyncio
  10. # 使得异步代码可以在已运行的事件循环中嵌套
  11. nest_asyncio.apply()
  12. class FileSearchToolInput(BaseModel):
  13. indexes: List[int] = Field(
  14. ..., description="file index list to look up in retrieval"
  15. )
  16. query: str = Field(..., description="query to look up in retrieval")
  17. class FileSearchTool(BaseTool):
  18. name: str = "file_search"
  19. description: str = (
  20. "Can be used to look up information that was uploaded to this assistant."
  21. "If the user is referencing particular files, that is often a good hint that information may be here."
  22. )
  23. args_schema: Type[BaseModel] = FileSearchToolInput
  24. def __init__(self) -> None:
  25. super().__init__()
  26. self.__filenames = []
  27. self.__keys = []
  28. self.loop = None
  29. def configure(self, session: Session, run: Run, **kwargs):
  30. """
  31. 置当前 Retrieval 涉及文件信息
  32. """
  33. self.loop = asyncio.get_event_loop() # 获取当前事件循环
  34. ## 获取文件信息
  35. files = FileService.get_file_list_by_ids(session=session, file_ids=run.file_ids)
  36. # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
  37. db_asst = AssistantService.get_assistant_sync(
  38. session=session, assistant_id=run.assistant_id
  39. )
  40. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  41. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  42. asst_folder_ids = (
  43. db_asst.tool_resources.get("file_search")
  44. .get("vector_stores")[0]
  45. .get("folder_ids")
  46. )
  47. print(asst_folder_ids)
  48. folder_fileinfo = []
  49. if asst_folder_ids:
  50. for fid in asst_folder_ids:
  51. folder_fileinfo += self.loop.run_until_complete(
  52. FileService.list_documents(id=fid, offset=0, limit=100)
  53. )
  54. print(folder_fileinfo)
  55. for file in folder_fileinfo:
  56. self.__filenames.append(file.get("title"))
  57. self.__keys.append(file.get("id"))
  58. # pre-cache data to prevent thread conflicts that may occur later on.
  59. print(
  60. "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
  61. )
  62. print(files)
  63. for file in files:
  64. self.__filenames.append(file.filename)
  65. self.__keys.append(file.key)
  66. print(self.__keys)
  67. def run(self, indexes: List[int], query: str) -> dict:
  68. file_keys = []
  69. for index in indexes:
  70. file_key = self.__keys[index]
  71. file_keys.append(file_key)
  72. print(file_key)
  73. # self.loop = asyncio.get_event_loop()
  74. files = self.loop.run_until_complete(
  75. FileService.search_in_files(query=query, file_keys=file_keys)
  76. )
  77. return files
  78. def instruction_supplement(self) -> str:
  79. """
  80. 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
  81. """
  82. if len(self.__filenames) == 0:
  83. return ""
  84. else:
  85. filenames_info = [
  86. f"({index}){filename}"
  87. for index, filename in enumerate(self.__filenames)
  88. ]
  89. return (
  90. 'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
  91. + 'Each line represents a file in the format "(index)filename":\n'
  92. + "\n".join(filenames_info)
  93. + "\nMake sure to be extremely concise when using attached files. "
  94. )