file_search_tool.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from typing import Type, List
  2. from pydantic import BaseModel, Field
  3. from sqlalchemy.orm import Session
  4. from app.core.tools.base_tool import BaseTool
  5. from app.models.run import Run
  6. from app.services.file.file import FileService
  7. from app.services.assistant.assistant import AssistantService
  8. import asyncio
  9. class FileSearchToolInput(BaseModel):
  10. indexes: List[int] = Field(
  11. ..., description="file index list to look up in retrieval"
  12. )
  13. query: str = Field(..., description="query to look up in retrieval")
  14. class FileSearchTool(BaseTool):
  15. name: str = "file_search"
  16. description: str = (
  17. "Can be used to look up information that was uploaded to this assistant."
  18. "If the user is referencing particular files, that is often a good hint that information may be here."
  19. )
  20. args_schema: Type[BaseModel] = FileSearchToolInput
  21. def __init__(self) -> None:
  22. super().__init__()
  23. self.__filenames = []
  24. self.__keys = []
  25. def configure(self, session: Session, run: Run, **kwargs):
  26. """
  27. 置当前 Retrieval 涉及文件信息
  28. """
  29. ## 获取文件信息
  30. # files = FileService.get_file_list_by_ids(session=session, file_ids=run.file_ids)
  31. files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
  32. loop = asyncio.get_event_loop() # 获取当前事件循环
  33. db_asst = loop.run_until_complete(
  34. AssistantService.get_assistant(
  35. session=session, assistant_id=run.assistant_id
  36. )
  37. )
  38. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  39. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  40. asst_folder_ids = (
  41. db_asst.tool_resources.get("file_search")
  42. .get("vector_stores")[0]
  43. .get("folder_ids")
  44. )
  45. print(asst_folder_ids)
  46. if asst_folder_ids:
  47. for fid in asst_folder_ids:
  48. files += FileService.list_documents(id=fid, offset=0, limit=100)
  49. # pre-cache data to prevent thread conflicts that may occur later on.
  50. print(
  51. "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
  52. )
  53. print(files)
  54. for file in files:
  55. self.__filenames.append(file.title)
  56. self.__keys.append(file.get("metadata").get("file_key"))
  57. print(self.__keys)
  58. def run(self, indexes: List[int], query: str) -> dict:
  59. file_keys = []
  60. print(self.__keys)
  61. for index in indexes:
  62. file_key = self.__keys[index]
  63. file_keys.append(file_key)
  64. files = FileService.search_in_files(query=query, file_keys=file_keys)
  65. return files
  66. def instruction_supplement(self) -> str:
  67. """
  68. 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
  69. """
  70. if len(self.__filenames) == 0:
  71. return ""
  72. else:
  73. filenames_info = [
  74. f"({index}){filename}"
  75. for index, filename in enumerate(self.__filenames)
  76. ]
  77. return (
  78. 'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
  79. + 'Each line represents a file in the format "(index)filename":\n'
  80. + "\n".join(filenames_info)
  81. + "\nMake sure to be extremely concise when using attached files. "
  82. )