Gogs 4 maanden geleden
commit
371ecc6b86
100 gewijzigde bestanden met toevoegingen van 5299 en 0 verwijderingen
  1. 9 0
      .dockerignore
  2. 63 0
      .env.example
  3. 59 0
      .github/workflows/docker-image-api.yml
  4. 60 0
      .github/workflows/docker-image-playground-ui.yml
  5. 9 0
      .gitignore
  6. BIN
      .r2r2.py.swp
  7. 30 0
      Dockerfile
  8. 21 0
      LICENSE
  9. 19 0
      Makefile
  10. 167 0
      README.md
  11. 163 0
      README_CN.md
  12. 157 0
      README_JP.md
  13. 72 0
      alembic.ini
  14. 0 0
      app/__init__.py
  15. 0 0
      app/api/__init__.py
  16. 113 0
      app/api/deps.py
  17. 15 0
      app/api/routes.py
  18. 0 0
      app/api/v1/__init__.py
  19. 84 0
      app/api/v1/action.py
  20. 65 0
      app/api/v1/assistant.py
  21. 60 0
      app/api/v1/assistant_file.py
  22. 84 0
      app/api/v1/files.py
  23. 113 0
      app/api/v1/message.py
  24. 192 0
      app/api/v1/runs.py
  25. 45 0
      app/api/v1/thread.py
  26. 45 0
      app/api/v1/token.py
  27. 0 0
      app/core/__init__.py
  28. 0 0
      app/core/doc_loaders/__init__.py
  29. 44 0
      app/core/doc_loaders/doc_loader.py
  30. 0 0
      app/core/runner/__init__.py
  31. 61 0
      app/core/runner/llm_backend.py
  32. 83 0
      app/core/runner/llm_callback_handler.py
  33. 87 0
      app/core/runner/memory.py
  34. 263 0
      app/core/runner/pub_handler.py
  35. 298 0
      app/core/runner/thread_runner.py
  36. 0 0
      app/core/runner/utils/__init__.py
  37. 58 0
      app/core/runner/utils/message_util.py
  38. 67 0
      app/core/runner/utils/tool_call_util.py
  39. 44 0
      app/core/tools/__init__.py
  40. 75 0
      app/core/tools/base_tool.py
  41. 37 0
      app/core/tools/external_function_tool.py
  42. 66 0
      app/core/tools/file_search_tool.py
  43. 52 0
      app/core/tools/openapi_function_tool.py
  44. 36 0
      app/core/tools/web_search.py
  45. 0 0
      app/exceptions/__init__.py
  46. 110 0
      app/exceptions/exception.py
  47. 0 0
      app/libs/__init__.py
  48. 36 0
      app/libs/bson/errors.py
  49. 276 0
      app/libs/bson/objectid.py
  50. 53 0
      app/libs/bson/tz_util.py
  51. 25 0
      app/libs/class_loader.py
  52. 84 0
      app/libs/paginate.py
  53. 52 0
      app/libs/thread_executor.py
  54. 1 0
      app/libs/types.py
  55. 37 0
      app/libs/util.py
  56. 26 0
      app/models/__init__.py
  57. 37 0
      app/models/action.py
  58. 51 0
      app/models/assistant.py
  59. 23 0
      app/models/assistant_file.py
  60. 46 0
      app/models/base_model.py
  61. 18 0
      app/models/file.py
  62. 41 0
      app/models/message.py
  63. 8 0
      app/models/message_file.py
  64. 119 0
      app/models/run.py
  65. 39 0
      app/models/run_step.py
  66. 26 0
      app/models/thread.py
  67. 27 0
      app/models/token.py
  68. 30 0
      app/models/token_relation.py
  69. 0 0
      app/providers/__init__.py
  70. 41 0
      app/providers/app_provider.py
  71. 112 0
      app/providers/auth_provider.py
  72. 9 0
      app/providers/celery_app.py
  73. 63 0
      app/providers/database.py
  74. 42 0
      app/providers/handle_exception.py
  75. 46 0
      app/providers/logging_provider.py
  76. 0 0
      app/providers/middleware/__init__.py
  77. 11 0
      app/providers/middleware/http_process_time.py
  78. 19 0
      app/providers/middleware/unhandled_exception_handler.py
  79. 5 0
      app/providers/pagination_provider.py
  80. 80 0
      app/providers/r2r.py
  81. 24 0
      app/providers/response.py
  82. 15 0
      app/providers/route_provider.py
  83. 80 0
      app/providers/storage.py
  84. 0 0
      app/schemas/__init__.py
  85. 14 0
      app/schemas/common.py
  86. 9 0
      app/schemas/files.py
  87. 23 0
      app/schemas/runs.py
  88. 15 0
      app/schemas/threads.py
  89. 0 0
      app/schemas/tool/__init__.py
  90. 254 0
      app/schemas/tool/action.py
  91. 86 0
      app/schemas/tool/authentication.py
  92. 0 0
      app/services/__init__.py
  93. 0 0
      app/services/assistant/__init__.py
  94. 67 0
      app/services/assistant/assistant.py
  95. 53 0
      app/services/assistant/assistant_file.py
  96. 0 0
      app/services/file/__init__.py
  97. 9 0
      app/services/file/file.py
  98. 0 0
      app/services/file/impl/__init__.py
  99. 47 0
      app/services/file/impl/base.py
  100. 94 0
      app/services/file/impl/oss_file.py

+ 9 - 0
.dockerignore

@@ -0,0 +1,9 @@
+__pycache__/
+.idea/
+venv/
+.venv
+.env
+volumes/
+logs
+.vscode
+.DS_Store

+ 63 - 0
.env.example

@@ -0,0 +1,63 @@
+# app
+APP_NAME=open-assistant-api
+APP_DEBUG=True
+APP_ENV=local
+APP_SERVER_HOST=0.0.0.0
+APP_SERVER_PORT=8086
+APP_SERVER_WORKERS=1
+APP_API_PREFIX=/api
+APP_AUTH_ENABLE=False
+APP_AUTH_ADMIN_TOKEN=admin
+
+LOG_LEVEL=DEBUG
+
+# database
+DB_HOST=127.0.0.1
+DB_PORT=3306
+DB_DATABASE=open_assistant
+DB_USER=root
+DB_PASSWORD=123456
+DB_POOL_SIZE=1
+
+# redis
+REDIS_HOST=localhost
+REDIS_PORT=6379
+REDIS_DB=0
+REDIS_PASSWORD=123456
+
+# s3 storage
+S3_ENDPOINT=http://minio:9000
+S3_BUCKET_NAME=oas
+S3_ACCESS_KEY=minioadmin
+S3_SECRET_KEY=minioadmin
+S3_REGION=us-east-1
+
+# celery
+CELERY_BROKER_URL=redis://:123456@127.0.0.1:6379/1
+
+# llm
+OPENAI_API_BASE=
+OPENAI_API_KEY=
+LLM_MAX_STEP=25
+
+# tool
+TOOL_WORKER_NUM=10
+TOOL_WORKER_EXECUTION_TIMEOUT=180
+
+# web search tool
+BING_SEARCH_URL=https://api.bing.microsoft.com/v7.0/search
+BING_SUBSCRIPTION_KEY=xxxx
+WEB_SEARCH_NUM_RESULTS=5
+
+# file service
+FILE_SERVICE_MODULE=app.services.file.impl.oss_file.OSSFileService
+#FILE_SERVICE_MODULE=app.services.file.impl.r2r_file.R2RFileService
+
+# file search tool
+R2R_BASE_URL=http://127.0.0.1:8000
+R2R_USERNAME=admin@example.com
+R2R_PASSWORD=change_me_immediately
+R2R_SEARCH_LIMIT=10
+
+# secret
+APP_AES_ENCRYPTION_KEY=7700b2f9c8dd982dfaddf8b47a92f1d900507ee8ac335f96a64e9ca0f018b195

+ 59 - 0
.github/workflows/docker-image-api.yml

@@ -0,0 +1,59 @@
+name: API Docker Image CI
+
+on:
+  push:
+    paths-ignore:
+      - 'playground-ui/**'
+    branches:
+      - 'main'
+  release:
+    types: [ published ]
+
+env:
+  IMAGE_NAME: 'samepaage/open-assistant-api'
+
+jobs:
+  build-and-push:
+    runs-on: ubuntu-latest
+    permissions:
+      packages: write
+      contents: read
+    steps:
+      - uses: actions/checkout@v4
+
+      - name: Set up QEMU
+        uses: docker/setup-qemu-action@v3
+
+      - name: Set up Docker Buildx
+        uses: docker/setup-buildx-action@v3
+
+      - name: Login to DockerHub
+        uses: docker/login-action@v2
+        with:
+          username: ${{ secrets.DOCKERHUB_USER }}
+          password: ${{ secrets.DOCKERHUB_TOKEN }}
+
+      - name: Extract metadata (tags, labels) for Docker
+        id: meta
+        uses: docker/metadata-action@v5
+        with:
+          images: ${{ env.IMAGE_NAME }}
+          tags: |
+            type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
+            type=ref,event=branch
+            type=sha,enable=true,priority=100,prefix=,suffix=,format=long
+            type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
+
+      - name: Build and push
+        uses: docker/build-push-action@v5
+        with:
+          platforms: |
+            linux/amd64
+            linux/arm64
+          build-args: |
+            COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
+          push: true
+          tags: ${{ steps.meta.outputs.tags }}
+          labels: ${{ steps.meta.outputs.labels }}
+          cache-from: type=gha
+          cache-to: type=gha,mode=max

+ 60 - 0
.github/workflows/docker-image-playground-ui.yml

@@ -0,0 +1,60 @@
+name: Playground UI Docker Image CI
+
+on:
+  push:
+    paths:
+      - 'playground-ui/**'
+    branches:
+      - 'main'
+  release:
+    types: [ published ]
+
+env:
+  IMAGE_NAME: 'samepaage/open-assistant-playground-ui'
+
+jobs:
+  build-and-push:
+    runs-on: ubuntu-latest
+    permissions:
+      packages: write
+      contents: read
+    steps:
+      - uses: actions/checkout@v4
+
+      - name: Set up QEMU
+        uses: docker/setup-qemu-action@v3
+
+      - name: Set up Docker Buildx
+        uses: docker/setup-buildx-action@v3
+
+      - name: Login to DockerHub
+        uses: docker/login-action@v2
+        with:
+          username: ${{ secrets.DOCKERHUB_USER }}
+          password: ${{ secrets.DOCKERHUB_TOKEN }}
+
+      - name: Extract metadata (tags, labels) for Docker
+        id: meta
+        uses: docker/metadata-action@v5
+        with:
+          images: ${{ env.IMAGE_NAME }}
+          tags: |
+            type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
+            type=ref,event=branch
+            type=sha,enable=true,priority=100,prefix=,suffix=,format=long
+            type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
+
+      - name: Build and push
+        uses: docker/build-push-action@v5
+        with:
+          context: ./playground-ui
+          platforms: |
+            linux/amd64
+            linux/arm64
+          build-args: |
+            COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
+          push: true
+          tags: ${{ steps.meta.outputs.tags }}
+          labels: ${{ steps.meta.outputs.labels }}
+          cache-from: type=gha
+          cache-to: type=gha,mode=max

+ 9 - 0
.gitignore

@@ -0,0 +1,9 @@
+__pycache__/
+.idea/
+venv/
+.venv
+.env
+volumes/
+logs
+.vscode
+.DS_Store

BIN
.r2r2.py.swp


+ 30 - 0
Dockerfile

@@ -0,0 +1,30 @@
+FROM python:3.10-slim AS base
+
+LABEL maintainer="xujiawei@cocorobo.cc"
+
+RUN apt-get update \
+    && apt-get install -y --no-install-recommends bash curl wget vim libmagic-dev \
+    && apt-get autoremove \
+    && rm -rf /var/lib/apt/lists/*
+
+RUN pip install --no-cache-dir poetry -i https://pypi.tuna.tsinghua.edu.cn/simple \
+    && poetry config virtualenvs.create false
+
+COPY poetry.lock /env/poetry.lock
+COPY pyproject.toml /env/pyproject.toml
+
+
+RUN poetry config repositories.pypi https://pypi.tuna.tsinghua.edu.cn/simple
+#RUN cd /env && poetry lock --no-update && poetry install --only main
+RUN cd /env && poetry lock --no-update && poetry install --only main
+
+EXPOSE 8086
+
+WORKDIR /app
+
+COPY . /app
+
+COPY docker/entrypoint.sh /entrypoint.sh
+RUN chmod +x /entrypoint.sh
+
+ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]

+ 21 - 0
LICENSE

@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 MarCo
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 19 - 0
Makefile

@@ -0,0 +1,19 @@
+.PHONY: all format lint
+
+all: help
+
+help:
+	@echo "make"
+
+	@echo "    format"
+	@echo "        Apply black formatting to code."
+	@echo "    lint"
+	@echo "        Lint code with ruff, and check if black formatter should be applied."
+
+format:
+	poetry run black .
+	poetry run ruff . --fix
+
+lint:
+	poetry run black . --check
+	poetry run ruff .

+ 167 - 0
README.md

@@ -0,0 +1,167 @@
+<div align="center">
+
+# Open Assistant API
+
+_✨ An out-of-the-box AI intelligent assistant API ✨_
+
+</div>
+
+<p align="center">
+  <a href="./README.md">English</a> |
+  <a href="./README_CN.md">简体中文</a> |
+  <a href="./README_JP.md">日本語</a>
+</p>
+
+## Introduction
+
+Open Assistant API is an open-source, self-hosted AI intelligent assistant API, compatible with the official OpenAI
+interface. It can be used directly with the official OpenAI [Client](https://github.com/openai/openai-python) to build
+LLM applications.
+
+It supports [One API](https://github.com/songquanpeng/one-api) for integration with more commercial and private models.
+
+It supports [R2R](https://github.com/SciPhi-AI/R2R) RAG engine。
+
+## Usage
+
+Below is an example of using the official OpenAI Python `openai` library:
+
+```python
+import openai
+
+client = openai.OpenAI(
+    base_url="http://127.0.0.1:8086/api/v1",
+    api_key="xxx"
+)
+
+assistant = client.beta.assistants.create(
+    name="demo",
+    instructions="You are a helpful assistant.",
+    model="gpt-4-1106-preview"
+)
+```
+
+## Why Choose Open Assistant API
+
+| Feature                  | Open Assistant API | OpenAI Assistant API |
+|--------------------------|--------------------|----------------------|
+| Ecosystem Strategy       | Open Source        | Closed Source        |
+| RAG Engine               | Support R2R        | Supported            |
+| Internet Search          | Supported          | Not Supported        |
+| Custom Functions         | Supported          | Supported            |
+| Built-in Tool            | Extendable         | Not Extendable       |
+| Code Interpreter         | Under Development  | Supported            |
+| Multimodal               | Supported          | Supported            |
+| LLM Support              | Supports More LLMs | Only GPT             |
+| Message Streaming Output | Supports           | Supported            |
+| Local Deployment         | Supported          | Not Supported        |
+
+- **LLM Support**: Compared to the official OpenAI version, more models can be supported by integrating with One API.
+- **Tool**: Currently supports online search; can easily expand more tools.
+- **RAG Engine**: The currently supported file types are txt, html, markdown, pdf, docx, pptx, xlsx, png, mp3, mp4, etc. We provide a preliminary
+  implementation.
+- **Message Streaming Output**: Support message streaming output for a smoother user experience.
+- **Ecosystem Strategy**: Open source, you can deploy the service locally and expand the existing features.
+
+## Quick Start
+
+The easiest way to start the Open Assistant API is to run the docker-compose.yml file. Make sure Docker and Docker
+Compose are installed on your machine before running.
+
+### Configuration
+
+Go to the project root directory, open `docker-compose.yml`, fill in the openai api_key and bing search key (optional).
+
+```sh
+# openai api_key (supports OneAPI api_key)
+OPENAI_API_KEY=<openai_api_key>
+
+# bing search key (optional)
+BING_SUBSCRIPTION_KEY=<bing_subscription_key>
+```
+
+It is recommended to configure the R2R RAG engine to replace the default RAG implementation to provide better RAG capabilities.
+You can learn about and use R2R through the [R2R Github repository](https://github.com/SciPhi-AI/R2R).
+
+```sh
+# RAG config
+# FILE_SERVICE_MODULE=app.services.file.impl.oss_file.OSSFileService
+FILE_SERVICE_MODULE=app.services.file.impl.r2r_file.R2RFileService
+R2R_BASE_URL=http://<r2r_api_address>
+R2R_USERNAME=<r2r_username>
+R2R_PASSWORD=<r2r_password>
+```
+
+### Run
+
+#### Run with Docker Compose:
+
+ ```sh
+docker compose up -d
+ ```
+
+### Access API
+
+Api Base URL: http://127.0.0.1:8086/api/v1
+
+Interface documentation address: http://127.0.0.1:8086/docs
+
+### Complete Usage Example
+
+In this example, an AI assistant is created and run using the official OpenAI client library. If you need to explore other usage methods,
+such as streaming output, tools (web_search, retrieval, function), etc., you can find the corresponding code under the examples directory.
+Before running, you need to run `pip install openai` to install the Python `openai` library.
+
+```sh
+# !pip install openai
+export PYTHONPATH=$(pwd)
+python examples/run_assistant.py
+```
+
+
+### Permissions
+Simple user isolation is provided based on tokens to meet SaaS deployment requirements. It can be enabled by configuring `APP_AUTH_ENABLE`.
+
+![](docs/imgs/user.png)
+
+1. The authentication method is Bearer token. You can include `Authorization: Bearer ***` in the header for authentication.
+2. Token management is described in the token section of the API documentation. Relevant APIs need to be authenticated with an admin token, which is configured as `APP_AUTH_ADMIN_TOKEN` and defaults to "admin".
+3. When creating a token, you need to provide the base URL and API key of the large model. The created assistant will use the corresponding configuration to access the large model.
+
+### Tools
+According to the OpenAPI/Swagger specification, it allows the integration of various tools into the assistant, empowering and enhancing its capability to connect with the external world.
+
+1. Facilitates connecting your application with other systems or services, enabling interaction with the external environment, such as code execution or accessing proprietary information sources.
+2. During usage, you need to create tools first, and then you can integrate them with the assistant. Refer to the test cases for more details.[Assistant With Action](tests/tools/assistant_action_test.py)
+3. If you need to use tools with authentication information, simply add the authentication information at runtime. The specific parameter format can be found in the API documentation. Refer to the test cases for more details. [Run With Auth Action](tests/tools/run_with_auth_action_test.py)
+
+
+## Community and Support
+
+- Join the [Slack](https://join.slack.com/t/openassistant-qbu7007/shared_invite/zt-29t8j9y12-9og5KZL6GagXTEvbEDf6UQ)
+  channel to see new releases, discuss issues, and participate in community interactions.
+- Join the [Discord](https://discord.gg/VfBruz4B) channel to interact with other community members.
+- Join the WeChat group:
+
+  ![](docs/imgs/wx.png)
+
+## Special Thanks
+
+We mainly referred to and relied on the following projects:
+
+- [OpenOpenAI](https://github.com/transitive-bullshit/OpenOpenAI): Assistant API implemented in Node
+- [One API](https://github.com/songquanpeng/one-api): Multi-model management tool
+- [R2R](https://github.com/SciPhi-AI/R2R): RAG engine
+- [OpenAI-Python](https://github.com/openai/openai-python): OpenAI Python Client
+- [OpenAI API](https://github.com/openai/openai-openapi): OpenAI interface definition
+- [LangChain](https://github.com/langchain-ai/langchain): LLM application development library
+- [OpenGPTs](https://github.com/langchain-ai/opengpts): LangChain GPTs
+- [TaskingAI](https://github.com/TaskingAI/TaskingAI): TaskingAI Client SDK
+
+## Contributing
+
+Please read our [contribution document](./docs/CONTRIBUTING.md) to learn how to contribute.
+
+## Open Source License
+
+This repository follows the MIT open source license. For more information, please see the [LICENSE](./LICENSE) file.

+ 163 - 0
README_CN.md

@@ -0,0 +1,163 @@
+<div align="center">
+
+# Open Assistant API
+
+_✨ 开箱即用的 AI 智能助手 API ✨_
+
+</div>
+
+<p align="center">
+  <a href="./README.md">English</a> |
+  <a href="./README_CN.md">简体中文</a> |
+  <a href="./README_JP.md">日本語</a>
+</p>
+
+## 简介
+
+Open Assistant API 是一个开源自托管的 AI 智能助手 API,兼容 OpenAI 官方接口,
+可以直接使用 OpenAI 官方的 [Client](https://github.com/openai/openai-python) 构建 LLM 应用。
+
+支持 [One API](https://github.com/songquanpeng/one-api) 可以用其接入更多商业和私有模型。
+
+支持 [R2R](https://github.com/SciPhi-AI/R2R) RAG 引擎。
+
+## 使用
+
+以下是使用了 OpenAI 官方的 Python `openai` 库的使用示例:
+
+```python
+import openai
+
+client = openai.OpenAI(
+    base_url="http://127.0.0.1:8086/api/v1",
+    api_key="xxx"
+)
+
+assistant = client.beta.assistants.create(
+    name="demo",
+    instructions="You are a helpful assistant.",
+    model="gpt-4-1106-preview"
+)
+```
+
+## 为什么选择 Open Assistant API
+
+| 功能               | Open Assistant API | OpenAI Assistant API |
+|------------------|--------------------|----------------------|
+| 生态策略             | 开源                 | 闭源                   |
+| RAG 引擎           | 支持 R2R             | 支持                   |
+| 联网搜索             | 支持                 | 不支持                  |
+| 自定义 Functions    | 支持                 | 支持                   |
+| 内置 Tool          | 支持扩展               | 不支持扩展                |
+| Code Interpreter | 待开发                | 支持                   |
+| 多模态识别            | 支持                 | 支持                   |
+| LLM 支持           | 支持更多的 LLM          | 仅 GPT                |
+| Message 流式输出     | 支持                 | 支持                   |
+| 本地部署             | 支持                 | 不支持                  |
+
+- **LLM 支持**: 相较于 OpenAI 官方版本,可以通过接入 One API 来支持更多的模型。
+- **Tool**: 目前支持联网搜索;可以较容易扩展更多的 Tool。
+- **RAG 引擎**: 支持 R2R RAG 引擎,目前支持的文件类型有 txt、html、markdown、pdf、docx、pptx、xlsx、png、mp3、mp4 等。
+- **Message 流式输出**: 支持 Message 流式输出,提供更流畅的用户体验。
+- **生态策略**: 开源,你可以将服务部署在本地,可以对已有功能进行扩展。
+
+## 快速上手
+
+启动 Open Assistant API 最简单方法是运行 docker-compose.yml 文件。 运行之前确保机器上安装了 Docker 和 Docker Compose。
+
+### 配置
+
+进入项目根目录,打开 `docker-compose.yml`,填写 openai api_key 和 bing search key (非必填)。
+
+```sh
+# openai api_key (支持 OneAPI api_key)
+OPENAI_API_KEY=<openai_api_key>
+
+# bing search key (非必填)
+BING_SUBSCRIPTION_KEY=<bing_subscription_key>
+````
+
+建议配置 R2R RAG 引擎替换默认的 RAG 实现,以提供更好的 RAG 能力。
+关于 R2R,可以通过 [R2R Github 仓库](https://github.com/SciPhi-AI/R2R) 了解和使用。
+
+```sh
+# RAG 配置
+# FILE_SERVICE_MODULE=app.services.file.impl.oss_file.OSSFileService
+FILE_SERVICE_MODULE=app.services.file.impl.r2r_file.R2RFileService
+R2R_BASE_URL=http://<r2r_api_address>
+R2R_USERNAME=<r2r_username>
+R2R_PASSWORD=<r2r_password>
+```
+
+### 运行
+
+#### 使用 Docker Compose 运行:
+
+ ```sh
+docker compose up -d
+ ```
+
+### 访问 API
+
+Api Base URL: http://127.0.0.1:8086/api/v1
+
+接口文档地址: http://127.0.0.1:8086/docs
+
+### 完整使用示例
+
+此示例中使用 OpenAI 官方的 client 库创建并运行了一个 AI 助手。如果需要查看其它使用方式,如流式输出、工具(web_search、retrieval、function)的使用等,
+可以在 examples 查看对应示例。
+运行之前需要运行 `pip install openai` 安装 Python `openai` 库。
+
+```sh
+# !pip install openai
+export PYTHONPATH=$(pwd)
+python examples/run_assistant.py
+```
+
+### 权限
+基于 token 提供简单用户隔离,满足 SaaS 部署需求,可通过配置 ```APP_AUTH_ENABLE``` 开启
+
+![](docs/imgs/user.png)
+
+1. 验证方式为 Bearer token,可在 Header 中填入 ```Authorization: Bearer ***``` 进行验证
+2. token 管理参考 api 文档中的 token 小节
+相关 api 需通过 admin token 验证,配置为 ```APP_AUTH_ADMIN_TOKEN```,默认为 admin
+3. 创建 token 需填入大模型 base_url 和 api_key,创建的 assistant 将使用相关配置访问大模型
+### 工具
+根据 OpenAPI /Swagger规范,允许将多种工具集成到助手中,赋予并增强了LLM 连接外部世界的能力。
+
+1. 方便将你的应用与其他系统或服务连接,与外部环境交互,如代码执行、对专属信息源的访问
+2. 在使用过程中,需创建工具,接着将工具与助手搭配即可,查看测试用例[Assistant With Action](tests/tools/assistant_action_test.py)
+2. 若需要使用带认证信息的工具,只需在运行时添加认证信息即可,具体参数格式可在接口文档中查看。查看测试用例[Run With Auth Action](tests/tools/run_with_auth_action_test.py)
+
+## 社区与支持
+
+- 加入 [Slack](https://join.slack.com/t/openassistant-qbu7007/shared_invite/zt-29t8j9y12-9og5KZL6GagXTEvbEDf6UQ)
+  频道,查看新发布的内容,交流问题,参与社区互动。
+- 加入 [Discord](https://discord.gg/VfBruz4B) 频道,与其他社区成员交流。
+- 加入 Open Assistant Api 微信交流群:
+
+  ![](docs/imgs/wx.png)
+
+## 特别感谢
+
+我们主要参考和依赖了以下项目:
+
+- [OpenOpenAI](https://github.com/transitive-bullshit/OpenOpenAI): Node 实现的 Assistant API
+- [One API](https://github.com/songquanpeng/one-api): 多模型管理工具
+- [R2R](https://github.com/SciPhi-AI/R2R): RAG 引擎
+- [OpenAI-Python](https://github.com/openai/openai-python): OpenAI Python Client
+- [OpenAI API](https://github.com/openai/openai-openapi): OpenAI 接口定义
+- [LangChain](https://github.com/langchain-ai/langchain): LLM 应用开发库
+- [OpenGPTs](https://github.com/langchain-ai/opengpts): LangChain GPTs
+- [TaskingAI](https://github.com/TaskingAI/TaskingAI): TaskingAI 原生应用开发
+
+
+## 参与贡献
+
+请阅读我们的[贡献文档](./docs/CONTRIBUTING_CN.md),了解如何参与贡献。
+
+## 开源协议
+
+本仓库遵循 MIT 开源协议。有关详细信息,请参阅 [LICENSE](./LICENSE) 文件。

+ 157 - 0
README_JP.md

@@ -0,0 +1,157 @@
+<div align="center">
+
+# Open Assistant API
+
+_✨ すぐに使える AI インテリジェントアシスタント API ✨_
+
+</div>
+
+<p align="center">
+  <a href="./README.md">English</a> |
+  <a href="./README_CN.md">简体中文</a> |
+  <a href="./README_JP.md">日本語</a>
+</p>
+
+## 紹介
+
+Open Assistant API は、オープンソースのセルフホスティング型 AI インテリジェントアシスタント API であり、OpenAI 公式インターフェースと互換性があります。OpenAI 公式の [Client](https://github.com/openai/openai-python) を使用して LLM アプリケーションを構築することができます。
+
+[One API](https://github.com/songquanpeng/one-api) をサポートしており、より多くの商用およびプライベートモデルと統合できます。
+
+[R2R](https://github.com/SciPhi-AI/R2R) RAG エンジンをサポートしています。
+
+## 使用方法
+
+以下は、OpenAI 公式の Python `openai` ライブラリを使用した例です:
+
+```python
+import openai
+
+client = openai.OpenAI(
+    base_url="http://127.0.0.1:8086/api/v1",
+    api_key="xxx"
+)
+
+assistant = client.beta.assistants.create(
+    name="demo",
+    instructions="You are a helpful assistant.",
+    model="gpt-4-1106-preview"
+)
+```
+
+## なぜ Open Assistant API を選ぶのか
+
+| 機能                  | Open Assistant API | OpenAI Assistant API |
+|----------------------|--------------------|----------------------|
+| エコシステム戦略          | オープンソース            | クローズドソース            |
+| RAG エンジン           | R2R をサポート           | サポートされている            |
+| インターネット検索         | サポートされている           | サポートされていない           |
+| カスタム関数            | サポートされている           | サポートされている            |
+| 内蔵ツール              | 拡張可能                | 拡張不可                 |
+| コードインタープリタ       | 開発中                 | サポートされている            |
+| マルチモーダル            | サポートされている           | サポートされている            |
+| LLM サポート           | より多くの LLM をサポート      | GPT のみ                |
+| メッセージストリーミング出力   | サポートされている           | サポートされている            |
+| ローカルデプロイメント       | サポートされている           | サポートされていない           |
+
+- **LLM サポート**: 公式の OpenAI バージョンと比較して、One API を統合することでより多くのモデルをサポートできます。
+- **ツール**: 現在、オンライン検索をサポートしています。より多くのツールを簡単に拡張できます。
+- **RAG エンジン**: 現在サポートされているファイルタイプは txt、html、markdown、pdf、docx、pptx、xlsx、png、mp3、mp4 などです。初期実装を提供しています。
+- **メッセージストリーミング出力**: メッセージストリーミング出力をサポートし、よりスムーズなユーザー体験を提供します。
+- **エコシステム戦略**: オープンソースであり、サービスをローカルにデプロイし、既存の機能を拡張することができます。
+
+## クイックスタート
+
+Open Assistant API を開始する最も簡単な方法は、docker-compose.yml ファイルを実行することです。実行する前に、マシンに Docker と Docker Compose がインストールされていることを確認してください。
+
+### 設定
+
+プロジェクトのルートディレクトリに移動し、`docker-compose.yml` を開いて、openai api_key と bing search key(オプション)を入力します。
+
+```sh
+# openai api_key (OneAPI api_key をサポート)
+OPENAI_API_KEY=<openai_api_key>
+
+# bing search key(オプション)
+BING_SUBSCRIPTION_KEY=<bing_subscription_key>
+```
+
+R2R RAG エンジンを設定して、デフォルトの RAG 実装を置き換え、より優れた RAG 機能を提供することをお勧めします。R2R については、[R2R Github リポジトリ](https://github.com/SciPhi-AI/R2R) を通じて学び、使用することができます。
+
+```sh
+# RAG 設定
+# FILE_SERVICE_MODULE=app.services.file.impl.oss_file.OSSFileService
+FILE_SERVICE_MODULE=app.services.file.impl.r2r_file.R2RFileService
+R2R_BASE_URL=http://<r2r_api_address>
+R2R_USERNAME=<r2r_username>
+R2R_PASSWORD=<r2r_password>
+```
+
+### 実行
+
+#### Docker Compose を使用して実行:
+
+ ```sh
+docker compose up -d
+ ```
+
+### API にアクセス
+
+Api Base URL: http://127.0.0.1:8086/api/v1
+
+インターフェースドキュメントのアドレス: http://127.0.0.1:8086/docs
+
+### 完全な使用例
+
+この例では、公式の OpenAI クライアントライブラリを使用して AI アシスタントを作成し、実行します。他の使用方法(ストリーミング出力、ツール(web_search、retrieval、function)など)を確認する場合は、examples ディレクトリで対応するコードを見つけることができます。実行する前に、Python `openai` ライブラリをインストールするために `pip install openai` を実行する必要があります。
+
+```sh
+# !pip install openai
+export PYTHONPATH=$(pwd)
+python examples/run_assistant.py
+```
+
+### 権限
+トークンに基づいて簡単なユーザー分離を提供し、SaaS デプロイメント要件を満たします。`APP_AUTH_ENABLE` を設定することで有効にできます。
+
+![](docs/imgs/user.png)
+
+1. 認証方法は Bearer トークンです。ヘッダーに `Authorization: Bearer ***` を含めて認証を行うことができます。
+2. トークン管理は API ドキュメントのトークンセクションに記載されています。関連する API は管理者トークンで認証する必要があり、`APP_AUTH_ADMIN_TOKEN` として設定され、デフォルトでは "admin" です。
+3. トークンを作成する際には、大規模モデルのベース URL と API キーを提供する必要があります。作成されたアシスタントは、対応する設定を使用して大規模モデルにアクセスします。
+
+### ツール
+OpenAPI/Swagger 仕様に従って、さまざまなツールをアシスタントに統合することができ、外部の世界と接続する能力を強化します。
+
+1. アプリケーションを他のシステムやサービスと接続し、外部環境と対話することができます。たとえば、コードの実行や専用情報源へのアクセスなどです。
+2. 使用中にツールを作成し、その後アシスタントと組み合わせることができます。詳細はテストケースを参照してください。[Assistant With Action](tests/tools/assistant_action_test.py)
+3. 認証情報を持つツールを使用する必要がある場合は、実行時に認証情報を追加するだけです。具体的なパラメータ形式は API ドキュメントで確認できます。詳細はテストケースを参照してください。[Run With Auth Action](tests/tools/run_with_auth_action_test.py)
+
+## コミュニティとサポート
+
+- [Slack](https://join.slack.com/t/openassistant-qbu7007/shared_invite/zt-29t8j9y12-9og5KZL6GagXTEvbEDf6UQ) チャンネルに参加して、新しいリリースを確認し、問題を議論し、コミュニティの交流に参加してください。
+- [Discord](https://discord.gg/VfBruz4B) チャンネルに参加して、他のコミュニティメンバーと交流してください。
+- Open Assistant Api WeChat グループに参加してください:
+
+  ![](docs/imgs/wx.png)
+
+## 特別な感謝
+
+主に以下のプロジェクトを参考にし、依存しています:
+
+- [OpenOpenAI](https://github.com/transitive-bullshit/OpenOpenAI): Node で実装された Assistant API
+- [One API](https://github.com/songquanpeng/one-api): マルチモデル管理ツール
+- [R2R](https://github.com/SciPhi-AI/R2R): RAG エンジン
+- [OpenAI-Python](https://github.com/openai/openai-python): OpenAI Python クライアント
+- [OpenAI API](https://github.com/openai/openai-openapi): OpenAI インターフェース定義
+- [LangChain](https://github.com/langchain-ai/langchain): LLM アプリケーション開発ライブラリ
+- [OpenGPTs](https://github.com/langchain-ai/opengpts): LangChain GPTs
+- [TaskingAI](https://github.com/TaskingAI/TaskingAI): TaskingAI クライアント SDK
+
+## 貢献
+
+貢献方法については、[貢献ドキュメント](./docs/CONTRIBUTING.md) をお読みください。
+
+## オープンソースライセンス
+
+このリポジトリは MIT オープンソースライセンスに従います。詳細については、[LICENSE](./LICENSE) ファイルを参照してください。

+ 72 - 0
alembic.ini

@@ -0,0 +1,72 @@
+# A generic, single database configuration.
+
+[alembic]
+# path to migration scripts
+script_location = migrations
+
+# template used to generate migration files
+# file_template = %%(rev)s_%%(slug)s
+file_template = %%(year)d-%%(month).2d-%%(day).2d-%%(hour).2d-%%(minute).2d_%%(rev)s
+
+# timezone to use when rendering the date
+# within the migration file as well as the filename.
+# string value is passed to dateutil.tz.gettz()
+# leave blank for localtime
+# timezone =
+
+# max length of characters to apply to the
+# "slug" field
+#truncate_slug_length = 40
+
+# set to 'true' to run the environment during
+# the 'revision' command, regardless of autogenerate
+# revision_environment = false
+
+# set to 'true' to allow .pyc and .pyo files without
+# a source .py file to be detected as revisions in the
+# versions/ directory
+# sourceless = false
+
+# version location specification; this defaults
+# to alembic/versions.  When using multiple version
+# directories, initial revisions must be specified with --version-path
+# version_locations = %(here)s/bar %(here)s/bat alembic/versions
+
+# the output encoding used when revision files
+# are written from script.py.mako
+# output_encoding = utf-8
+
+# Logging configuration
+[loggers]
+keys = root,sqlalchemy,alembic
+
+[handlers]
+keys = console
+
+[formatters]
+keys = generic
+
+[logger_root]
+level = WARN
+handlers = console
+qualname =
+
+[logger_sqlalchemy]
+level = WARN
+handlers =
+qualname = sqlalchemy.engine
+
+[logger_alembic]
+level = INFO
+handlers =
+qualname = alembic
+
+[handler_console]
+class = StreamHandler
+args = (sys.stderr,)
+level = NOTSET
+formatter = generic
+
+[formatter_generic]
+format = %(levelname)-5.5s [%(name)s] %(message)s
+datefmt = %H:%M:%S

+ 0 - 0
app/__init__.py


+ 0 - 0
app/api/__init__.py


+ 113 - 0
app/api/deps.py

@@ -0,0 +1,113 @@
+from typing import AsyncGenerator
+
+from fastapi import Depends, Request
+from fastapi.security import APIKeyHeader
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.exceptions.exception import AuthenticationError, AuthorizationError, ResourceNotFoundError
+from app.models.token import Token
+from app.models.token_relation import RelationType, TokenRelationQuery
+from app.providers import database
+from app.services.token.token import TokenService
+from app.services.token.token_relation import TokenRelationService
+from config.config import settings
+
+
+async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
+    """session生成器 作为fast api的Depends选项"""
+    async with database.async_session_local() as session:
+        yield session
+
+
+class OAuth2Bearer(APIKeyHeader):
+    """
+    it use to fetch token from header
+    """
+
+    def __init__(
+        self, *, name: str, scheme_name: str | None = None, description: str | None = None, auto_error: bool = True
+    ):
+        super().__init__(name=name, scheme_name=scheme_name, description=description, auto_error=auto_error)
+
+    async def __call__(self, request: Request) -> str:
+        authorization_header_value = request.headers.get(self.model.name)
+        if authorization_header_value:
+            scheme, _, param = authorization_header_value.partition(" ")
+            if scheme.lower() == "bearer" and param.strip() != "":
+                return param.strip()
+        return None
+
+
+oauth_token = OAuth2Bearer(name="Authorization")
+
+
+async def verify_admin_token(token=Depends(oauth_token)) -> Token:
+    """
+    admin token authentication
+    """
+    if token is None:
+        raise AuthenticationError()
+    if settings.AUTH_ADMIN_TOKEN != token:
+        raise AuthorizationError()
+
+
+async def get_token(session=Depends(get_async_session), token=Depends(oauth_token)) -> Token:
+    """
+    get token info
+    """
+    if token and token != "":
+        try:
+            return await TokenService.get_token(session=session, token=token)
+        except ResourceNotFoundError:
+            pass
+    return None
+
+
+async def verfiy_token(token: Token = Depends(get_token)):
+    if token is None:
+        raise AuthenticationError()
+
+
+async def get_token_id(token: Token = Depends(get_token)):
+    """
+    Return token_id, which can be considered as user information.
+    """
+    return token.id if token is not None else None
+
+
+def get_param(name: str):
+    """
+    extract param from Request
+    """
+
+    async def get_param_from_request(request: Request):
+        if name in request.path_params:
+            return request.path_params[name]
+        if name in request.query_params:
+            return request.query_params[name]
+        body = await request.json()
+        if name in body:
+            return body[name]
+
+    return get_param_from_request
+
+
+def verify_token_relation(relation_type: RelationType, name: str, ignore_none_relation_id: bool = False):
+    """
+    param relation_type: relation type
+    param name: param name
+    param ignore_none_relation_id: if ignore_none_relation_id is set, return where relation_id is None, use for copy thread api
+    """
+
+    async def verify_authorization(
+        session=Depends(get_async_session), token_id=Depends(get_token_id), relation_id=Depends(get_param(name))
+    ):
+        if token_id and ignore_none_relation_id:
+            return
+        if token_id and relation_id:
+            verify = TokenRelationQuery(token_id=token_id, relation_type=relation_type, relation_id=relation_id)
+            if await TokenRelationService.verify_relation(session=session, verify=verify):
+                return
+        raise AuthorizationError()
+
+    return verify_authorization

+ 15 - 0
app/api/routes.py

@@ -0,0 +1,15 @@
+from fastapi import APIRouter
+from app.api.v1 import assistant, assistant_file, thread, message, files, runs, token, action
+
+api_router = APIRouter(prefix="/v1")
+
+
+def router_init():
+    api_router.include_router(assistant.router, prefix="/assistants", tags=["assistants"])
+    api_router.include_router(assistant_file.router, prefix="/assistants", tags=["assistants"])
+    api_router.include_router(thread.router, prefix="/threads", tags=["threads"])
+    api_router.include_router(message.router, prefix="/threads", tags=["messages"])
+    api_router.include_router(runs.router, prefix="/threads", tags=["runs"])
+    api_router.include_router(files.router, prefix="/files", tags=["files"])
+    api_router.include_router(token.router, prefix="/tokens", tags=["tokens"])
+    api_router.include_router(action.router, prefix="/actions", tags=["actions"])

+ 0 - 0
app/api/v1/__init__.py


+ 84 - 0
app/api/v1/action.py

@@ -0,0 +1,84 @@
+from typing import Dict, List
+
+from fastapi import APIRouter, Depends
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlmodel import select
+
+from app.api.deps import get_async_session, get_token_id
+from app.libs.paginate import cursor_page, CommonPage
+from app.models.action import Action, ActionRead
+from app.models.token_relation import RelationType
+from app.providers.auth_provider import auth_policy
+from app.schemas.common import DeleteResponse, BaseSuccessDataResponse
+from app.schemas.tool.action import ActionBulkCreateRequest, ActionUpdateRequest, ActionRunRequest
+from app.services.tool.action import ActionService
+
+router = APIRouter()
+
+
+@router.get("", response_model=CommonPage[ActionRead])
+async def list_actions(*, session: AsyncSession = Depends(get_async_session), token_id=Depends(get_token_id)):
+    """
+    Returns a list of Actions.
+    """
+    statement = auth_policy.token_filter(
+        select(Action), field=Action.id, relation_type=RelationType.Action, token_id=token_id
+    )
+    page = await cursor_page(statement, session)
+    page.data = [ast.model_dump(by_alias=True) for ast in page.data]
+    return page.model_dump(by_alias=True)
+
+
+@router.post("", response_model=List[ActionRead])
+async def create_actions(
+    *, session: AsyncSession = Depends(get_async_session), body: ActionBulkCreateRequest, token_id=Depends(get_token_id)
+):
+    """
+    Create an action with openapi schema.
+    """
+
+    actions = await ActionService.create_actions(session=session, body=body, token_id=token_id)
+    actions = [item.model_dump(by_alias=True) for item in actions]
+    return actions
+
+
+@router.get("/{action_id}", response_model=ActionRead)
+async def get_action(*, session: AsyncSession = Depends(get_async_session), action_id: str):
+    """
+    Retrieves an action.
+    """
+    action = await ActionService.get_action(session=session, action_id=action_id)
+    return action.model_dump(by_alias=True)
+
+
+@router.post("/{action_id}", response_model=ActionRead)
+async def modify_action(
+    *, session: AsyncSession = Depends(get_async_session), action_id: str, body: ActionUpdateRequest
+):
+    """
+    Modifies an action.
+    """
+    action = await ActionService.modify_action(session=session, action_id=action_id, body=body)
+    return action.model_dump(by_alias=True)
+
+
+@router.delete("/{action_id}", response_model=DeleteResponse)
+async def delete_action(*, session: AsyncSession = Depends(get_async_session), action_id: str) -> DeleteResponse:
+    """
+    Delete an action.
+    """
+    return await ActionService.delete_action(session=session, action_id=action_id)
+
+
+@router.post(
+    "/{action_id}/run",
+    response_model=BaseSuccessDataResponse,
+)
+async def api_run_action(*, session: AsyncSession = Depends(get_async_session), action_id: str, body: ActionRunRequest):
+    response: Dict = await ActionService.run_action(
+        session=session,
+        action_id=action_id,
+        parameters=body.parameters,
+        headers=body.headers,
+    )
+    return BaseSuccessDataResponse(data=response)

+ 65 - 0
app/api/v1/assistant.py

@@ -0,0 +1,65 @@
+from fastapi import APIRouter, Depends
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlmodel import select
+
+from app.api.deps import get_token_id, get_async_session
+from app.models.assistant import Assistant, AssistantUpdate, AssistantCreate, AssistantRead
+from app.libs.paginate import cursor_page, CommonPage
+from app.models.token_relation import RelationType
+from app.providers.auth_provider import auth_policy
+from app.schemas.common import DeleteResponse
+from app.services.assistant.assistant import AssistantService
+
+router = APIRouter()
+
+
+@router.get("", response_model=CommonPage[AssistantRead])
+async def list_assistants(*, session: AsyncSession = Depends(get_async_session), token_id=Depends(get_token_id)):
+    """
+    Returns a list of assistants.
+    """
+    statement = auth_policy.token_filter(
+        select(Assistant), field=Assistant.id, relation_type=RelationType.Assistant, token_id=token_id
+    )
+    asts_page = await cursor_page(statement, session)
+    asts_page.data = [ast.model_dump(by_alias=True) for ast in asts_page.data]
+    return asts_page
+
+
+@router.post("", response_model=AssistantRead)
+async def create_assistant(
+    *, session: AsyncSession = Depends(get_async_session), body: AssistantCreate, token_id=Depends(get_token_id)
+):
+    """
+    Create an assistant with a model and instructions.
+    """
+    ast = await AssistantService.create_assistant(session=session, body=body, token_id=token_id)
+    return ast.model_dump(by_alias=True)
+
+
+@router.get("/{assistant_id}", response_model=AssistantRead)
+async def get_assistant(*, session: AsyncSession = Depends(get_async_session), assistant_id: str):
+    """
+    Retrieves an assistant.
+    """
+    ast = await AssistantService.get_assistant(session=session, assistant_id=assistant_id)
+    return ast.model_dump(by_alias=True)
+
+
+@router.post("/{assistant_id}", response_model=AssistantRead)
+async def modify_assistant(
+    *, session: AsyncSession = Depends(get_async_session), assistant_id: str, body: AssistantUpdate
+):
+    """
+    Modifies an assistant.
+    """
+    ast = await AssistantService.modify_assistant(session=session, assistant_id=assistant_id, body=body)
+    return ast.model_dump(by_alias=True)
+
+
+@router.delete("/{assistant_id}", response_model=DeleteResponse)
+async def delete_assistant(*, session: AsyncSession = Depends(get_async_session), assistant_id: str) -> DeleteResponse:
+    """
+    Delete an assistant.
+    """
+    return await AssistantService.delete_assistant(session=session, assistant_id=assistant_id)

+ 60 - 0
app/api/v1/assistant_file.py

@@ -0,0 +1,60 @@
+from fastapi import APIRouter, Depends
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlmodel import select
+
+from app.api.deps import get_async_session
+from app.libs.paginate import cursor_page, CommonPage
+from app.models.assistant_file import AssistantFileCreate, AssistantFile
+from app.schemas.common import DeleteResponse
+from app.services.assistant.assistant_file import AssistantFileService
+
+router = APIRouter()
+
+
+@router.get("/{assistant_id}/files", response_model=CommonPage[AssistantFile])
+async def list_assistant_files(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    assistant_id: str,
+):
+    """
+    Returns a list of assistant files.
+    """
+    return await cursor_page(select(AssistantFile).where(AssistantFile.assistant_id == assistant_id), db=session)
+
+
+@router.post("/{assistant_id}/files", response_model=AssistantFile)
+async def create_assistant_file(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    assistant_id: str,
+    body: AssistantFileCreate,
+) -> AssistantFile:
+    """
+    Create an assistant file by attaching a [File](/docs/api-reference/files)
+    to an [assistant](/docs/api-reference/assistants).
+    """
+    return await AssistantFileService.create_assistant_file(session=session, assistant_id=assistant_id, body=body)
+
+
+@router.get("/{assistant_id}/files/{file_id}", response_model=AssistantFile)
+async def get_assistant_file(
+    *, session: AsyncSession = Depends(get_async_session), assistant_id: str, file_id: str
+) -> AssistantFile:
+    """
+    Retrieves an AssistantFile.
+    """
+    return await AssistantFileService.get_assistant_file(session=session, assistant_id=assistant_id, file_id=file_id)
+
+
+@router.delete(
+    "/{assistant_id}/files/{file_id}",
+    response_model=DeleteResponse,
+)
+async def delete_assistant_file(
+    *, session: AsyncSession = Depends(get_async_session), assistant_id: str, file_id: str
+) -> DeleteResponse:
+    """
+    Delete an assistant file.
+    """
+    return await AssistantFileService.delete_assistant_file(session=session, assistant_id=assistant_id, file_id=file_id)

+ 84 - 0
app/api/v1/files.py

@@ -0,0 +1,84 @@
+import io
+import os
+import urllib.parse
+from typing import Optional, List
+
+from fastapi import APIRouter, Depends, UploadFile, Form, HTTPException, Query
+from starlette.responses import StreamingResponse
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.api.deps import get_async_session
+from app.models import File
+from app.schemas.common import DeleteResponse
+from app.schemas.files import ListFilesResponse
+from app.services.file.file import FileService
+
+router = APIRouter()
+
+# 限制文件大小
+max_size = 512 * 1024 * 1024
+# 支持的文件类型
+file_ext = [".csv", ".docx", ".html", ".json", ".md", ".pdf", ".pptx", ".txt",
+            ".xlsx", ".gif", ".png", ".jpg", ".jpeg", ".svg", ".mp3", ".mp4"]
+
+
+@router.get("", response_model=ListFilesResponse)
+async def list_files(
+    *,
+    purpose: Optional[str] = None,
+    file_ids: Optional[List[str]] = Query(None, alias="ids[]"),
+    session: AsyncSession = Depends(get_async_session),
+) -> ListFilesResponse:
+    """
+    Returns a list of files that belong to the user's organization.
+    """
+    files = await FileService.get_file_list(session=session, purpose=purpose, file_ids=file_ids)
+    return ListFilesResponse(data=files)
+
+
+@router.post("", response_model=File)
+async def create_file(
+    *, session: AsyncSession = Depends(get_async_session), purpose: str = Form(default="assistants"), file: UploadFile
+) -> File:
+    """
+    The size of individual files can be a maximum of 512 MB. See the [Assistants Tools guide]
+    (/docs/assistants/tools) to learn more about the types of files supported.
+    """
+    # 判断后缀名
+    _, file_extension = os.path.splitext(file.filename)
+    if file_extension not in file_ext:
+        raise HTTPException(status_code=400, detail=f"文件类型{file_extension}暂时不支持")
+    # 判断文件大小
+    if file.size == 0 or file.size > max_size:
+        raise HTTPException(status_code=413, detail="File too large")
+    print(FileService)
+    return await FileService.create_file(session=session, purpose=purpose, file=file)
+
+
+@router.delete("/{file_id}", response_model=DeleteResponse)
+async def delete_file(*, session: AsyncSession = Depends(get_async_session), file_id: str) -> DeleteResponse:
+    """
+    Delete a file.
+    """
+    return await FileService.delete_file(session=session, file_id=file_id)
+
+
+@router.get("/{file_id}", response_model=File)
+async def retrieve_file(*, session: AsyncSession = Depends(get_async_session), file_id: str) -> File:
+    """
+    Returns information about a specific file.
+    """
+    return await FileService.get_file(session=session, file_id=file_id)
+
+
+@router.get("/{file_id}/content", response_class=StreamingResponse)
+async def download_file(*, file_id: str, session: AsyncSession = Depends(get_async_session)):
+    """
+    Returns the contents of the specified file.
+    """
+    file_data, filename = await FileService.get_file_content(session=session, file_id=file_id)
+
+    response = StreamingResponse(io.BytesIO(file_data), media_type="application/octet-stream")
+    response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{urllib.parse.quote(filename)}"
+    response.headers["Content-Type"] = "application/octet-stream"
+    return response

+ 113 - 0
app/api/v1/message.py

@@ -0,0 +1,113 @@
+from typing import Optional
+from fastapi import APIRouter, Depends
+from fastapi.params import Query
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlmodel import select
+
+from app.api.deps import get_async_session
+from app.models import MessageFile
+from app.models.message import Message, MessageCreate, MessageUpdate, MessageRead
+from app.libs.paginate import cursor_page, CommonPage
+from app.services.message.message import MessageService
+
+router = APIRouter()
+
+
+@router.get(
+    "/{thread_id}/messages",
+    response_model=CommonPage[MessageRead],
+)
+async def list_messages(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    run_id: Optional[str] = Query(None, description="Filter messages by the run ID that generated them."),
+):
+    """
+    Returns a list of messages for a given thread.
+    """
+    statement = select(Message).where(Message.thread_id == thread_id)
+    if run_id:
+        # 根据 run_id 进行过滤
+        statement = statement.where(Message.run_id == run_id)
+
+    page = await cursor_page(statement, session)
+    page.data = [ast.model_dump(by_alias=True) for ast in page.data]
+    return page
+
+
+@router.post("/{thread_id}/messages", response_model=MessageRead)
+async def create_message(
+    *, session: AsyncSession = Depends(get_async_session), thread_id: str, body: MessageCreate
+):
+    """
+    Create a message.
+    """
+    message = await MessageService.create_message(session=session, thread_id=thread_id, body=body)
+    return message.model_dump(by_alias=True)
+
+
+@router.get(
+    "/{thread_id}/messages/{message_id}",
+    response_model=MessageRead,
+)
+async def get_message(
+    *, session: AsyncSession = Depends(get_async_session), thread_id: str, message_id: str
+):
+    """
+    Retrieve a message.
+    """
+    message = await MessageService.get_message(session=session, thread_id=thread_id, message_id=message_id)
+    return message.model_dump(by_alias=True)
+
+
+@router.post(
+    "/{thread_id}/messages/{message_id}",
+    response_model=MessageRead,
+)
+async def modify_message(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    message_id: str = ...,
+    body: MessageUpdate = ...,
+):
+    """
+    Modifies a message.
+    """
+    message = await MessageService.modify_message(session=session, thread_id=thread_id, message_id=message_id, body=body)
+    return message.model_dump(by_alias=True)
+
+
+@router.get(
+    "/{thread_id}/messages/{message_id}/files",
+    response_model=CommonPage[MessageFile],
+)
+async def list_message_files(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    message_id: str = ...,
+):
+    """
+    Returns a list of message files.
+    """
+    return await cursor_page(select(MessageFile).where(MessageFile.message_id == message_id), session)
+
+
+@router.get(
+    "/{thread_id}/messages/{message_id}/files/{file_id}",
+    response_model=MessageFile,
+)
+async def get_message_file(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    message_id: str = ...,
+    file_id: str = ...,
+) -> MessageFile:
+    """
+    Retrieves a message file.
+    """
+    return await MessageService.get_message_file(
+        session=session, thread_id=thread_id, message_id=message_id, file_id=file_id
+    )

+ 192 - 0
app/api/v1/runs.py

@@ -0,0 +1,192 @@
+from fastapi import APIRouter, Depends, Request
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlmodel import select
+from starlette.responses import StreamingResponse
+
+from app.api.deps import get_async_session
+from app.core.runner import pub_handler
+from app.libs.paginate import cursor_page, CommonPage
+from app.models.run import RunCreate, RunRead, RunUpdate, Run
+from app.models.run_step import RunStep, RunStepRead
+from app.schemas.runs import SubmitToolOutputsRunRequest
+from app.schemas.threads import CreateThreadAndRun
+from app.services.run.run import RunService
+from app.services.thread.thread import ThreadService
+from app.tasks.run_task import run_task
+import json
+
+router = APIRouter()
+#print(run_task)
+
+@router.get(
+    "/{thread_id}/runs",
+    response_model=CommonPage[RunRead],
+)
+async def list_runs(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+):
+    """
+    Returns a list of runs belonging to a thread.
+    """
+    await ThreadService.get_thread(session=session, thread_id=thread_id)
+    page = await cursor_page(select(Run).where(Run.thread_id == thread_id), session)
+    page.data = [ast.model_dump(by_alias=True) for ast in page.data]
+    return page.model_dump(by_alias=True)
+
+
+@router.post(
+    "/{thread_id}/runs",
+    response_model=RunRead,
+)
+async def create_run(
+    *, session: AsyncSession = Depends(get_async_session), thread_id: str, body: RunCreate = ..., request: Request
+):
+    """
+    Create a run.
+    """
+    #body.stream = True
+    db_run = await RunService.create_run(session=session, thread_id=thread_id, body=body)
+    #db_run.file_ids = json.loads(db_run.file_ids)
+    event_handler = pub_handler.StreamEventHandler(run_id=db_run.id, is_stream=body.stream)
+    event_handler.pub_run_created(db_run)
+    event_handler.pub_run_queued(db_run)
+    print("22222233333333333344444444444444444555555555555555556")
+    #print(run_task)
+    run_task.apply_async(args=(db_run.id, body.stream))
+    print("22222222222222222222222222222222")
+    print(body.stream)
+    db_run.file_ids = json.loads(db_run.file_ids)
+    if body.stream:
+        return pub_handler.sub_stream(db_run.id, request)
+    else:
+        return db_run.model_dump(by_alias=True)
+
+
+@router.get(
+    "/{thread_id}/runs/{run_id}"
+#    response_model=RunRead,
+)
+async def get_run(*, session: AsyncSession = Depends(get_async_session), thread_id: str, run_id: str = ...) -> RunRead:
+    """
+    Retrieves a run.
+    """
+    run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
+    run.file_ids = json.loads(run.file_ids)
+    run.failed_at = int(run.failed_at.timestamp()) if run.failed_at else None
+    run.completed_at = int(run.completed_at.timestamp()) if run.completed_at else None
+    print(run)
+    return run.model_dump(by_alias=True)
+
+
+@router.post(
+    "/{thread_id}/runs/{run_id}",
+    response_model=RunRead,
+)
+async def modify_run(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    run_id: str = ...,
+    body: RunUpdate = ...,
+) -> RunRead:
+    """
+    Modifies a run.
+    """
+    run = await RunService.modify_run(session=session, thread_id=thread_id, run_id=run_id, body=body)
+    return run.model_dump(by_alias=True)
+
+
+@router.post(
+    "/{thread_id}/runs/{run_id}/cancel",
+    response_model=RunRead,
+)
+async def cancel_run(
+    *, session: AsyncSession = Depends(get_async_session), thread_id: str, run_id: str = ...
+) -> RunRead:
+    """
+    Cancels a run that is `in_progress`.
+    """
+    run = await RunService.cancel_run(session=session, thread_id=thread_id, run_id=run_id)
+    return run.model_dump(by_alias=True)
+
+
+@router.get(
+    "/{thread_id}/runs/{run_id}/steps",
+    response_model=CommonPage[RunStepRead],
+)
+async def list_run_steps(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    run_id: str = ...,
+):
+    """
+    Returns a list of run steps belonging to a run.
+    """
+    page = await cursor_page(
+        select(RunStep).where(RunStep.thread_id == thread_id).where(RunStep.run_id == run_id), session
+    )
+    page.data = [ast.model_dump(by_alias=True) for ast in page.data]
+    return page.model_dump(by_alias=True)
+
+
+@router.get(
+    "/{thread_id}/runs/{run_id}/steps/{step_id}",
+    response_model=RunStepRead,
+)
+async def get_run_step(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    run_id: str = ...,
+    step_id: str = ...,
+) -> RunStep:
+    """
+    Retrieves a run step.
+    """
+    run_step = await RunService.get_run_step(thread_id=thread_id, run_id=run_id, step_id=step_id, session=session)
+    return run_step.model_dump(by_alias=True)
+
+
+@router.post(
+    "/{thread_id}/runs/{run_id}/submit_tool_outputs",
+    response_model=RunRead,
+)
+async def submit_tool_outputs_to_run(
+    *,
+    session: AsyncSession = Depends(get_async_session),
+    thread_id: str,
+    run_id: str = ...,
+    body: SubmitToolOutputsRunRequest = ...,
+    request: Request,
+) -> RunRead:
+    """
+    When a run has the `status: "requires_action"` and `required_action.type` is `submit_tool_outputs`,
+    this endpoint can be used to submit the outputs from the tool calls once they're all completed.
+    All outputs must be submitted in a single request.
+    """
+    db_run = await RunService.submit_tool_outputs_to_run(session=session, thread_id=thread_id, run_id=run_id, body=body)
+    # Resume async task
+    if db_run.status == "queued":
+        run_task.apply_async(args=(db_run.id, body.stream))
+
+    if body.stream:
+        return pub_handler.sub_stream(db_run.id, request)
+    else:
+        return db_run.model_dump(by_alias=True)
+
+
+@router.post("/runs", response_model=RunRead)
+async def create_thread_and_run(
+    *, session: AsyncSession = Depends(get_async_session), body: CreateThreadAndRun, request: Request
+) -> RunRead:
+    """
+    Create a thread and run it in one request.
+    """
+    run = await RunService.create_thread_and_run(session=session, body=body)
+    if body.stream:
+        return pub_handler.sub_stream(run.id, request)
+    else:
+        return run.model_dump(by_alias=True)

+ 45 - 0
app/api/v1/thread.py

@@ -0,0 +1,45 @@
+from fastapi import APIRouter, Depends
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.api.deps import get_token_id, get_async_session
+from app.models.thread import Thread, ThreadUpdate, ThreadCreate
+from app.schemas.common import DeleteResponse
+from app.services.thread.thread import ThreadService
+
+router = APIRouter()
+
+
+@router.post("", response_model=Thread)
+async def create_thread(
+    *, session: AsyncSession = Depends(get_async_session), body: ThreadCreate, token_id=Depends(get_token_id)
+) -> Thread:
+    """
+    Create a thread.
+    """
+    return await ThreadService.create_thread(session=session, body=body, token_id=token_id)
+
+
+@router.get("/{thread_id}", response_model=Thread)
+async def get_thread(*, session: AsyncSession = Depends(get_async_session), thread_id: str) -> Thread:
+    """
+    Retrieves a thread.
+    """
+    return await ThreadService.get_thread(session=session, thread_id=thread_id)
+
+
+@router.post("/{thread_id}", response_model=Thread)
+async def modify_thread(
+    *, session: AsyncSession = Depends(get_async_session), thread_id: str, body: ThreadUpdate
+) -> Thread:
+    """
+    Modifies a thread.
+    """
+    return await ThreadService.modify_thread(session=session, thread_id=thread_id, body=body)
+
+
+@router.delete("/{thread_id}", response_model=DeleteResponse)
+async def delete_thread(*, session: AsyncSession = Depends(get_async_session), thread_id: str) -> DeleteResponse:
+    """
+    Delete a thread.
+    """
+    return await ThreadService.delete_assistant(session=session, thread_id=thread_id)

+ 45 - 0
app/api/v1/token.py

@@ -0,0 +1,45 @@
+from fastapi import APIRouter, Depends
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlmodel import select
+
+from app.api.deps import verify_admin_token, get_async_session
+from app.libs.paginate import CommonPage, cursor_page
+from app.models.token import Token, TokenCreate, TokenUpdate
+from app.services.token.token import TokenService
+
+router = APIRouter()
+
+
+@router.get("", response_model=CommonPage[Token], dependencies=[Depends(verify_admin_token)])
+async def list_tokens(*, session: AsyncSession = Depends(get_async_session)):
+    """
+    Returns a list of tokens.
+    """
+    statement = select(Token)
+    return await cursor_page(statement, session)
+
+
+@router.post("", response_model=Token, dependencies=[Depends(verify_admin_token)])
+async def create_token(*, session: AsyncSession = Depends(get_async_session), body: TokenCreate) -> Token:
+    """
+    Create a token with a llm url & token.
+    """
+    return await TokenService.create_token(session=session, body=body)
+
+
+@router.get("/{token}", response_model=Token, dependencies=[Depends(verify_admin_token)])
+async def get_token(*, session: AsyncSession = Depends(get_async_session), token: str) -> Token:
+    """
+    Retrieves a token.
+    """
+    return await TokenService.get_token(session=session, token=token)
+
+
+@router.get("/refresh_token/{token}", response_model=Token, dependencies=[Depends(verify_admin_token)])
+async def refresh_token(*, session: AsyncSession = Depends(get_async_session), token: str) -> Token:
+    return await TokenService.refresh_token(session=session, token=token)
+
+
+@router.post("/modify_token/{token}", response_model=Token, dependencies=[Depends(verify_admin_token)])
+async def modify_token(*, session: AsyncSession = Depends(get_async_session), body: TokenUpdate, token: str) -> Token:
+    return await TokenService.modify_token(session=session, body=body, token=token)

+ 0 - 0
app/core/__init__.py


+ 0 - 0
app/core/doc_loaders/__init__.py


+ 44 - 0
app/core/doc_loaders/doc_loader.py

@@ -0,0 +1,44 @@
+from langchain.document_loaders import Blob
+from langchain.document_loaders.parsers import BS4HTMLParser, PyMuPDFParser
+from langchain.document_loaders.parsers.generic import MimeTypeBasedParser
+from langchain.document_loaders.parsers.txt import TextParser
+
+PARSER_HANDLERS = {
+    "application/pdf": PyMuPDFParser(),
+    "text/plain": TextParser(),
+    "text/html": BS4HTMLParser(),
+}
+
+MIMETYPE_PARSER = MimeTypeBasedParser(
+    handlers=PARSER_HANDLERS,
+    fallback_parser=None,
+)
+
+
+def _get_mimetype(file_bytes: bytes) -> str:
+    try:
+        import magic
+    except ImportError:
+        raise ImportError(
+            "magic package not found, please install it with `pip install python-magic` and `brew install libmagic`"
+        )
+
+    mime = magic.Magic(mime=True)
+    mime_type = mime.from_buffer(file_bytes)
+    return mime_type
+
+
+def load(data: bytes) -> str:
+    mimetype = _get_mimetype(data)
+    blob = Blob.from_data(
+        data=data,
+        mime_type=mimetype,
+    )
+
+    parser = MIMETYPE_PARSER
+
+    docs = []
+    for document in parser.lazy_parse(blob):
+        docs.append(document)
+
+    return "\n\n".join([doc.page_content for doc in docs])

+ 0 - 0
app/core/runner/__init__.py


+ 61 - 0
app/core/runner/llm_backend.py

@@ -0,0 +1,61 @@
+import logging
+from typing import List
+
+from openai import OpenAI, Stream
+from openai.types.chat import ChatCompletionChunk, ChatCompletion
+
+
+class LLMBackend:
+    """
+    openai chat 接口封装
+    """
+
+    def __init__(self, base_url: str, api_key) -> None:
+        self.base_url = base_url + "/" if base_url else None
+        self.api_key = api_key
+        self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
+
+    def run(
+        self,
+        messages: List,
+        model: str,
+        tools: List = None,
+        tool_choice="auto",
+        stream=False,
+        stream_options=None,
+        extra_body=None,
+        temperature=None,
+        top_p=None,
+        response_format=None,
+    ) -> ChatCompletion | Stream[ChatCompletionChunk]:
+        chat_params = {
+            "messages": messages,
+            "model": model,
+            "stream": stream,
+        }
+        if extra_body:
+            model_params = extra_body.get("model_params")
+            if model_params:
+                if "n" in model_params:
+                    raise ValueError("n is not allowed in model_params")
+                chat_params.update(model_params)
+        if stream_options:
+            if isinstance(stream_options, dict):
+                if "include_usage" in stream_options:
+                    chat_params["stream_options"] = {"include_usage": bool(stream_options["include_usage"])}
+        if temperature:
+            chat_params["temperature"] = temperature
+        if top_p:
+            chat_params["top_p"] = top_p
+        if tools:
+            chat_params["tools"] = tools
+            chat_params["tool_choice"] = tool_choice if tool_choice else "auto"
+        if isinstance(response_format, dict) and response_format.get("type") == "json_object":
+            chat_params["response_format"] = {"type": "json_object"}
+        for message in chat_params['messages']:
+            if 'content' not in message:
+                message['content'] = ""
+        logging.info("chat_params: %s", chat_params)
+        response = self.client.chat.completions.create(**chat_params)
+        logging.info("chat_response: %s", response)
+        return response

+ 83 - 0
app/core/runner/llm_callback_handler.py

@@ -0,0 +1,83 @@
+import logging
+
+
+from openai import Stream
+from openai.types.chat import ChatCompletionChunk, ChatCompletionMessage
+
+from app.core.runner.pub_handler import StreamEventHandler
+from app.core.runner.utils import message_util
+
+
+class LLMCallbackHandler:
+    """
+    LLM chat callback handler, handling message sending and message merging
+    """
+
+    def __init__(
+        self, run_id: str, on_step_create_func, on_message_create_func, event_handler: StreamEventHandler
+    ) -> None:
+        super().__init__()
+        self.run_id = run_id
+        self.final_message_started = False
+        self.on_step_create_func = on_step_create_func
+        self.step = None
+        self.on_message_create_func = on_message_create_func
+        self.message = None
+        self.event_handler: StreamEventHandler = event_handler
+
+    def handle_llm_response(
+        self,
+        response_stream: Stream[ChatCompletionChunk],
+    ) -> ChatCompletionMessage:
+        """
+        Handle LLM response stream
+        :param response_stream: ChatCompletionChunk stream
+        :return: ChatCompletionMessage
+        """
+        message = ChatCompletionMessage(content="", role="assistant", tool_calls=[])
+
+        index = 0
+        try:
+            for chunk in response_stream:
+                logging.debug(chunk)
+
+                if chunk.usage:
+                    self.event_handler.pub_message_usage(chunk)
+                    continue
+
+                if not chunk.choices:
+                    continue
+
+                choice = chunk.choices[0]
+                delta = choice.delta
+
+                if not delta:
+                    continue
+
+                # merge tool call delta
+                if delta.tool_calls:
+                    for tool_call_delta in delta.tool_calls:
+                        message_util.merge_tool_call_delta(message.tool_calls, tool_call_delta)
+
+                elif delta.content:
+                    # call on delta message received
+                    if not self.final_message_started:
+                        self.final_message_started = True
+
+                        self.message = self.on_message_create_func(content="")
+                        self.step = self.on_step_create_func(self.message.id)
+                        logging.debug("create message and step (%s), (%s)", self.message, self.step)
+
+                        self.event_handler.pub_run_step_created(self.step)
+                        self.event_handler.pub_run_step_in_progress(self.step)
+                        self.event_handler.pub_message_created(self.message)
+                        self.event_handler.pub_message_in_progress(self.message)
+
+                    # append message content delta
+                    message.content += delta.content
+                    self.event_handler.pub_message_delta(self.message.id, index, delta.content, delta.role)
+        except Exception as e:
+            logging.error("handle_llm_response error: %s", e)
+            raise e
+
+        return message

+ 87 - 0
app/core/runner/memory.py

@@ -0,0 +1,87 @@
+"""
+限于 llm 对上下文长度的限制和成本控制,需要对上下文进行筛选整合,本模块用于相关策略描述
+"""
+from enum import Enum
+from typing import List
+from abc import ABC, abstractmethod
+
+
+class MemoryType(str, Enum):
+    """
+    support 3 kind of context memory
+    """
+
+    WINDOW = "window"
+    ZERO = "zero"
+    NAIVE = "naive"
+
+
+class Memory(ABC):
+    """
+    interface for context memory
+    """
+
+    @abstractmethod
+    def integrate_context(self, messages: List[dict]) -> List[dict]:
+        """
+        integrate context according to the memory
+        """
+
+
+class WindowMemory(Memory):
+    """
+    limit the context length to a fixed window size
+    """
+
+    def __init__(self, window_size: int = 20, max_token_size: int = 4000):
+        if window_size < 1 or max_token_size < 1:
+            raise ValueError("window size and max token size should be greater than 0")
+        self.window_size = window_size
+        self.max_token_size = max_token_size
+
+    def integrate_context(self, messages: List[dict]) -> List[dict]:
+        if not messages:
+            return []
+        histories = messages[-self.window_size :]
+        total_length = 0
+        for i, message in enumerate(reversed(histories)):
+            total_length += len(message["content"])
+            if total_length >= self.max_token_size:
+                return histories[len(histories) - i - 1 :]
+        return histories
+
+
+class NaiveMemory(Memory):
+    """
+    navie memory contains all the context
+    """
+
+    def integrate_context(self, messages: List[dict]) -> List[dict]:
+        return messages
+
+
+class ZeroMemory(Memory):
+    """
+    zero memory contains last user message
+    """
+
+    def integrate_context(self, messages: List[dict]) -> List[dict]:
+        if not messages:
+            return []
+        for i, message in enumerate(reversed(messages)):
+            if message["role"] == "user":
+                return messages[len(messages) - i - 1 :]
+
+
+Memories = {MemoryType.WINDOW: WindowMemory, MemoryType.ZERO: ZeroMemory, MemoryType.NAIVE: NaiveMemory}
+
+
+def find_memory(assistants_metadata: dict) -> Memory:
+    memory_type = assistants_metadata.get("type", MemoryType.NAIVE)
+    kwargs = assistants_metadata.copy()
+    kwargs.pop("type", None)
+
+    if kwargs:
+        return Memories[memory_type](**kwargs)
+    else:
+        return Memories[memory_type]()

+ 263 - 0
app/core/runner/pub_handler.py

@@ -0,0 +1,263 @@
+from datetime import datetime
+from typing import List, Tuple, Optional
+from fastapi import Request
+from sse_starlette import EventSourceResponse
+from openai.types.beta import assistant_stream_event as events
+import json
+from app.exceptions.exception import ResourceNotFoundError, InternalServerError
+from app.providers.database import redis_client
+
+"""
+LLM chat message event pub/sub handler
+"""
+
+
+def generate_channel_name(key: str) -> str:
+    return f"generate_event:{key}"
+
+
+def channel_exist(channel: str) -> bool:
+    return bool(redis_client.keys(channel))
+
+
+def pub_event(channel: str, data: dict) -> None:
+    """
+    publish events to channel
+    :param channel: channel name
+    :param event: event dict
+    """
+    redis_client.xadd(channel, data)
+    redis_client.expire(channel, 10 * 60)
+
+
+def read_event(channel: str, x_index: str = None) -> Tuple[Optional[str], Optional[dict]]:
+    """
+    Read events from the channel, starting from the next index of x_index
+    :param channel: channel name
+    :param x_index: previous event_id, first time is empty
+    :return: event index, event data
+    """
+    if not x_index:
+        x_index = "0-0"
+
+    data = redis_client.xread({channel: x_index}, count=1, block=180_000)
+    if not data:
+        return None, None
+
+    stream_id = data[0][1][0][0]
+    event = data[0][1][0][1]
+    return stream_id, event
+
+
+def save_last_stream_id(run_id: str, stream_id: str):
+    """
+    保存当前 run_id 对应的最新 stream_id
+    :param run_id: 当前的运行 ID
+    :param stream_id: 最新的 stream_id
+    """
+    redis_client.set(f"run:{run_id}:last_stream_id", stream_id, 10 * 60)
+
+
+def get_last_stream_id(run_id: str) -> str:
+    """
+    获取上次保存的 stream_id
+    :param run_id: 当前的运行 ID
+    :return: 上次的 stream_id 或 None
+    """
+    return redis_client.get(f"run:{run_id}:last_stream_id")
+
+
+def _data_adjust_tools(tools: List[dict]) -> List[dict]:
+    def _adjust_tool(tool: dict):
+        if tool["type"] not in {"code_interpreter", "file_search", "function"}:
+            return {
+                "type": "function",
+                "function": {
+                    "name": tool["type"],
+                },
+            }
+        else:
+            return tool
+
+    if tools:
+        return [_adjust_tool(tool) for tool in tools]
+    return []
+
+
+def _data_adjust(obj):
+    """
+    event data adjust:
+    """
+    id = obj.id
+    data = obj.model_dump(exclude={"id"})
+    data.update({"id": id})
+    if hasattr(obj, "tools"):
+        data["tools"] = _data_adjust_tools(data["tools"])
+
+    if hasattr(obj, "file_ids") and data["file_ids"] is None:
+        data["file_ids"] = []
+
+    for key, value in data.items():
+        if isinstance(value, datetime):
+            data[key] = value.timestamp()
+    data['parallel_tool_calls'] = True
+    data["file_ids"] = json.loads(data['file_ids']) 
+    return data
+
+
+def _data_adjust_message(obj):
+    data = _data_adjust(obj)
+    if "status" not in data:
+        data["status"] = "in_progress"
+    return data
+
+
+def _data_adjust_message_delta(step_details):
+    for index, tool_call in enumerate(step_details["tool_calls"]):
+        tool_call["index"] = index
+    return step_details
+
+
+def sub_stream(run_id, request: Request, prefix_events: List[dict] = [], suffix_events: List[dict] = []):
+    """
+    Subscription chat response stream
+    """
+    channel = generate_channel_name(run_id)
+
+    async def _stream():
+        for event in prefix_events:
+            yield event
+
+        last_index = get_last_stream_id(run_id)  # 获取上次的 stream_id
+        x_index = last_index or None
+        while True:
+            if await request.is_disconnected():
+                break
+            if not channel_exist(channel):
+                raise ResourceNotFoundError()
+
+            x_index, data = read_event(channel, x_index)
+            if not data:
+                break
+
+            if data["event"] == "done":
+                save_last_stream_id(run_id, x_index)  # 记录最新的 stream_id
+                break
+
+            if data["event"] == "error":
+                save_last_stream_id(run_id, x_index)  # 记录最新的 stream_id
+                raise InternalServerError(data["data"])
+
+            yield data
+            save_last_stream_id(run_id, x_index)  # 记录最新的 stream_id
+
+        for event in suffix_events:
+            yield event
+
+    return EventSourceResponse(_stream())
+
+
+class StreamEventHandler:
+    def __init__(self, run_id: str, is_stream: bool = False) -> None:
+        self._channel = generate_channel_name(key=run_id)
+        self._is_stream = is_stream
+
+    def pub_event(self, event) -> None:
+        if self._is_stream:
+            pub_event(self._channel, {"event": event.event, "data": event.data.json()})
+
+    def pub_run_created(self, run):
+        data=_data_adjust(run)
+        print(data)
+        self.pub_event(events.ThreadRunCreated(data=data, event="thread.run.created"))
+
+    def pub_run_queued(self, run):
+        self.pub_event(events.ThreadRunQueued(data=_data_adjust(run), event="thread.run.queued"))
+
+    def pub_run_in_progress(self, run):
+        self.pub_event(events.ThreadRunInProgress(data=_data_adjust(run), event="thread.run.in_progress"))
+
+    def pub_run_completed(self, run):
+        self.pub_event(events.ThreadRunCompleted(data=_data_adjust(run), event="thread.run.completed"))
+
+    def pub_run_requires_action(self, run):
+        self.pub_event(events.ThreadRunRequiresAction(data=_data_adjust(run), event="thread.run.requires_action"))
+
+    def pub_run_failed(self, run):
+        self.pub_event(events.ThreadRunFailed(data=_data_adjust(run), event="thread.run.failed"))
+
+    def pub_run_step_created(self, step):
+        self.pub_event(events.ThreadRunStepCreated(data=_data_adjust(step), event="thread.run.step.created"))
+
+    def pub_run_step_in_progress(self, step):
+        self.pub_event(events.ThreadRunStepInProgress(data=_data_adjust(step), event="thread.run.step.in_progress"))
+
+    def pub_run_step_delta(self, step_id, step_details):
+        self.pub_event(
+            events.ThreadRunStepDelta(
+                data={
+                    "id": step_id,
+                    "delta": {"step_details": _data_adjust_message_delta(step_details)},
+                    "object": "thread.run.step.delta",
+                },
+                event="thread.run.step.delta",
+            )
+        )
+
+    def pub_run_step_completed(self, step):
+        self.pub_event(events.ThreadRunStepCompleted(data=_data_adjust(step), event="thread.run.step.completed"))
+
+    def pub_run_step_failed(self, step):
+        self.pub_event(events.ThreadRunStepFailed(data=_data_adjust(step), event="thread.run.step.failed"))
+
+    def pub_message_created(self, message):
+        self.pub_event(events.ThreadMessageCreated(data=_data_adjust_message(message), event="thread.message.created"))
+
+    def pub_message_in_progress(self, message):
+        self.pub_event(
+            events.ThreadMessageInProgress(data=_data_adjust_message(message), event="thread.message.in_progress")
+        )
+
+    def pub_message_usage(self, chunk):
+        """
+        目前 stream 未有 usage 相关 event,借用 thread.message.in_progress 进行传输,待官方更新
+        """
+        data = {
+            "id": chunk.id,
+            "content": [],
+            "created_at": 0,
+            "object": "thread.message",
+            "role": "assistant",
+            "status": "in_progress",
+            "thread_id": "",
+            "metadata": {"usage": chunk.usage.json()}
+        }
+        self.pub_event(
+            events.ThreadMessageInProgress(data=data, event="thread.message.in_progress")
+        )
+
+    def pub_message_completed(self, message):
+        self.pub_event(
+            events.ThreadMessageCompleted(data=_data_adjust_message(message), event="thread.message.completed")
+        )
+
+    def pub_message_delta(self, message_id, index, content, role):
+        """
+        pub MessageDelta
+        """
+        self.pub_event(
+            events.ThreadMessageDelta(
+                data=events.MessageDeltaEvent(
+                    id=message_id,
+                    delta={"content": [{"index": index, "type": "text", "text": {"value": content}}], "role": role},
+                    object="thread.message.delta",
+                ),
+                event="thread.message.delta",
+            )
+        )
+
+    def pub_done(self):
+        pub_event(self._channel, {"event": "done", "data": "done"})
+
+    def pub_error(self, msg):
+        pub_event(self._channel, {"event": "error", "data": msg})

+ 298 - 0
app/core/runner/thread_runner.py

@@ -0,0 +1,298 @@
+from functools import partial
+import logging
+
+from typing import List
+from concurrent.futures import Executor
+
+from sqlalchemy.orm import Session
+
+from app.models.token_relation import RelationType
+from config.config import settings
+from config.llm import llm_settings, tool_settings
+
+from app.core.runner.llm_backend import LLMBackend
+from app.core.runner.llm_callback_handler import LLMCallbackHandler
+from app.core.runner.memory import Memory, find_memory
+from app.core.runner.pub_handler import StreamEventHandler
+from app.core.runner.utils import message_util as msg_util
+from app.core.runner.utils.tool_call_util import (
+    tool_call_recognize,
+    internal_tool_call_invoke,
+    tool_call_request,
+    tool_call_id,
+    tool_call_output,
+)
+from app.core.tools import find_tools, BaseTool
+from app.libs.thread_executor import get_executor_for_config, run_with_executor
+from app.models.message import Message, MessageUpdate
+from app.models.run import Run
+from app.models.run_step import RunStep
+from app.models.token_relation import RelationType
+from app.services.assistant.assistant import AssistantService
+from app.services.file.file import FileService
+from app.services.message.message import MessageService
+from app.services.run.run import RunService
+from app.services.run.run_step import RunStepService
+from app.services.token.token import TokenService
+from app.services.token.token_relation import TokenRelationService
+
+
+class ThreadRunner:
+    """
+    ThreadRunner 封装 run 的执行逻辑
+    """
+
+    tool_executor: Executor = get_executor_for_config(tool_settings.TOOL_WORKER_NUM, "tool_worker_")
+
+    def __init__(self, run_id: str, session: Session, stream: bool = False):
+        self.run_id = run_id
+        self.session = session
+        self.stream = stream
+        self.max_step = llm_settings.LLM_MAX_STEP
+        self.event_handler: StreamEventHandler = None
+
+    def run(self):
+        """
+        完成一次 run 的执行,基本步骤
+        1. 初始化,获取 run 以及相关 tools, 构造 system instructions;
+        2. 开始循环,查询已有 run step, 进行 chat message 生成;
+        3. 调用 llm 并解析返回结果;
+        4. 根据返回结果,生成新的 run step(tool calls 处理) 或者 message
+        """
+        # TODO: 重构,将 run 的状态变更逻辑放到 RunService 中
+        run = RunService.get_run_sync(session=self.session, run_id=self.run_id)
+        self.event_handler = StreamEventHandler(run_id=self.run_id, is_stream=self.stream)
+
+        run = RunService.to_in_progress(session=self.session, run_id=self.run_id)
+        self.event_handler.pub_run_in_progress(run)
+        logging.info("processing ThreadRunner task, run_id: %s", self.run_id)
+
+        # get memory from assistant metadata
+        # format likes {"memory": {"type": "window", "window_size": 20, "max_token_size": 4000}}
+        ast = AssistantService.get_assistant_sync(session=self.session, assistant_id=run.assistant_id)
+        metadata = ast.metadata_ or {}
+        memory = find_memory(metadata.get("memory", {}))
+
+        instructions = [run.instructions] if run.instructions else [ast.instructions]
+        tools = find_tools(run, self.session)
+        for tool in tools:
+            tool.configure(session=self.session, run=run)
+            instruction_supplement = tool.instruction_supplement()
+            if instruction_supplement:
+                instructions += [instruction_supplement]
+        instruction = "\n".join(instructions)
+
+        llm = self.__init_llm_backend(run.assistant_id)
+        loop = True
+        while loop:
+            run_steps = RunStepService.get_run_step_list(
+                session=self.session, run_id=self.run_id, thread_id=run.thread_id
+            )
+            loop = self.__run_step(llm, run, run_steps, instruction, tools, memory)
+
+        # 任务结束
+        self.event_handler.pub_run_completed(run)
+        self.event_handler.pub_done()
+
+    def __run_step(
+        self,
+        llm: LLMBackend,
+        run: Run,
+        run_steps: List[RunStep],
+        instruction: str,
+        tools: List[BaseTool],
+        memory: Memory,
+    ):
+        """
+        执行 run step
+        """
+        logging.info("step %d is running", len(run_steps) + 1)
+
+        assistant_system_message = [msg_util.system_message(instruction)]
+
+        # 获取已有 message 上下文记录
+        chat_messages = self.__generate_chat_messages(
+            MessageService.get_message_list(session=self.session, thread_id=run.thread_id)
+        )
+
+        tool_call_messages = []
+        for step in run_steps:
+            if step.type == "tool_calls" and step.status == "completed":
+                tool_call_messages += self.__convert_assistant_tool_calls_to_chat_messages(step)
+
+        # memory
+        messages = assistant_system_message + memory.integrate_context(chat_messages) + tool_call_messages 
+
+        response_stream = llm.run(
+            messages=messages,
+            model=run.model,
+            tools=[tool.openai_function for tool in tools],
+            tool_choice="auto" if len(run_steps) < self.max_step else "none",
+            stream=True,
+            stream_options=run.stream_options,
+            extra_body=run.extra_body,
+            temperature=run.temperature,
+            top_p=run.top_p,
+            response_format=run.response_format,
+        )
+
+        # create message callback
+        create_message_callback = partial(
+            MessageService.new_message,
+            session=self.session,
+            assistant_id=run.assistant_id,
+            thread_id=run.thread_id,
+            run_id=run.id,
+            role="assistant",
+        )
+
+        # create 'message creation' run step callback
+        def _create_message_creation_run_step(message_id):
+            return RunStepService.new_run_step(
+                session=self.session,
+                type="message_creation",
+                assistant_id=run.assistant_id,
+                thread_id=run.thread_id,
+                run_id=run.id,
+                step_details={"type": "message_creation", "message_creation": {"message_id": message_id}},
+            )
+
+        llm_callback_handler = LLMCallbackHandler(
+            run_id=run.id,
+            on_step_create_func=_create_message_creation_run_step,
+            on_message_create_func=create_message_callback,
+            event_handler=self.event_handler,
+        )
+        response_msg = llm_callback_handler.handle_llm_response(response_stream)
+        message_creation_run_step = llm_callback_handler.step
+        logging.info("chat_response_message: %s", response_msg)
+
+        if msg_util.is_tool_call(response_msg):
+            # tool & tool_call definition dict
+            tool_calls = [tool_call_recognize(tool_call, tools) for tool_call in response_msg.tool_calls]
+
+            # new run step for tool calls
+            new_run_step = RunStepService.new_run_step(
+                session=self.session,
+                type="tool_calls",
+                assistant_id=run.assistant_id,
+                thread_id=run.thread_id,
+                run_id=run.id,
+                step_details={"type": "tool_calls", "tool_calls": [tool_call_dict for _, tool_call_dict in tool_calls]},
+            )
+            self.event_handler.pub_run_step_created(new_run_step)
+            self.event_handler.pub_run_step_in_progress(new_run_step)
+
+            internal_tool_calls = list(filter(lambda _tool_calls: _tool_calls[0] is not None, tool_calls))
+            external_tool_call_dict = [tool_call_dict for tool, tool_call_dict in tool_calls if tool is None]
+
+            # 为减少线程同步逻辑,依次处理内/外 tool_call 调用
+            if internal_tool_calls:
+                try:
+                    tool_calls_with_outputs = run_with_executor(
+                        executor=ThreadRunner.tool_executor,
+                        func=internal_tool_call_invoke,
+                        tasks=internal_tool_calls,
+                        timeout=tool_settings.TOOL_WORKER_EXECUTION_TIMEOUT,
+                    )
+                    new_run_step = RunStepService.update_step_details(
+                        session=self.session,
+                        run_step_id=new_run_step.id,
+                        step_details={"type": "tool_calls", "tool_calls": tool_calls_with_outputs},
+                        completed=not external_tool_call_dict,
+                    )
+                except Exception as e:
+                    RunStepService.to_failed(session=self.session, run_step_id=new_run_step.id, last_error=e)
+                    raise e
+
+            if external_tool_call_dict:
+                # run 设置为 action required,等待业务完成更新并再次拉起
+                run = RunService.to_requires_action(
+                    session=self.session,
+                    run_id=run.id,
+                    required_action={
+                        "type": "submit_tool_outputs",
+                        "submit_tool_outputs": {"tool_calls": external_tool_call_dict},
+                    },
+                )
+                self.event_handler.pub_run_step_delta(
+                    step_id=new_run_step.id, step_details={"type": "tool_calls", "tool_calls": external_tool_call_dict}
+                )
+                self.event_handler.pub_run_requires_action(run)
+            else:
+                self.event_handler.pub_run_step_completed(new_run_step)
+                return True
+        else:
+            # 无 tool call 信息,message 生成结束,更新状态
+            new_message = MessageService.modify_message_sync(
+                session=self.session,
+                thread_id=run.thread_id,
+                message_id=llm_callback_handler.message.id,
+                body=MessageUpdate(content=response_msg.content),
+            )
+            self.event_handler.pub_message_completed(new_message)
+
+            new_step = RunStepService.update_step_details(
+                session=self.session,
+                run_step_id=message_creation_run_step.id,
+                step_details={"type": "message_creation", "message_creation": {"message_id": new_message.id}},
+                completed=True,
+            )
+            RunService.to_completed(session=self.session, run_id=run.id)
+            self.event_handler.pub_run_step_completed(new_step)
+
+        return False
+
+    def __init_llm_backend(self, assistant_id):
+        if settings.AUTH_ENABLE:
+            # init llm backend with token id
+            token_id = TokenRelationService.get_token_id_by_relation(
+                session=self.session, relation_type=RelationType.Assistant, relation_id=assistant_id
+            )
+            token = TokenService.get_token_by_id(self.session, token_id)
+            return LLMBackend(base_url=token.llm_base_url, api_key=token.llm_api_key)
+        else:
+            # init llm backend with llm settings
+            return LLMBackend(base_url=llm_settings.OPENAI_API_BASE, api_key=llm_settings.OPENAI_API_KEY)
+
+    def __generate_chat_messages(self, messages: List[Message]):
+        """
+        根据历史信息生成 chat message
+        """
+
+        chat_messages = []
+        for message in messages:
+            role = message.role
+            if role == "user":
+                message_content = []
+                if message.file_ids:
+                    files = FileService.get_file_list_by_ids(session=self.session, file_ids=message.file_ids)
+                    for file in files:
+                        chat_messages.append(msg_util.new_message(role, f'The file "{file.filename}" can be used as a reference'))
+                else:
+                    for content in message.content:
+                        if content["type"] == "text":
+                            message_content.append({"type": "text", "text": content["text"]["value"]})
+                        elif content["type"] == "image_url":
+                            message_content.append(content)
+                    chat_messages.append(msg_util.new_message(role, message_content))
+            elif role == "assistant":
+                message_content = ""
+                for content in message.content:
+                    if content["type"] == "text":
+                        message_content += content["text"]["value"]
+                chat_messages.append(msg_util.new_message(role, message_content))
+
+        return chat_messages
+
+    def __convert_assistant_tool_calls_to_chat_messages(self, run_step: RunStep):
+        """
+        根据 run step 执行结果生成 message 信息
+        每个 tool call run step 包含两部分,调用与结果(结果可能为多个信息)
+        """
+        tool_calls = run_step.step_details["tool_calls"]
+        tool_call_requests = [msg_util.tool_calls([tool_call_request(tool_call) for tool_call in tool_calls])]
+        tool_call_outputs = [
+            msg_util.tool_call_result(tool_call_id(tool_call), tool_call_output(tool_call)) for tool_call in tool_calls
+        ]
+        return tool_call_requests + tool_call_outputs

+ 0 - 0
app/core/runner/utils/__init__.py


+ 58 - 0
app/core/runner/utils/message_util.py

@@ -0,0 +1,58 @@
+"""
+This module provides utility functions for working with messages in the OpenAI API.
+
+Functions:
+- new_message(role: str, content: str) -> dict: Creates a new message with the specified role and content.
+- system_message(content: str) -> dict: Creates a system message with the specified content.
+- user_message(content: str) -> dict: Creates a user message with the specified content.
+- assistant_message(content: str) -> dict: Creates an assistant message with the specified content.
+- tool_calls(tool_calls) -> dict: Creates a message with assistant tool calls.
+- tool_call_result(id, content) -> dict: Creates a tool call result message with the specified ID and content.
+- is_tool_call(message: ChatCompletionMessage) -> bool: Checks if a message is a tool call.
+"""
+from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageToolCall
+from openai.types.chat.chat_completion_message_tool_call import Function
+
+
+def new_message(role: str, content: str):
+    if role != "user" and role != "system" and role != "assistant":
+        raise ValueError(f"Invalid role {role}")
+
+    return {"role": role, "content": content}
+
+
+def system_message(content: str):
+    return new_message("system", content)
+
+
+def user_message(content: str):
+    return new_message("user", content)
+
+
+def assistant_message(content: str):
+    return new_message("assistant", content)
+
+
+def tool_calls(tool_calls):
+    return {"role": "assistant", "tool_calls": tool_calls}
+
+
+def tool_call_result(id, content):
+    return {"role": "tool", "tool_call_id": id, "content": content}
+
+
+def is_tool_call(message: ChatCompletionMessage) -> bool:
+    return bool(message.tool_calls)
+
+
+def merge_tool_call_delta(tool_calls, tool_call_delta):
+    if len(tool_calls) - 1 >= tool_call_delta.index:
+        tool_call = tool_calls[tool_call_delta.index]
+        tool_call.function.arguments += tool_call_delta.function.arguments
+    else:
+        tool_call = ChatCompletionMessageToolCall(
+            id=tool_call_delta.id,
+            function=Function(name=tool_call_delta.function.name, arguments=tool_call_delta.function.arguments),
+            type=tool_call_delta.type,
+        )
+        tool_calls.append(tool_call)

+ 67 - 0
app/core/runner/utils/tool_call_util.py

@@ -0,0 +1,67 @@
+"""
+tool calls 常用转换方法
+从 llm 接口获取的 tool calls 为 ChatCompletionMessageToolCall 形式,类型为 function
+可通过 json() 获取 json 形式,格式如下:
+{
+    "id": "tool_call_0",
+    "function": {
+        "name": "file_search",
+        "arguments": "{\"file_keys\": [\"file_0\", \"file_1\"], \"query\": \"query\"}"
+    }
+}
+查询结果将放入 ["function"]["output"] 中
+"""
+
+from typing import List
+import json
+from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
+
+from app.core.tools.base_tool import BaseTool
+from app.core.tools.external_function_tool import ExternalFunctionTool
+
+
+def tool_call_recognize(tool_call: ChatCompletionMessageToolCall, tools: List[BaseTool]) -> (BaseTool, dict):
+    """
+    对齐 tool call 和 tool,仅针对内部 tool call
+    """
+    tool_name = tool_call.function.name
+    [tool] = [tool for tool in tools if tool.name == tool_name]
+    if isinstance(tool, ExternalFunctionTool):
+        tool = None
+    return (tool, json.loads(tool_call.json()))
+
+
+def internal_tool_call_invoke(tool: BaseTool, tool_call_dict: dict) -> dict:
+    """
+    internal tool call 执行,结果写入 output
+    """
+    args = json.loads(tool_call_dict["function"]["arguments"])
+    output = tool.run(**args)
+    tool_call_dict["function"]["output"] = json.dumps(output, ensure_ascii=False)
+    return tool_call_dict
+
+
+def tool_call_request(tool_call_dict: dict) -> dict:
+    """
+    tool call 结果需返回原始请求 & 结果
+    库中未存储 tool_call 原始请求,需进行重新组装
+    """
+    return {
+        "id": tool_call_dict["id"],
+        "type": "function",
+        "function": {"name": tool_call_dict["function"]["name"], "arguments": tool_call_dict["function"]["arguments"]},
+    }
+
+
+def tool_call_id(tool_call_dict: dict) -> str:
+    """
+    tool call id
+    """
+    return tool_call_dict["id"]
+
+
+def tool_call_output(tool_call_dict: dict) -> str:
+    """
+    tool call output
+    """
+    return tool_call_dict["function"]["output"]

+ 44 - 0
app/core/tools/__init__.py

@@ -0,0 +1,44 @@
+from enum import Enum
+from typing import List
+
+from sqlalchemy.orm import Session
+from sqlmodel import select
+
+from app.exceptions.exception import ServerError
+from app.models.action import Action
+from app.core.tools.base_tool import BaseTool
+from app.core.tools.external_function_tool import ExternalFunctionTool
+from app.core.tools.openapi_function_tool import OpenapiFunctionTool
+from app.core.tools.file_search_tool import FileSearchTool
+from app.core.tools.web_search import WebSearchTool
+
+
+class AvailableTools(str, Enum):
+    FILE_SEARCH = "file_search"
+    WEB_SEARCH = "web_search"
+
+
+TOOLS = {
+    AvailableTools.FILE_SEARCH: FileSearchTool,
+    AvailableTools.WEB_SEARCH: WebSearchTool,
+}
+
+
+def find_tools(run, session: Session) -> List[BaseTool]:
+    action_ids = [tool.get("id") for tool in run.tools if tool.get("type") == "action"]
+    actions = session.execute(select(Action).where(Action.id.in_(action_ids))).scalars().all()
+    action_map = {action.id: action for action in actions}
+
+    tools = []
+    for tool in run.tools:
+        tool_name = tool["type"]
+        if tool_name in TOOLS:
+            tools.append(TOOLS[tool_name]())
+        elif tool_name == "function":
+            tools.append(ExternalFunctionTool(tool))
+        elif tool_name == "action":
+            action = action_map.get(tool.get("id"))
+            tools.append(OpenapiFunctionTool(tool, run.extra_body, action))
+        else:
+            raise ServerError(f"Unknown tool type {tool}")
+    return tools

+ 75 - 0
app/core/tools/base_tool.py

@@ -0,0 +1,75 @@
+from abc import ABC
+from typing import Type, Dict, Any, Optional
+
+from langchain.tools import BaseTool as LCBaseTool
+from langchain.tools.render import format_tool_to_openai_function
+from pydantic import BaseModel, Field
+
+
+class BaseToolInput(BaseModel):
+    """
+    Base schema for tool input arguments.
+    """
+
+    input: str = Field(..., description="input")
+
+
+class BaseTool(ABC):
+    """
+    Base class for tools.
+
+    Attributes:
+        name (str): The name of the tool.
+        description (str): The description of the tool.
+        args_schema (Optional[Type[BaseModel]]): The schema for the tool's input arguments.
+        openai_function (Dict): The OpenAI function representation of the tool.
+    """
+
+    name: str
+    description: str
+
+    args_schema: Optional[Type[BaseModel]] = BaseToolInput
+
+    openai_function: Dict
+
+    def __init_subclass__(cls) -> None:
+        lc_tool = LCTool(name=cls.name, description=cls.description, args_schema=cls.args_schema, _run=lambda x: x)
+        cls.openai_function = {"type": "function", "function": dict(format_tool_to_openai_function(lc_tool))}
+
+    def configure(self, **kwargs):
+        """
+        Configure the tool with the provided keyword arguments.
+
+        Args:
+            **kwargs: Additional configuration parameters.
+        """
+        return
+
+    def run(self, **kwargs) -> Any:
+        """
+        Executes the tool with the given arguments.
+
+        Args:
+            **kwargs: Additional keyword arguments for the tool.
+
+        Returns:
+            Any: The result of executing the tool.
+        """
+        raise NotImplementedError()
+
+    def instruction_supplement(self) -> str:
+        """
+        Provides additional instructions to supplement the run instruction for the tool.
+
+        Returns:
+            str: The additional instructions.
+        """
+        return ""
+
+
+class LCTool(LCBaseTool):
+    name: str = ""
+    description: str = ""
+
+    def _run(self):
+        pass

+ 37 - 0
app/core/tools/external_function_tool.py

@@ -0,0 +1,37 @@
+from app.core.tools.base_tool import BaseTool
+
+
+class ExternalFunctionTool(BaseTool):
+    """
+    external tool, definition as follows:
+    {
+        "type": "function",
+        "function": {
+            "name": "calculator",
+            "parameters": {
+                "type": "object",
+                "required": ["input"],
+                "properties": {
+                    "input": {
+                        "type": "string",
+                        "description": "需要计算的算术表达式"
+                    }
+                }
+            },
+            "description": "计算器"
+        }
+    }
+    """
+
+    # for disable BaseTool.__init_subclass__
+    name = ""
+    description = ""
+    args_schema = None
+
+    def __init__(self, definition: dict) -> None:
+        if definition["type"] != "function" or "function" not in definition:
+            raise ValueError(f"definition format error: {definition}")
+
+        # 其它参数未使用到,暂时不做处理
+        self.openai_function = definition
+        self.name = definition["function"]["name"]

+ 66 - 0
app/core/tools/file_search_tool.py

@@ -0,0 +1,66 @@
+from typing import Type, List
+
+from pydantic import BaseModel, Field
+from sqlalchemy.orm import Session
+
+from app.core.tools.base_tool import BaseTool
+from app.models.run import Run
+from app.services.file.file import FileService
+
+
+class FileSearchToolInput(BaseModel):
+    indexes: List[int] = Field(..., description="file index list to look up in retrieval")
+    query: str = Field(..., description="query to look up in retrieval")
+
+
+class FileSearchTool(BaseTool):
+    name: str = "file_search"
+    description: str = (
+        "Can be used to look up information that was uploaded to this assistant."
+        "If the user is referencing particular files, that is often a good hint that information may be here."
+    )
+
+    args_schema: Type[BaseModel] = FileSearchToolInput
+
+    def __init__(self) -> None:
+        super().__init__()
+        self.__filenames = []
+        self.__keys = []
+
+    def configure(self, session: Session, run: Run, **kwargs):
+        """
+        置当前 Retrieval 涉及文件信息
+        """
+        files = FileService.get_file_list_by_ids(session=session, file_ids=run.file_ids)
+        # pre-cache data to prevent thread conflicts that may occur later on.
+        print("---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------")
+        print(files)
+        for file in files:
+            self.__filenames.append(file.filename)
+            self.__keys.append(file.key)
+        print(self.__keys)
+
+    def run(self, indexes: List[int], query: str) -> dict:
+        file_keys = []
+        print(self.__keys)
+        for index in indexes:
+            file_key = self.__keys[index]
+            file_keys.append(file_key)
+
+        files = FileService.search_in_files(query=query, file_keys=file_keys)
+        return files
+
+    def instruction_supplement(self) -> str:
+        """
+        为 Retrieval 提供文件选择信息,用于 llm 调用抉择
+        """
+        if len(self.__filenames) == 0:
+            return ""
+        else:
+            filenames_info = [f"({index}){filename}" for index, filename in enumerate(self.__filenames)]
+            return (
+                'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
+                + 'Each line represents a file in the format "(index)filename":\n'
+                + "\n".join(filenames_info)
+                + "\nMake sure to be extremely concise when using attached files. "
+            )

+ 52 - 0
app/core/tools/openapi_function_tool.py

@@ -0,0 +1,52 @@
+from app.models.action import Action
+from app.core.tools.base_tool import BaseTool
+from app.exceptions.exception import ResourceNotFoundError
+from app.services.tool.openapi_call import call_action_api
+from app.schemas.tool.action import ActionMethod, ActionBodyType
+from app.services.tool.openapi_utils import action_param_dict_to_schema
+from app.schemas.tool.authentication import Authentication
+
+
+class OpenapiFunctionTool(BaseTool):
+    """
+    openapi tool, definition as follows:
+    {'id': '65d6c295a09d481250cc8ed1', 'type': 'action'}
+    """
+
+    name = ""
+    description = ""
+    args_schema = None
+    action: Action = None
+
+    def __init__(self, definition: dict, extra_body: dict, action: Action) -> None:
+        if definition["type"] != "action" or "id" not in definition:
+            raise ValueError(f"definition format error: {definition}")
+        # an exception is thrown if no action is found
+        if action is None:
+            raise ResourceNotFoundError(message="action not found")
+        if not action.use_for_everyone:
+            action_authentications = extra_body.get("action_authentications")
+            if action_authentications:
+                authentication = action_authentications.get(action.id)
+                if authentication:
+                    action.authentication = authentication
+                else:
+                    action.authentication = {"type": "none"}
+        self.action = action
+        self.openai_function = {"type": "function", "function": action.function_def}
+        self.name = action.function_def["name"]
+
+    def run(self, **arguments: dict) -> dict:
+        action = self.action
+        response = call_action_api(
+            url=action.url,
+            method=ActionMethod(action.method),
+            path_param_schema=action_param_dict_to_schema(action.path_param_schema),
+            query_param_schema=action_param_dict_to_schema(action.query_param_schema),
+            body_param_schema=action_param_dict_to_schema(action.body_param_schema),
+            body_type=ActionBodyType(action.body_type),
+            parameters=arguments,
+            headers={},
+            authentication=Authentication(**action.authentication),
+        )
+        return response

+ 36 - 0
app/core/tools/web_search.py

@@ -0,0 +1,36 @@
+from typing import Type
+
+from langchain.utilities import BingSearchAPIWrapper
+from pydantic import BaseModel, Field
+
+from app.core.tools.base_tool import BaseTool
+from config.llm import tool_settings
+
+
+class WebSearchToolInput(BaseModel):
+    query: str = Field(
+        ...,
+        description="Search query. Use a format suitable for Bing and, if necessary, "
+        "use Bing's advanced search function",
+    )
+
+
+class WebSearchTool(BaseTool):
+    name: str = "web_search"
+    description: str = (
+        "A tool for performing a Bing search and extracting snippets and webpages "
+        "when you need to search for something you don't know or when your information "
+        "is not up to date. "
+        "Input should be a search query."
+    )
+
+    args_schema: Type[BaseModel] = WebSearchToolInput
+
+    _bing_search_api_wrapper = BingSearchAPIWrapper(
+        bing_search_url=tool_settings.BING_SEARCH_URL,
+        bing_subscription_key=tool_settings.BING_SUBSCRIPTION_KEY,
+        k=tool_settings.WEB_SEARCH_NUM_RESULTS,
+    )
+
+    def run(self, query: str) -> dict:
+        return self._bing_search_api_wrapper.results(query=query, num_results=tool_settings.WEB_SEARCH_NUM_RESULTS)

+ 0 - 0
app/exceptions/__init__.py


+ 110 - 0
app/exceptions/exception.py

@@ -0,0 +1,110 @@
+from typing import Any
+
+from fastapi import HTTPException
+
+
+class BaseHTTPException(HTTPException):
+    """
+    基础异常
+    """
+
+    type: str = None
+    param: str = None
+
+    def __init__(
+        self,
+        status_code: int,
+        error_code: str,
+        message: str = None,
+        type: str = None,
+        param: str = None,
+        detail: Any = None,
+    ):
+        self.status_code = status_code
+        self.error_code = error_code
+        self.message = message
+        self.type = type
+        self.param = param
+        super().__init__(status_code, detail)
+
+    def __str__(self) -> str:
+        return f"status_code={self.status_code} error_code={self.error_code} message={self.message}"
+
+
+class BadRequestError(BaseHTTPException):
+    """
+    请求参数异常
+    """
+
+    def __init__(self, message: str, error_code: str = "bad_request"):
+        self.status_code = 400
+        self.error_code = error_code
+        self.message = message
+        self.type = "invalid_request_error"
+
+
+class ValidateFailedError(BaseHTTPException):
+    """
+    校验失败
+    """
+
+    def __init__(self, message: str = "Validation failed", error_code: str = "validation_failed"):
+        self.status_code = 422
+        self.error_code = error_code
+        self.message = message
+        self.type = error_code
+
+
+class AuthenticationError(BaseHTTPException):
+    """
+    未认证
+    """
+
+    def __init__(self, message: str = "Unauthorized", error_code: str = "unauthorized"):
+        self.status_code = 401
+        self.error_code = error_code
+        self.message = message
+
+
+class AuthorizationError(BaseHTTPException):
+    """
+    未授权
+    """
+
+    def __init__(self, message: str = "Forbidden", error_code: str = "forbidden"):
+        self.status_code = 403
+        self.error_code = error_code
+        self.message = message
+
+
+class ResourceNotFoundError(BaseHTTPException):
+    """
+    资源不存在
+    """
+
+    def __init__(self, message: str = "Resource not found", error_code: str = "resource_not_found"):
+        self.status_code = 404
+        self.error_code = error_code
+        self.message = message
+        self.type = "not_found_error"
+
+
+class InternalServerError(BaseHTTPException):
+    """
+    服务器内部异常
+    """
+
+    def __init__(self, message: str = "Internal Server Error", error_code: str = "internal_server_error"):
+        self.status_code = 500
+        self.message = message
+        self.error_code = error_code
+        self.type = error_code
+
+
+class ServerError(BaseException):
+    """
+    服务端异常
+    """
+
+    def __init__(self, message: str):
+        self.message = message

+ 0 - 0
app/libs/__init__.py


+ 36 - 0
app/libs/bson/errors.py

@@ -0,0 +1,36 @@
+# Copyright 2009-present MongoDB, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Exceptions raised by the BSON package."""
+from __future__ import annotations
+
+
+class BSONError(Exception):
+    """Base class for all BSON exceptions."""
+
+
+class InvalidBSONError(BSONError):
+    """Raised when trying to create a BSON object from invalid data."""
+
+
+class InvalidStringDataError(BSONError):
+    """Raised when trying to encode a string containing non-UTF8 data."""
+
+
+class InvalidDocumentError(BSONError):
+    """Raised when trying to create a BSON object from an invalid document."""
+
+
+class InvalidIdError(BSONError):
+    """Raised when trying to create an ObjectId from invalid data."""

+ 276 - 0
app/libs/bson/objectid.py

@@ -0,0 +1,276 @@
+# Copyright 2009-2015 MongoDB, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tools for working with MongoDB ObjectIds."""
+from __future__ import annotations
+
+import binascii
+import calendar
+import datetime
+import os
+import struct
+import threading
+import time
+from random import SystemRandom
+from typing import Any, NoReturn, Optional, Type, Union
+
+from app.libs.bson.errors import InvalidIdError
+
+from app.libs.bson.tz_util import utc
+
+_MAX_COUNTER_VALUE = 0xFFFFFF
+
+
+def _raise_invalid_id(oid: str) -> NoReturn:
+    raise InvalidIdError("%r is not a valid ObjectId, it must be a 12-byte input" " or a 24-character hex string" % oid)
+
+
+def _random_bytes() -> bytes:
+    """Get the 5-byte random field of an ObjectId."""
+    return os.urandom(5)
+
+
+class ObjectId:
+    """A MongoDB ObjectId."""
+
+    _pid = os.getpid()
+
+    _inc = SystemRandom().randint(0, _MAX_COUNTER_VALUE)
+    _inc_lock = threading.Lock()
+
+    __random = _random_bytes()
+
+    __slots__ = ("__id",)
+
+    _type_marker = 7
+
+    def __init__(self, oid: Optional[Union[str, ObjectId, bytes]] = None) -> None:
+        """Initialize a new ObjectId.
+
+        An ObjectId is a 12-byte unique identifier consisting of:
+
+          - a 4-byte value representing the seconds since the Unix epoch,
+          - a 5-byte random value,
+          - a 3-byte counter, starting with a random value.
+
+        By default, ``ObjectId()`` creates a new unique identifier. The
+        optional parameter `oid` can be an :class:`ObjectId`, or any 12
+        :class:`bytes`.
+
+        For example, the 12 bytes b'foo-bar-quux' do not follow the ObjectId
+        specification but they are acceptable input::
+
+          >>> ObjectId(b'foo-bar-quux')
+          ObjectId('666f6f2d6261722d71757578')
+
+        `oid` can also be a :class:`str` of 24 hex digits::
+
+          >>> ObjectId('0123456789ab0123456789ab')
+          ObjectId('0123456789ab0123456789ab')
+
+        Raises :class:`~bson.errors.InvalidId` if `oid` is not 12 bytes nor
+        24 hex digits, or :class:`TypeError` if `oid` is not an accepted type.
+
+        :param oid: a valid ObjectId.
+
+        .. seealso:: The MongoDB documentation on  `ObjectIds <http://dochub.mongodb.org/core/objectids>`_.
+
+        .. versionchanged:: 3.8
+           :class:`~bson.objectid.ObjectId` now implements the `ObjectID
+           specification version 0.2
+           <https://github.com/mongodb/specifications/blob/master/source/
+           objectid.rst>`_.
+        """
+        if oid is None:
+            self.__generate()
+        elif isinstance(oid, bytes) and len(oid) == 12:
+            self.__id = oid
+        else:
+            self.__validate(oid)
+
+    @classmethod
+    def from_datetime(cls: Type[ObjectId], generation_time: datetime.datetime) -> ObjectId:
+        """Create a dummy ObjectId instance with a specific generation time.
+
+        This method is useful for doing range queries on a field
+        containing :class:`ObjectId` instances.
+
+        .. warning::
+           It is not safe to insert a document containing an ObjectId
+           generated using this method. This method deliberately
+           eliminates the uniqueness guarantee that ObjectIds
+           generally provide. ObjectIds generated with this method
+           should be used exclusively in queries.
+
+        `generation_time` will be converted to UTC. Naive datetime
+        instances will be treated as though they already contain UTC.
+
+        An example using this helper to get documents where ``"_id"``
+        was generated before January 1, 2010 would be:
+
+        >>> gen_time = datetime.datetime(2010, 1, 1)
+        >>> dummy_id = ObjectId.from_datetime(gen_time)
+        >>> result = collection.find({"_id": {"$lt": dummy_id}})
+
+        :param generation_time: :class:`~datetime.datetime` to be used
+            as the generation time for the resulting ObjectId.
+        """
+        offset = generation_time.utcoffset()
+        if offset is not None:
+            generation_time = generation_time - offset
+        timestamp = calendar.timegm(generation_time.timetuple())
+        oid = struct.pack(">I", int(timestamp)) + b"\x00\x00\x00\x00\x00\x00\x00\x00"
+        return cls(oid)
+
+    @classmethod
+    def is_valid(cls: Type[ObjectId], oid: Any) -> bool:
+        """Checks if a `oid` string is valid or not.
+
+        :param oid: the object id to validate
+
+        .. versionadded:: 2.3
+        """
+        if not oid:
+            return False
+
+        try:
+            ObjectId(oid)
+            return True
+        except (InvalidIdError, TypeError):
+            return False
+
+    @classmethod
+    def _random(cls) -> bytes:
+        """Generate a 5-byte random number once per process."""
+        pid = os.getpid()
+        if pid != cls._pid:
+            cls._pid = pid
+            cls.__random = _random_bytes()
+        return cls.__random
+
+    def __generate(self) -> None:
+        """Generate a new value for this ObjectId."""
+        # 4 bytes current time
+        oid = struct.pack(">I", int(time.time()))
+
+        # 5 bytes random
+        oid += ObjectId._random()
+
+        # 3 bytes inc
+        with ObjectId._inc_lock:
+            oid += struct.pack(">I", ObjectId._inc)[1:4]
+            ObjectId._inc = (ObjectId._inc + 1) % (_MAX_COUNTER_VALUE + 1)
+
+        self.__id = oid
+
+    def __validate(self, oid: Any) -> None:
+        """Validate and use the given id for this ObjectId.
+
+        Raises TypeError if id is not an instance of :class:`str`,
+        :class:`bytes`, or ObjectId. Raises InvalidId if it is not a
+        valid ObjectId.
+
+        :param oid: a valid ObjectId
+        """
+        if isinstance(oid, ObjectId):
+            self.__id = oid.binary
+        elif isinstance(oid, str):
+            if len(oid) == 24:
+                try:
+                    self.__id = bytes.fromhex(oid)
+                except (TypeError, ValueError):
+                    _raise_invalid_id(oid)
+            else:
+                _raise_invalid_id(oid)
+        else:
+            raise TypeError(f"id must be an instance of (bytes, str, ObjectId), not {type(oid)}")
+
+    @property
+    def binary(self) -> bytes:
+        """12-byte binary representation of this ObjectId."""
+        return self.__id
+
+    @property
+    def generation_time(self) -> datetime.datetime:
+        """A :class:`datetime.datetime` instance representing the time of
+        generation for this :class:`ObjectId`.
+
+        The :class:`datetime.datetime` is timezone aware, and
+        represents the generation time in UTC. It is precise to the
+        second.
+        """
+        timestamp = struct.unpack(">I", self.__id[0:4])[0]
+        return datetime.datetime.fromtimestamp(timestamp, utc)
+
+    def __getstate__(self) -> bytes:
+        """Return value of object for pickling.
+        needed explicitly because __slots__() defined.
+        """
+        return self.__id
+
+    def __setstate__(self, value: Any) -> None:
+        """Explicit state set from pickling"""
+        # Provide backwards compatibility with OIDs
+        # pickled with pymongo-1.9 or older.
+        if isinstance(value, dict):
+            oid = value["_ObjectId__id"]
+        else:
+            oid = value
+        # ObjectIds pickled in python 2.x used `str` for __id.
+        # In python 3.x this has to be converted to `bytes`
+        # by encoding latin-1.
+        if isinstance(oid, str):
+            self.__id = oid.encode("latin-1")
+        else:
+            self.__id = oid
+
+    def __str__(self) -> str:
+        return binascii.hexlify(self.__id).decode()
+
+    def __repr__(self) -> str:
+        return f"ObjectId('{self!s}')"
+
+    def __eq__(self, other: Any) -> bool:
+        if isinstance(other, ObjectId):
+            return self.__id == other.binary
+        return NotImplemented
+
+    def __ne__(self, other: Any) -> bool:
+        if isinstance(other, ObjectId):
+            return self.__id != other.binary
+        return NotImplemented
+
+    def __lt__(self, other: Any) -> bool:
+        if isinstance(other, ObjectId):
+            return self.__id < other.binary
+        return NotImplemented
+
+    def __le__(self, other: Any) -> bool:
+        if isinstance(other, ObjectId):
+            return self.__id <= other.binary
+        return NotImplemented
+
+    def __gt__(self, other: Any) -> bool:
+        if isinstance(other, ObjectId):
+            return self.__id > other.binary
+        return NotImplemented
+
+    def __ge__(self, other: Any) -> bool:
+        if isinstance(other, ObjectId):
+            return self.__id >= other.binary
+        return NotImplemented
+
+    def __hash__(self) -> int:
+        """Get a hash value for this :class:`ObjectId`."""
+        return hash(self.__id)

+ 53 - 0
app/libs/bson/tz_util.py

@@ -0,0 +1,53 @@
+# Copyright 2010-2015 MongoDB, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Timezone related utilities for BSON."""
+from __future__ import annotations
+
+from datetime import datetime, timedelta, tzinfo
+from typing import Optional, Tuple, Union
+
+ZERO: timedelta = timedelta(0)
+
+
+class FixedOffset(tzinfo):
+    """Fixed offset timezone, in minutes east from UTC.
+
+    Implementation based from the Python `standard library documentation
+    <http://docs.python.org/library/datetime.html#tzinfo-objects>`_.
+    Defining __getinitargs__ enables pickling / copying.
+    """
+
+    def __init__(self, offset: Union[float, timedelta], name: str) -> None:
+        if isinstance(offset, timedelta):
+            self.__offset = offset
+        else:
+            self.__offset = timedelta(minutes=offset)
+        self.__name = name
+
+    def __getinitargs__(self) -> Tuple[timedelta, str]:
+        return self.__offset, self.__name
+
+    def utcoffset(self, dt: Optional[datetime]) -> timedelta:
+        return self.__offset
+
+    def tzname(self, dt: Optional[datetime]) -> str:
+        return self.__name
+
+    def dst(self, dt: Optional[datetime]) -> timedelta:
+        return ZERO
+
+
+utc: FixedOffset = FixedOffset(0, "UTC")
+"""Fixed offset timezone representing UTC."""

+ 25 - 0
app/libs/class_loader.py

@@ -0,0 +1,25 @@
+import importlib
+import logging
+
+
+def load_class(name: str):
+    """
+    load class by name
+    """
+    name_components = name.split(".")
+    if not name_components:
+        logging.error("Invalid class name: %s", name)
+        return
+
+    module_name = ".".join(name_components[:-1])
+    class_name = name_components[-1]
+
+    try:
+        module = importlib.import_module(module_name)
+        a_class = getattr(module, class_name)
+        logging.info("load class: %s", a_class)
+        return a_class
+    except ImportError:
+        logging.error("Module not found: %s", name)
+    except AttributeError:
+        logging.error("Class not found: %s", name)

+ 84 - 0
app/libs/paginate.py

@@ -0,0 +1,84 @@
+import logging
+from typing import TypeVar, Any, Optional, Generic, List, Sequence
+
+from fastapi import Query
+from fastapi_pagination.bases import AbstractPage, AbstractParams, CursorRawParams
+from fastapi_pagination.cursor import encode_cursor
+from fastapi_pagination.ext.sqlmodel import paginate
+from fastapi_pagination.types import Cursor
+from fastapi_pagination.utils import verify_params, create_pydantic_model
+from sqlmodel import asc, desc
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.models.base_model import BaseModel
+
+ModelType = TypeVar("ModelType", bound=BaseModel)
+
+
+class CursorParams(BaseModel, AbstractParams):
+    limit: int = Query(20, ge=1, le=100, description="Page offset")
+    order: str = Query(default="desc", description="Sort order")
+    after: Optional[str] = Query(None, description="Page after")
+    before: Optional[str] = Query(None, description="Page before")
+
+    def to_raw_params(self) -> CursorRawParams:
+        return CursorRawParams(cursor=None, size=self.limit, include_total=True)
+
+
+class CommonPage(AbstractPage[ModelType], Generic[ModelType]):
+    __params_type__ = CursorParams
+
+    object: str = "list"
+    data: List[ModelType] = []
+    first_id: Optional[str] = ""
+    last_id: Optional[str] = ""
+    has_more: bool = False
+
+    @classmethod
+    def create(
+        cls,
+        items: Sequence[ModelType],
+        params: CursorParams,
+        *,
+        current: Optional[Cursor] = None,
+        current_backwards: Optional[Cursor] = None,
+        next_: Optional[Cursor] = None,
+        previous: Optional[Cursor] = None,
+        **kwargs: Any,
+    ) -> AbstractPage[ModelType]:
+        next_page = encode_cursor(next_)
+        return create_pydantic_model(
+            CommonPage,
+            next_page=next_page,
+            first_id=items[0].id if items else None,
+            last_id=items[len(items) - 1].id if items else None,
+            has_more=False if next_page is None else True,
+            data=list(items),
+        )
+
+
+async def cursor_page(query: Any, db: AsyncSession) -> CommonPage[ModelType]:
+    params, _ = verify_params(None, "cursor")
+    model = query._propagate_attrs["plugin_subject"].class_
+
+    logging.debug(
+        f"Page model={model}, sort={params.order}, filter_parameters=before:{params.before}, after:{params.after}",
+    )
+
+    if params.before is not None:
+        if params.order.upper() == "DESC":
+            query = query.where(model.id > params.before)
+        else:
+            query = query.where(model.id < params.before)
+    if params.after is not None:
+        if params.order.upper() == "DESC":
+            query = query.where(model.id < params.after)
+        else:
+            query = query.where(model.id > params.after)
+
+    if params.order.upper() == "DESC":
+        query = query.order_by(desc(model.__dict__["created_at"]))
+    else:
+        query = query.order_by(asc(model.__dict__["created_at"]))
+
+    return await paginate(db, query)

+ 52 - 0
app/libs/thread_executor.py

@@ -0,0 +1,52 @@
+import atexit
+from concurrent.futures import Executor, ThreadPoolExecutor
+import concurrent
+import concurrent.futures
+
+from typing import List
+
+
+def get_executor_for_config(worker_num: int, thread_name_prefix: str) -> Executor:
+    """
+    Returns a generator that yields a ThreadPoolExecutor with the specified number of workers.
+
+    Args:
+        worker_num (int): The number of worker threads in the ThreadPoolExecutor.
+        thread_name_prefix (str): thread name perfix.
+
+    Yields:
+        Executor: A ThreadPoolExecutor instance.
+
+    """
+    executor = ThreadPoolExecutor(max_workers=worker_num, thread_name_prefix=thread_name_prefix)
+    atexit.register(executor.shutdown, wait=False)
+    return executor
+
+
+def run_with_executor(executor: Executor, func, tasks: List, timeout: int):
+    """
+    Executes the given function with the provided executor and tasks.
+
+    Args:
+        executor (Executor): The executor to use for running the tasks.
+        func: The function to be executed.
+        tasks (List): The list of tasks to be executed.
+        timeout (int): The maximum time to wait for the tasks to complete.
+
+    Returns:
+        List: The results of the executed tasks.
+
+    Raises:
+        Exception: If any of the tasks raise an exception.
+    """
+    futures = [executor.submit(lambda args: func(*args), task) for task in tasks]
+    done, _ = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION, timeout=timeout)
+
+    results = []
+    for future in done:
+        if future.exception():
+            raise future.exception()
+
+        if future.done():
+            results.append(future.result())
+    return results

+ 1 - 0
app/libs/types.py

@@ -0,0 +1 @@
+from app.libs.bson.objectid import ObjectId as ObjectId  # noqa

+ 37 - 0
app/libs/util.py

@@ -0,0 +1,37 @@
+import uuid
+from datetime import datetime
+
+import jwt
+
+
+def datetime2timestamp(value: datetime):
+    if not value:
+        return None
+    return value.timestamp()
+
+
+def str2datetime(value: str):
+    if not value:
+        return None
+    return datetime.fromisoformat(value)
+
+
+def is_valid_datetime(date_str, format="%Y-%m-%d %H:%M:%S"):
+    if not date_str or not isinstance(date_str, str):
+        return False
+    try:
+        datetime.strptime(date_str, format)
+        return True
+    except ValueError:
+        return False
+
+
+def random_uuid() -> str:
+    return "ml-" + str(uuid.uuid4()).replace("-", "")
+
+
+def verify_jwt_expiration(token):
+    decoded_token = jwt.decode(token, options={"verify_signature": False, "verify_exp": False})
+    expiration_time = datetime.fromtimestamp(decoded_token['exp'])
+    current_time = datetime.now()
+    return current_time < expiration_time

+ 26 - 0
app/models/__init__.py

@@ -0,0 +1,26 @@
+from app.models.action import Action
+from app.models.assistant import Assistant
+from app.models.assistant_file import AssistantFile
+from app.models.file import File
+from app.models.message import Message
+from app.models.message_file import MessageFile
+from app.models.run import Run
+from app.models.run_step import RunStep
+from app.models.thread import Thread
+from app.models.token import Token
+from app.models.token_relation import TokenRelation
+
+
+__all__ = [
+    "Assistant",
+    "AssistantFile",
+    "File",
+    "Message",
+    "MessageFile",
+    "Run",
+    "RunStep",
+    "Thread",
+    "Token",
+    "TokenRelation",
+    "Action",
+]

+ 37 - 0
app/models/action.py

@@ -0,0 +1,37 @@
+from sqlalchemy import Column, JSON
+from typing import Optional
+
+from pydantic import Field as PDField
+
+from sqlmodel import Field
+
+from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class ActionBase(BaseModel):
+    name: str = Field(nullable=False)
+    description: Optional[str] = Field(nullable=False)
+    openapi_schema: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    authentication: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    extra: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    metadata_: Optional[dict] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
+    operation_id: str = Field(nullable=False)
+    url: str = Field(nullable=False)
+    method: str = Field(nullable=False)
+    path_param_schema: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    query_param_schema: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    body_param_schema: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    body_type: str = Field(nullable=False)
+    function_def: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    use_for_everyone: bool = Field(default=False, nullable=False)
+    object: str = Field(nullable=False, default="action")
+
+
+class Action(ActionBase, PrimaryKeyMixin, TimeStampMixin, table=True):
+    pass
+
+
+class ActionRead(ActionBase, PrimaryKeyMixin, TimeStampMixin):
+    metadata_: Optional[dict] = PDField(default=None, alias="metadata")

+ 51 - 0
app/models/assistant.py

@@ -0,0 +1,51 @@
+from typing import Optional, Union
+
+from pydantic import Field as PDField
+
+from sqlalchemy import Column
+from sqlmodel import Field, JSON, TEXT
+
+from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
+
+
+class AssistantBase(BaseModel):
+    model: str = Field(nullable=False)
+    description: Optional[str] = Field(default=None)
+    file_ids: Optional[list] = Field(default=None, sa_column=Column(JSON))
+    instructions: Optional[str] = Field(default=None, max_length=32768, sa_column=Column(TEXT))
+    metadata_: Optional[dict] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
+    name: Optional[str] = Field(default=None)
+    tools: Optional[list] = Field(default=None, sa_column=Column(JSON))
+    extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
+    response_format: Optional[Union[str, dict]] = Field(default="auto", sa_column=Column(JSON))  # 响应格式
+    tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 工具资源
+    temperature: Optional[float] = Field(default=None)  # 温度
+    top_p: Optional[float] = Field(default=None)  # top_p
+    object: str = Field(nullable=False, default="assistant")
+
+
+class Assistant(AssistantBase, PrimaryKeyMixin, TimeStampMixin, table=True):
+    pass
+
+
+class AssistantCreate(AssistantBase):
+    pass
+
+
+class AssistantUpdate(BaseModel):
+    model: Optional[str] = Field(default=None)
+    description: Optional[str] = Field(default=None)
+    file_ids: Optional[list] = Field(default=None, sa_column=Column(JSON))
+    instructions: Optional[str] = Field(default=None, max_length=32768, sa_column=Column(TEXT))
+    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+    name: Optional[str] = Field(default=None)
+    tools: Optional[list] = Field(default=None, sa_column=Column(JSON))
+    extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
+    response_format: Optional[Union[str, dict]] = Field(default="auto", sa_column=Column(JSON))  # 响应格式
+    tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 工具资源
+    temperature: Optional[float] = Field(default=None)  # 温度
+    top_p: Optional[float] = Field(default=None)  # top_p
+
+
+class AssistantRead(AssistantBase, PrimaryKeyMixin, TimeStampMixin):
+    metadata_: Optional[dict] = PDField(default=None, alias="metadata")

+ 23 - 0
app/models/assistant_file.py

@@ -0,0 +1,23 @@
+from sqlalchemy import Index
+from sqlmodel import Field
+
+from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
+
+
+class AssistantFileBase(BaseModel):
+    __table_args__ = (Index("assistant_file_assistant_id_id_idx", "assistant_id", "id"),)
+
+    assistant_id: str = Field(nullable=False)
+    object: str = Field(nullable=False, default="assistant.file")
+
+
+class AssistantFile(AssistantFileBase, PrimaryKeyMixin, TimeStampMixin, table=True):
+    pass
+
+
+class AssistantFileCreate(AssistantFileBase):
+    pass
+
+
+class AssistantFileUpdate(BaseModel, PrimaryKeyMixin, TimeStampMixin):
+    pass

+ 46 - 0
app/models/base_model.py

@@ -0,0 +1,46 @@
+from datetime import datetime
+from typing import Optional
+
+import orjson
+from sqlalchemy import DateTime, text
+from sqlalchemy.orm import declared_attr
+from sqlmodel import SQLModel, Field
+
+from app.libs.types import ObjectId
+from app.libs.util import datetime2timestamp
+
+
+def orjson_dumps(v, *, default):
+    # orjson.dumps returns bytes, to match standard json.dumps we need to decode
+    return orjson.dumps(v, default=default).decode()
+
+
+def to_snake_case(string: str) -> str:
+    return "".join(["_" + i.lower() if i.isupper() else i for i in string]).lstrip("_")
+
+
+class BaseModel(SQLModel):
+    class Config:
+        from_attributes = True
+        populate_by_name = True
+        json_encoders = {
+            datetime: lambda v: datetime2timestamp(v),
+        }
+
+    @classmethod
+    @declared_attr
+    def __tablename__(cls) -> str:
+        return to_snake_case(cls.__name__)
+
+
+class TimeStampMixin(SQLModel):
+    created_at: Optional[datetime] = Field(
+        sa_type=DateTime, default=None, nullable=False,  sa_column_kwargs={"server_default": text("CURRENT_TIMESTAMP")}
+    )
+    updated_at: Optional[datetime] = Field(
+        sa_type=DateTime, default=None, sa_column_kwargs={"onupdate": text("CURRENT_TIMESTAMP")}
+    )
+
+
+class PrimaryKeyMixin(SQLModel):
+    id: str = Field(primary_key=True, default_factory=ObjectId)

+ 18 - 0
app/models/file.py

@@ -0,0 +1,18 @@
+from typing import Optional
+
+from sqlalchemy import Index, Column, Enum
+from sqlmodel import Field
+
+from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
+
+
+class File(BaseModel, PrimaryKeyMixin, TimeStampMixin, table=True):
+    __table_args__ = (Index("file_purpose_idx", "purpose"),)
+
+    bytes: int = Field(nullable=False)
+    filename: str = Field(nullable=False)
+    purpose: str = Field(nullable=False)
+    object: str = Field(nullable=False, default="file")
+    key: str = Field(nullable=False)
+    status: Optional[str] = Field(default=None, sa_column=Column("status", Enum("error", "processed", "uploaded")))
+    status_details: Optional[str] = Field(default=None)

+ 41 - 0
app/models/message.py

@@ -0,0 +1,41 @@
+from typing import Optional, Union, List
+
+from pydantic import Field as PDField
+
+from sqlalchemy import Column, Enum
+from sqlmodel import Field, JSON
+
+from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
+
+
+class MessageBase(BaseModel):
+    role: str = Field(sa_column=Column(Enum("assistant", "user", "system", "function", "tool"), nullable=False))
+    thread_id: str = Field(nullable=False)
+    object: str = Field(nullable=False, default="thread.message")
+    content: Optional[list] = Field(default=None, sa_column=Column(JSON))
+    file_ids: Optional[list] = Field(default=None, sa_column=Column(JSON))
+    attachments: Optional[list] = Field(default=None, sa_column=Column(JSON))  # 附件
+    metadata_: Optional[dict] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
+    assistant_id: Optional[str] = Field(default=None)
+    run_id: Optional[str] = Field(default=None)
+
+
+class Message(MessageBase, TimeStampMixin, PrimaryKeyMixin, table=True):
+    pass
+
+
+class MessageCreate(BaseModel):
+    role: str = Field(sa_column=Column(Enum("assistant", "user"), nullable=False))
+    content: Union[str, List[dict]] = Field(nullable=False)
+    file_ids: Optional[list] = Field(default=None)
+    attachments: Optional[list] = Field(default=None, sa_column=Column(JSON))  # 附件
+    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+
+
+class MessageUpdate(BaseModel):
+    content: Optional[str] = Field(default=None)
+    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+
+
+class MessageRead(MessageBase, TimeStampMixin, PrimaryKeyMixin):
+    metadata_: Optional[dict] = PDField(default=None, alias="metadata")

+ 8 - 0
app/models/message_file.py

@@ -0,0 +1,8 @@
+from sqlmodel import Field
+
+from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
+
+
+class MessageFile(BaseModel, PrimaryKeyMixin, TimeStampMixin, table=True):
+    message_id: str = Field(nullable=False)
+    object: str = Field(nullable=False, default="thread.message.file")

+ 119 - 0
app/models/run.py

@@ -0,0 +1,119 @@
+from datetime import datetime
+from typing import Optional, Any, Union, List
+
+from pydantic import Field as PDField
+
+from sqlalchemy import Column, Enum
+from sqlalchemy.sql.sqltypes import JSON, TEXT, String
+from sqlmodel import Field
+
+from pydantic import model_validator
+
+from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
+from app.models.message import MessageCreate
+from app.schemas.tool.authentication import Authentication
+
+
+class RunBase(BaseModel):
+    instructions: str = Field(default=None, max_length=32768, sa_column=Column(TEXT))
+    model: str = Field(default=None)
+    status: str = Field(
+        default="queued",
+        sa_column=Column(
+            Enum(
+                "cancelled",
+                "cancelling",
+                "completed",
+                "expired",
+                "failed",
+                "in_progress",
+                "queued",
+                "requires_action",
+            ),
+            default="queued",
+            nullable=True,
+        ),
+    )
+    #id: str = Field(default=None, nullable=False)
+    created_at: Optional[datetime] = Field(default=datetime.now())
+    assistant_id: str = Field(nullable=False)
+    thread_id: str = Field(default=None)
+    object: str = Field(nullable=False, default="thread.run")
+    #file_ids: Optional[list] = Field(default=[], sa_column=Column(JSON))
+    file_ids: List[str] = Field(default_factory=list ,sa_column=Column(String))
+    #metadata: Optional[object] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
+    metadata_: Optional[dict] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
+    last_error: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    required_action: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    tools: Optional[list] = Field(default=[], sa_column=Column(JSON))
+    started_at: Optional[datetime] = Field(default=None)
+    completed_at: Optional[int] = Field(default=None)
+    cancelled_at: Optional[int] = Field(default=None)
+    expires_at: Optional[int] = Field(default=None)
+    failed_at: Optional[int] = Field(default=None)
+    additional_instructions: Optional[str] = Field(default=None, max_length=32768, sa_column=Column(TEXT))
+    extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
+    stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    incomplete_details: Optional[str] = Field(default=None)  # 未完成详情
+    max_completion_tokens: Optional[int] = Field(default=None)  # 最大完成长度
+    max_prompt_tokens: Optional[int] = Field(default=None)  # 最大提示长度
+    response_format: Optional[Union[str, dict]] = Field(default="auto", sa_column=Column(JSON))  # 响应格式
+    tool_choice: Optional[str] = Field(default=None)  # 工具选择
+    truncation_strategy: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 截断策略
+    usage: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 调用使用情况
+    temperature: Optional[float] = Field(default=None)  # 温度
+    top_p: Optional[float] = Field(default=None)  # top_p
+
+class Run(RunBase, PrimaryKeyMixin, TimeStampMixin, table=True):
+    pass
+
+
+class RunCreate(BaseModel):
+    assistant_id: str
+    status: Optional[str] = "queued"
+    instructions: Optional[str] = None
+    additional_instructions: Optional[str] = None
+    model: Optional[str] = None
+    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+    tools: Optional[list] = []
+    extra_body: Optional[dict[str, Union[dict[str, Union[Authentication, Any]], Any]]] = {}
+    stream: Optional[bool] = False
+    stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    additional_messages: Optional[list[MessageCreate]] = Field(default=[], sa_column=Column(JSON))  # 消息列表
+    max_completion_tokens: Optional[int] = None  # 最大完成长度
+    max_prompt_tokens: Optional[int] = Field(default=None)  # 最大提示长度
+    truncation_strategy: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 截断策略
+    response_format: Optional[Union[str, dict]] = Field(default="auto", sa_column=Column(JSON))  # 响应格式
+    tool_choice: Optional[str] = Field(default=None)  # 工具选择
+    temperature: Optional[float] = Field(default=None)  # 温度
+    top_p: Optional[float] = Field(default=None)  # top_p
+
+    @model_validator(mode="before")
+    def model_validator(cls, data: Any):
+        extra_body = data.get("extra_body")
+        if extra_body:
+            action_authentications = extra_body.get("action_authentications")
+            if action_authentications:
+                res = action_authentications.values()
+                [Authentication.model_validate(i).encrypt() for i in res]
+        return data
+
+
+class RunUpdate(BaseModel):
+    tools: Optional[list] = []
+    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+    extra_body: Optional[dict[str, Authentication]] = {}
+
+    @model_validator(mode="before")
+    def model_validator(cls, data: Any):
+        extra_body = data.get("extra_body")
+        if extra_body:
+            action_authentications = extra_body.get("action_authentications")
+            if action_authentications:
+                res = action_authentications.values()
+                [Authentication.model_validate(i).encrypt() for i in res]
+        return data
+
+
+class RunRead(RunBase, TimeStampMixin, PrimaryKeyMixin):
+    metadata_: Optional[dict] = PDField(default=None, alias="metadata")

+ 39 - 0
app/models/run_step.py

@@ -0,0 +1,39 @@
+from datetime import datetime
+from typing import Optional
+
+from pydantic import Field as PDField
+
+from sqlalchemy import Index, Column, Enum
+from sqlmodel import Field, JSON
+
+from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
+
+
+class RunStepBase(BaseModel):
+    status: str = Field(
+        sa_column=Column(Enum("cancelled", "completed", "expired", "failed", "in_progress"), nullable=False)
+    )
+    type: str = Field(sa_column=Column(Enum("message_creation", "tool_calls"), nullable=False))
+    assistant_id: str = Field(nullable=False)
+    thread_id: str = Field(nullable=False)
+    run_id: str = Field(nullable=False)
+    object: str = Field(nullable=False, default="thread.run.step")
+    metadata_: Optional[dict] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
+    last_error: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    step_details: Optional[dict] = Field(default=None, sa_column=Column(JSON))
+    completed_at: Optional[datetime] = Field(default=None)
+    cancelled_at: Optional[datetime] = Field(default=None)
+    expires_at: Optional[datetime] = Field(default=None)
+    failed_at: Optional[datetime] = Field(default=None)
+    message_id: Optional[str] = Field(default=None)
+
+
+class RunStep(RunStepBase, PrimaryKeyMixin, TimeStampMixin, table=True):
+    __table_args__ = (
+        Index("run_step_run_id_idx", "run_id"),
+        Index("run_step_run_id_type_idx", "run_id", "type"),
+    )
+
+
+class RunStepRead(RunStepBase, PrimaryKeyMixin, TimeStampMixin):
+    metadata_: Optional[dict] = PDField(default=None, alias="metadata")

+ 26 - 0
app/models/thread.py

@@ -0,0 +1,26 @@
+from typing import Optional
+
+from sqlalchemy import Column
+from sqlmodel import Field, JSON
+
+from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
+from app.models.message import MessageCreate
+
+
+class Thread(BaseModel, PrimaryKeyMixin, TimeStampMixin, table=True):
+    object: str = Field(nullable=False, default="thread")
+    metadata_: Optional[dict] = Field(default=None, sa_column=Column("metadata", JSON), schema_extra={"validation_alias": "metadata"})
+    tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 工具资源
+
+
+class ThreadCreate(BaseModel):
+    object: str = "thread"
+    messages: Optional[list[MessageCreate]] = Field(default=None)
+    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+    thread_id: Optional[str] = Field(default=None)
+    end_message_id: Optional[str] = Field(default=None)
+    tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON))  # 工具资源
+
+
+class ThreadUpdate(BaseModel):
+    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})

+ 27 - 0
app/models/token.py

@@ -0,0 +1,27 @@
+from typing import Optional
+
+from sqlalchemy import Index
+from sqlmodel import Field
+from app.libs import util
+
+from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
+
+
+class TokenBase(BaseModel):
+    llm_base_url: str = Field(nullable=False)
+    llm_api_key: str = Field(nullable=False)
+    description: Optional[str] = Field(default=None)
+
+
+class Token(TokenBase, TimeStampMixin, PrimaryKeyMixin, table=True):
+    __table_args__ = (Index("token_assistant_token_idx", "assistant_token", unique=True),)
+
+    assistant_token: str = Field(default_factory=util.random_uuid)
+
+
+class TokenCreate(TokenBase):
+    pass
+
+
+class TokenUpdate(TokenBase):
+    pass

+ 30 - 0
app/models/token_relation.py

@@ -0,0 +1,30 @@
+from enum import Enum
+from sqlmodel import Field
+
+from app.models.base_model import BaseModel, TimeStampMixin, PrimaryKeyMixin
+
+
+class RelationType(str, Enum):
+    Assistant = "assistant"
+    File = "file"
+    Thread = "thread"
+    Action = "action"
+
+
+class TokenRelationBase(BaseModel):
+    token_id: str = Field(nullable=False)
+    relation_type: RelationType = Field(nullable=False)
+    relation_id: str = Field(nullable=False)
+
+
+class TokenRelation(TokenRelationBase, TimeStampMixin, PrimaryKeyMixin, table=True):
+    pass
+
+
+class TokenRelationQuery(TokenRelationBase):
+    pass
+
+
+class TokenRelationDelete(BaseModel):
+    relation_type: RelationType = Field(nullable=False)
+    relation_id: str = Field(nullable=False)

+ 0 - 0
app/providers/__init__.py


+ 41 - 0
app/providers/app_provider.py

@@ -0,0 +1,41 @@
+import logging
+
+from fastapi.middleware.cors import CORSMiddleware
+
+from app.providers.middleware.http_process_time import HTTPProcessTimeMiddleware
+from app.providers.middleware.unhandled_exception_handler import UnhandledExceptionHandlingMiddleware
+from app.providers.database import redis_client
+from config.config import settings
+
+
+def register(app):
+    app.debug = settings.DEBUG
+    app.title = settings.NAME
+
+    add_global_middleware(app)
+
+    @app.on_event("startup")
+    def startup():
+        # create_db_and_tables()
+        pass
+
+    @app.on_event("shutdown")
+    def shutdown():
+        if redis_client:
+            redis_client.close()
+
+        logging.info("Application shutdown")
+
+
+def add_global_middleware(app):
+    app.add_middleware(UnhandledExceptionHandlingMiddleware)
+
+    app.add_middleware(
+        CORSMiddleware,
+        allow_origins=["*"],
+        allow_credentials=True,
+        allow_methods=["*"],
+        allow_headers=["*"],
+    )
+
+    app.add_middleware(HTTPProcessTimeMiddleware)

+ 112 - 0
app/providers/auth_provider.py

@@ -0,0 +1,112 @@
+import logging
+from fastapi import Depends
+
+from sqlmodel import select
+
+from app.api.deps import verfiy_token, verify_token_relation
+from app.models.token_relation import RelationType, TokenRelation, TokenRelationDelete
+from app.services.token.token_relation import TokenRelationService
+from config.config import settings
+
+
+class AuthPolicy(object):
+    """
+    default auth policy with nothing to do
+    """
+
+    def enable(self):
+        """
+        enable auth policy
+        """
+
+    def insert_token_rel(self, session, token_id: str, relation_type: RelationType, relation_id: str):
+        """
+        insert a token relation to database when enable token auth policy
+        """
+
+    async def delete_token_rel(self, session, relation_type: RelationType, relation_id: str):
+        """
+        delete token relation when enable token auth policy
+        """
+
+    def token_filter(self, statement, field, relation_type: RelationType, token_id: str):
+        """
+        add token filter clause when enable token auth policy
+        """
+        return statement
+
+
+class SimpleTokenAuthPolicy(AuthPolicy):
+    """
+    simple token auth policy
+    """
+
+    def enable(self):
+        """
+        add auth verify dependents to path router
+        """
+        from app.api.v1 import assistant, assistant_file, thread, message, runs, action
+
+        verify_assistant_depends = Depends(
+            verify_token_relation(relation_type=RelationType.Assistant, name="assistant_id")
+        )
+        # assistant router
+        for route in assistant.router.routes:
+            if route.name == assistant.create_assistant.__name__ or route.name == assistant.list_assistants.__name__:
+                route.dependencies.append(Depends(verfiy_token))
+            else:
+                route.dependencies.append(verify_assistant_depends)
+
+        # thread router
+        verify_thread_depends = Depends(verify_token_relation(relation_type=RelationType.Thread, name="thread_id"))
+        for route in thread.router.routes:
+            if route.name == thread.create_thread.__name__:
+                route.dependencies.append(
+                    Depends(
+                        verify_token_relation(
+                            relation_type=RelationType.Thread, name="thread_id", ignore_none_relation_id=True
+                        )
+                    )
+                )
+            else:
+                route.dependencies.append(verify_thread_depends)
+
+        # action router
+        verify_action_depends = Depends(verify_token_relation(relation_type=RelationType.Action, name="action_id"))
+        for route in action.router.routes:
+            if route.name == action.create_actions.__name__ or route.name == action.list_actions.__name__:
+                route.dependencies.append(Depends(verfiy_token))
+            else:
+                route.dependencies.append(verify_action_depends)
+
+        self.__append_deps_for_all_routes(assistant_file.router, verify_assistant_depends)
+        self.__append_deps_for_all_routes(message.router, verify_thread_depends)
+        self.__append_deps_for_all_routes(runs.router, verify_thread_depends)
+
+    def insert_token_rel(self, session, token_id: str, relation_type: RelationType, relation_id: str):
+        if token_id:
+            relation = TokenRelation(token_id=token_id, relation_type=relation_type, relation_id=str(relation_id))
+            session.add(relation)
+
+    async def delete_token_rel(self, session, relation_type: RelationType, relation_id: str):
+        to_delete = TokenRelationDelete(relation_type=relation_type, relation_id=relation_id)
+        relation = await TokenRelationService.get_relation_to_delete(session=session, delete=to_delete)
+        await session.delete(relation)
+
+    def token_filter(self, statement, field, relation_type: RelationType, token_id: str):
+        id_subquery = select(TokenRelation.relation_id).where(
+            TokenRelation.relation_type == relation_type, TokenRelation.token_id == token_id
+        )
+        return statement.where(field.in_(id_subquery))
+
+    def __append_deps_for_all_routes(self, router, depends):
+        for route in router.routes:
+            route.dependencies.append(depends)
+
+
+auth_policy: AuthPolicy = SimpleTokenAuthPolicy() if settings.AUTH_ENABLE else AuthPolicy()
+
+
+def register(app):
+    logging.info("use auth polily: %s", auth_policy.__class__.__name__)
+    auth_policy.enable()

+ 9 - 0
app/providers/celery_app.py

@@ -0,0 +1,9 @@
+from celery import Celery
+
+from config.celery import settings as celery_settings
+from config.config import settings
+
+celery_app: Celery = Celery(main=settings.NAME, broker=celery_settings.CELERY_BROKER_URL, task_ignore_result=True)
+
+# 导入任务列表
+import app.tasks.run_task  # noqa

+ 63 - 0
app/providers/database.py

@@ -0,0 +1,63 @@
+import logging
+from contextvars import ContextVar
+from typing import Callable
+
+import redis
+from sqlmodel import SQLModel, create_engine
+from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
+from sqlalchemy.pool import AsyncAdaptedQueuePool, QueuePool
+from sqlalchemy.orm import sessionmaker, scoped_session
+
+from config.config import settings
+from config.database import db_settings, redis_settings
+
+db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
+db_state = ContextVar("db_state", default=db_state_default.copy())
+
+# database
+connect_args = {}
+database_url = db_settings.database_url
+engine = create_engine(
+    database_url,
+    connect_args=connect_args,
+    poolclass=QueuePool,
+    pool_size=db_settings.DB_POOL_SIZE,
+    pool_recycle=db_settings.DB_POOL_RECYCLE,
+    echo=settings.DEBUG,
+)
+session = scoped_session(sessionmaker(bind=engine))
+
+async_database_url = db_settings.async_database_url
+async_engine = create_async_engine(
+    async_database_url,
+    connect_args=connect_args,
+    poolclass=AsyncAdaptedQueuePool,
+    pool_size=db_settings.DB_POOL_SIZE,
+    pool_recycle=db_settings.DB_POOL_RECYCLE,
+    echo=settings.DEBUG,
+)
+
+# 创建session元类
+async_session_local: Callable[..., AsyncSession] = sessionmaker(
+    class_=AsyncSession,
+    bind=async_engine,
+)
+
+
+def create_db_and_tables():
+    logging.debug("Creating database and tables")
+    import app.models  # noqa
+
+    SQLModel.metadata.create_all(async_engine)
+    logging.debug("Database and tables created successfully")
+
+
+# redis
+redis_pool = redis.ConnectionPool(
+    host=redis_settings.REDIS_HOST,
+    port=redis_settings.REDIS_PORT,
+    db=redis_settings.REDIS_DB,
+    password=redis_settings.REDIS_PASSWORD,
+    decode_responses=True,
+)
+redis_client = redis.Redis(connection_pool=redis_pool)

+ 42 - 0
app/providers/handle_exception.py

@@ -0,0 +1,42 @@
+import logging
+
+from fastapi import Request
+from fastapi.exceptions import RequestValidationError
+from starlette.exceptions import HTTPException as StarletteHTTPException
+
+from app.exceptions.exception import AuthenticationError, AuthorizationError, BaseHTTPException
+from app.providers.response import ErrorResponse
+
+
+def register(app):
+    @app.exception_handler(AuthenticationError)
+    async def authentication_exception_handler(request: Request, e: AuthenticationError):
+        """
+        认证异常处理
+        """
+        return ErrorResponse(e.status_code, e.error_code, e.message)
+
+    @app.exception_handler(AuthorizationError)
+    async def authorization_exception_handler(request: Request, e: AuthorizationError):
+        """
+        权限异常处理
+        """
+        return ErrorResponse(e.status_code, e.error_code, e.message)
+
+    @app.exception_handler(BaseHTTPException)
+    async def business_exception_handler(request: Request, e: BaseHTTPException):
+        """
+        其他业务异常
+        """
+        logging.exception(e)
+        return ErrorResponse(e.status_code, e.error_code, e.message, e.type, e.param)
+
+    @app.exception_handler(StarletteHTTPException)
+    async def starlette_http_exception_handler(request: Request, e: StarletteHTTPException):
+        logging.exception(e)
+        return ErrorResponse(e.status_code, "http_error", e.detail)
+
+    @app.exception_handler(RequestValidationError)
+    async def validation_exception_handler(request: Request, e: RequestValidationError):
+        logging.exception(e)
+        return ErrorResponse(422, "request_validation_error", str(e))

+ 46 - 0
app/providers/logging_provider.py

@@ -0,0 +1,46 @@
+import logging
+import sys
+from loguru import logger
+
+from config.logging import settings
+
+
+def register(app=None):
+    level = settings.LOG_LEVEL
+    path = settings.LOG_PATH
+    retention = settings.LOG_RETENTION
+
+    # intercept everything at the root logger
+    logging.root.handlers = [InterceptHandler()]
+    logging.root.setLevel(level)
+
+    # remove every other logger's handlers
+    # and propagate to root logger
+    for name in logging.root.manager.loggerDict.keys():
+        logging.getLogger(name).handlers = []
+        logging.getLogger(name).propagate = True
+
+    # configure loguru
+    logger.configure(
+        handlers=[
+            {"sink": sys.stdout},
+            {"sink": path, "rotation": "00:00", "retention": retention},
+        ]
+    )
+
+
+class InterceptHandler(logging.Handler):
+    def emit(self, record):
+        # Get corresponding Loguru level if it exists
+        try:
+            level = logger.level(record.levelname).name
+        except ValueError:
+            level = record.levelno
+
+        # Find caller from where originated the logged message
+        frame, depth = logging.currentframe(), 2
+        while frame.f_code.co_filename == logging.__file__:
+            frame = frame.f_back
+            depth += 1
+
+        logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())

+ 0 - 0
app/providers/middleware/__init__.py


+ 11 - 0
app/providers/middleware/http_process_time.py

@@ -0,0 +1,11 @@
+import time
+from starlette.middleware.base import BaseHTTPMiddleware
+
+
+class HTTPProcessTimeMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request, call_next):
+        start_time = time.time()
+        response = await call_next(request)
+        process_time = time.time() - start_time
+        response.headers["X-Process-Time"] = str(process_time)
+        return response

+ 19 - 0
app/providers/middleware/unhandled_exception_handler.py

@@ -0,0 +1,19 @@
+import logging
+from starlette.middleware.base import BaseHTTPMiddleware
+
+from app.providers.response import ErrorResponse
+
+
+# Bug: exception_handler unable to catch Exception
+# https://github.com/tiangolo/fastapi/issues/4025
+class UnhandledExceptionHandlingMiddleware(BaseHTTPMiddleware):
+    """
+    处理其他未知异常
+    """
+
+    async def dispatch(self, request, call_next):
+        try:
+            return await call_next(request)
+        except Exception as e:
+            logging.exception(e)
+            return ErrorResponse(500, "internal_server_error", "Internal Server Error")

+ 5 - 0
app/providers/pagination_provider.py

@@ -0,0 +1,5 @@
+from fastapi_pagination import add_pagination
+
+
+def register(app):
+    add_pagination(app)

+ 80 - 0
app/providers/r2r.py

@@ -0,0 +1,80 @@
+from typing import Optional, Any
+import pytest
+
+from r2r import R2RClient
+
+from app.libs.util import verify_jwt_expiration
+from config.llm import tool_settings
+import nest_asyncio
+import asyncio
+
+nest_asyncio.apply()
+
+class R2R:
+    client: R2RClient
+
+    def init(self):
+        # self.client = R2RClient(tool_settings.R2R_BASE_URL)
+        self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
+        self.client = None
+        loop = asyncio.get_event_loop()
+        if loop.is_running():
+            return loop.create_task(self._login())  # 在现有事件循环中运行异步任务
+        else:
+            return asyncio.run(self._login())  # 如果没有事件循环则创建一个新的
+        #loop.create_task(self._login())
+
+    def ingest_file(self, file_path: str, metadata: Optional[dict]):
+        self._check_login()
+        ingest_response = self.client.documents.create(
+            file_path=file_path, metadata=metadata if metadata else None, id=None
+        )
+        return ingest_response.get("results")
+
+    def search(self, query: str, filters: dict[str, Any]):
+        self._check_login()
+        print("aaaaaaacccccccccccccccccccccccccccccccccccccccccvvvvvvvvvvvvvvvvvvvvvvvvvvvvv")
+        print(filters)
+        print(tool_settings.R2R_SEARCH_LIMIT)
+        search_response = self.client.retrieval.search(
+            query=query,
+            search_settings={
+                "filters": filters,
+                "limit": tool_settings.R2R_SEARCH_LIMIT,
+                # ,"do_hybrid_search": True,
+            },
+        )
+        print(search_response)
+        return search_response.get("results").get("chunk_search_results")
+
+    #@pytest.fixture(scope="session")
+    async def _login(self):
+        if not self.auth_enabled:
+            return
+        if not self.client:
+            self.client = R2RClient(tool_settings.R2R_BASE_URL)
+            result = self.client.users.login(tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD)  # 同步调用异步函数
+        #self.client.users.login(tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD)
+        #return self.client
+
+    def _check_login(self):
+        if not self.auth_enabled:
+            return
+        if verify_jwt_expiration(self.client.access_token):
+            return
+        else:
+            loop = asyncio.get_event_loop()
+            if loop.is_running():
+                return loop.create_task(self._login())  # 在现有事件循环中运行异步任务
+            else:
+                return asyncio.run(self._login())  # 如果没有事件循环则创建一个新的
+            #loop.create_task(self._login())
+            #self._login()
+
+
+r2r = R2R()
+
+#async def run_async():
+r2r.init()  # 运行异步函数
+
+#asyncio.run(run_async())

+ 24 - 0
app/providers/response.py

@@ -0,0 +1,24 @@
+from fastapi.responses import JSONResponse
+
+
+# class CustomJSONResponse(JSONResponse):
+#
+#     def render(self, data: any) -> bytes:
+#         return json.dumps(
+#             {'code': 'success', 'data': data},
+#             ensure_ascii=False,
+#             allow_nan=False,
+#             indent=None,
+#             separators=(",", ":"),
+#         ).encode("utf-8")
+
+
+class ErrorResponse(JSONResponse):
+    def __init__(
+        self, status_code: int, error_code: str, message: str = None, type_code: str = None, param: str = None
+    ) -> None:
+        super().__init__(
+            status_code=status_code,
+            # OpenAI style error response
+            content={"error": {"code": error_code, "message": message, "type": type_code, "param": param}},
+        )

+ 15 - 0
app/providers/route_provider.py

@@ -0,0 +1,15 @@
+import logging
+
+from app.api.routes import api_router, router_init
+from config.config import settings
+
+
+def boot(app):
+    # 注册api路由[app/api/routes.py]
+    router_init()
+    app.include_router(api_router, prefix=settings.API_PREFIX)
+
+    # 打印路由
+    if app.debug:
+        for route in app.routes:
+            logging.info({"path": route.path, "name": route.name, "methods": route.methods})

+ 80 - 0
app/providers/storage.py

@@ -0,0 +1,80 @@
+from contextlib import closing
+from typing import Union, Generator
+
+import boto3
+from botocore.exceptions import ClientError
+
+from app.exceptions.exception import ResourceNotFoundError
+from config.storage import settings as s3_settings
+
+
+class Storage:
+    def __init__(self):
+        self.bucket_name = None
+        self.client = None
+
+    def init(self):
+        self.bucket_name = s3_settings.S3_BUCKET_NAME
+        self.client = boto3.client(
+            service_name="s3",
+            aws_access_key_id=s3_settings.S3_ACCESS_KEY,
+            aws_secret_access_key=s3_settings.S3_SECRET_KEY,
+            endpoint_url=s3_settings.S3_ENDPOINT,
+            region_name=s3_settings.S3_REGION,
+        )
+
+    def save(self, filename, data):
+        self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
+
+    def save_from_path(self, filename, local_file_path):
+        self.client.upload_file(Filename=local_file_path, Bucket=self.bucket_name, Key=filename)
+
+    def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]:
+        if stream:
+            return self.load_stream(filename)
+        else:
+            return self.load_once(filename)
+
+    def load_once(self, filename: str) -> bytes:
+        try:
+            with closing(self.client) as client:
+                data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
+        except ClientError as ex:
+            if ex.response["Error"]["Code"] == "NoSuchKey":
+                raise ResourceNotFoundError("File not found")
+            else:
+                raise
+
+        return data
+
+    def load_stream(self, filename: str) -> Generator:
+        def generate(filename: str = filename) -> Generator:
+            try:
+                with closing(self.client) as client:
+                    response = client.get_object(Bucket=self.bucket_name, Key=filename)
+                    for chunk in response["Body"].iter_chunks():
+                        yield chunk
+            except ClientError as ex:
+                if ex.response["Error"]["Code"] == "NoSuchKey":
+                    raise ResourceNotFoundError("File not found")
+                else:
+                    raise
+
+        return generate()
+
+    def download(self, filename, target_filepath):
+        with closing(self.client) as client:
+            client.download_file(self.bucket_name, filename, target_filepath)
+
+    def exists(self, filename):
+        with closing(self.client) as client:
+            try:
+                client.head_object(Bucket=self.bucket_name, Key=filename)
+                return True
+            except Exception:
+                return False
+
+
+storage = Storage()
+
+storage.init()

+ 0 - 0
app/schemas/__init__.py


+ 14 - 0
app/schemas/common.py

@@ -0,0 +1,14 @@
+from typing import Optional, Any
+
+from pydantic import BaseModel, Field
+
+
+class DeleteResponse(BaseModel):
+    id: str
+    object: str = "file"
+    deleted: bool
+
+
+class BaseSuccessDataResponse(BaseModel):
+    status: str = Field("success")
+    data: Optional[Any] = None

+ 9 - 0
app/schemas/files.py

@@ -0,0 +1,9 @@
+from typing import List
+from pydantic import BaseModel
+
+from app.models import File
+
+
+class ListFilesResponse(BaseModel):
+    data: List[File]
+    object: str = "file"

+ 23 - 0
app/schemas/runs.py

@@ -0,0 +1,23 @@
+from typing import List, Optional
+
+from pydantic import BaseModel
+from sqlmodel import Field
+
+
+class ToolOutput(BaseModel):
+    tool_call_id: Optional[str] = Field(
+        None,
+        description="The ID of the tool call in the `required_action` "
+        "object within the run object the output is being submitted for.",
+    )
+    output: Optional[str] = Field(
+        None,
+        description="The output of the tool call to be submitted to continue the run.",
+    )
+
+
+class SubmitToolOutputsRunRequest(BaseModel):
+    tool_outputs: List[ToolOutput] = Field(
+        ..., description="A list of tools for which the outputs are being submitted."
+    )
+    stream: Optional[bool] = False

+ 15 - 0
app/schemas/threads.py

@@ -0,0 +1,15 @@
+from typing import Optional
+from pydantic import BaseModel
+from sqlmodel import Field
+
+from app.models.thread import ThreadCreate
+
+
+class CreateThreadAndRun(BaseModel):
+    assistant_id: str
+    thread: Optional[ThreadCreate] = None
+    instructions: Optional[str] = None
+    model: Optional[str] = None
+    metadata_: Optional[dict] = Field(default=None, schema_extra={"validation_alias": "metadata"})
+    tools: Optional[list] = Field(default=[])
+    stream: Optional[bool] = False

+ 0 - 0
app/schemas/tool/__init__.py


+ 254 - 0
app/schemas/tool/action.py

@@ -0,0 +1,254 @@
+from enum import Enum
+import re
+from typing import Optional, Any, Dict, List
+
+from pydantic import BaseModel, Field, model_validator
+
+import openapi_spec_validator
+
+from app.exceptions.exception import ValidateFailedError
+from app.schemas.tool.authentication import Authentication, AuthenticationType
+
+
+# This function code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+def validate_openapi_schema(schema: Dict):
+    try:
+        openapi_spec_validator.validate(schema)
+        # check exactly one server in the schema
+    except Exception as e:
+        if hasattr(e, "message"):
+            raise ValidateFailedError(f"Invalid openapi schema: {e.message}")
+        else:
+            raise ValidateFailedError(f"Invalid openapi schema: {e}")
+
+    if "servers" not in schema:
+        raise ValidateFailedError("No server is found in action schema")
+
+    if "paths" not in schema:
+        raise ValidateFailedError("No paths is found in action schema")
+
+    if len(schema["servers"]) != 1:
+        raise ValidateFailedError("Exactly one server is allowed in action schema")
+
+    # check each path method has a valid description and operationId
+    for path, methods in schema["paths"].items():
+        for method, details in methods.items():
+            if not details.get("description") or not isinstance(details["description"], str):
+                if details.get("summary") and isinstance(details["summary"], str):
+                    # use summary as its description
+                    details["description"] = details["summary"]
+                else:
+                    raise ValidateFailedError(f"No description is found in {method} {path} in action schema")
+
+            if len(details["description"]) > 512:
+                raise ValidateFailedError(
+                    f"Description cannot be longer than 512 characters in {method} {path} in action schema"
+                )
+
+            if not details.get("operationId") or not isinstance(details["operationId"], str):
+                raise ValidateFailedError(f"No operationId is found in {method} {path} in action schema")
+
+            if len(details["operationId"]) > 128:
+                raise ValidateFailedError(
+                    f"operationId cannot be longer than 128 characters in {method} {path} in action schema"
+                )
+
+            if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", details["operationId"]):
+                raise ValidateFailedError(
+                    f'Invalid operationId {details["operationId"]} in {method} {path} in action schema'
+                )
+
+    return schema
+
+
+# ----------------------------
+# Create Action
+# POST /actions
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class ActionBulkCreateRequest(BaseModel):
+    openapi_schema: Dict = Field(
+        ...,
+        description="The action schema is compliant with the OpenAPI Specification. "
+        "If there are multiple paths and methods in the schema, "
+        "the server will create multiple actions whose schema only has exactly one path and one method",
+    )
+
+    authentication: Authentication = Field(
+        Authentication(type=AuthenticationType.none), description="The action API authentication."
+    )
+
+    use_for_everyone: bool = Field(default=False)
+
+    @model_validator(mode="before")
+    def model_validator(cls, data: Any):
+        openapi_schema = data.get("openapi_schema")
+        validate_openapi_schema(openapi_schema)
+        authentication = data.get("authentication")
+        if authentication:
+            Authentication.model_validate(authentication).encrypt()
+        return data
+
+
+# ----------------------------
+# Update Action
+# POST /actions/{action_id}
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class ActionUpdateRequest(BaseModel):
+    openapi_schema: Optional[Dict] = Field(
+        default=None,
+        description="The action schema, which is compliant with the OpenAPI Specification. "
+        "It should only have exactly one path and one method.",
+    )
+    authentication: Optional[Authentication] = Field(None, description="The action API authentication.")
+
+    use_for_everyone: bool = Field(default=False)
+
+    @model_validator(mode="before")
+    def model_validator(cls, data: Any):
+        if not any([(data.get(key) is not None) for key in ["use_for_everyone", "openapi_schema", "authentication"]]):
+            raise ValidateFailedError("At least one field should be filled")
+        openapi_schema = data.get("openapi_schema")
+        if openapi_schema:
+            validate_openapi_schema(openapi_schema)
+        authentication = data.get("authentication")
+        if authentication:
+            Authentication.model_validate(authentication).encrypt()
+        return data
+
+
+# ----------------------------
+# Run an Action
+# POST /actions/{action_id}/run
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class ActionRunRequest(BaseModel):
+    parameters: Optional[Dict[str, Any]] = Field(None)
+    headers: Optional[Dict[str, Any]] = Field(None)
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class ActionMethod(str, Enum):
+    GET = "GET"
+    POST = "POST"
+    PUT = "PUT"
+    DELETE = "DELETE"
+    PATCH = "PATCH"
+    # HEAD = "HEAD"
+    # OPTIONS = "OPTIONS"
+    # TRACE = "TRACE"
+    NONE = "NONE"
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class ActionParam(BaseModel):
+    type: str
+    description: str
+    enum: Optional[List[str]] = None
+    required: bool
+    properties: Optional[Dict[str, Dict]] = None
+
+    def is_single_value_enum(self):
+        return self.enum and len(self.enum) == 1
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class ActionBodyType(str, Enum):
+    JSON = "JSON"
+    FORM = "FORM"
+    NONE = "NONE"
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class ChatCompletionFunctionParametersProperty(BaseModel):
+    type: str = Field(
+        ...,
+        pattern="^(string|number|integer|boolean|object)$",
+        description="The type of the parameter.",
+    )
+
+    description: str = Field(
+        "",
+        max_length=256,
+        description="The description of the parameter.",
+    )
+
+    properties: Optional[Dict] = Field(
+        None,
+        description="The properties of the parameters.",
+    )
+
+    enum: Optional[List[str]] = Field(
+        None,
+        description="The enum list of the parameter. Which is only allowed when type is 'string'.",
+    )
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class ChatCompletionFunctionParameters(BaseModel):
+    type: str = Field(
+        "object",
+        Literal="object",
+        description="The type of the parameters, which is always 'object'.",
+    )
+
+    properties: Dict[str, ChatCompletionFunctionParametersProperty] = Field(
+        ...,
+        description="The properties of the parameters.",
+    )
+
+    required: List[str] = Field(
+        [],
+        description="The required parameters.",
+    )
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class ChatCompletionFunction(BaseModel):
+    name: str = Field(
+        ...,
+        description="The name of the function.",
+        examples=["plus_a_and_b"],
+    )
+
+    description: str = Field(
+        ...,
+        description="The description of the function.",
+        examples=["Add two numbers"],
+    )
+
+    parameters: ChatCompletionFunctionParameters = Field(
+        ...,
+        description="The function's parameters are represented as an object in JSON Schema format.",
+        examples=[
+            {
+                "type": "object",
+                "properties": {
+                    "a": {"type": "number", "description": "The first number"},
+                    "b": {"type": "number", "description": "The second number"},
+                },
+                "required": ["a", "b"],
+            }
+        ],
+    )
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class ActionRunRequest(BaseModel):
+    parameters: Optional[Dict[str, Any]] = Field(None)
+    headers: Optional[Dict[str, Any]] = Field(None)

+ 86 - 0
app/schemas/tool/authentication.py

@@ -0,0 +1,86 @@
+import logging
+from enum import Enum
+from typing import Optional, Dict
+
+from pydantic import BaseModel, Field, model_validator
+
+from app.utils import aes_encrypt, aes_decrypt
+
+logger = logging.getLogger(__name__)
+
+__all__ = ["Authentication", "AuthenticationType"]
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class AuthenticationType(str, Enum):
+    bearer = "bearer"
+    basic = "basic"
+    custom = "custom"
+    none = "none"
+
+
+# This function code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+def validate_authentication_data(data: Dict):
+    if not isinstance(data, dict):
+        raise ValueError("Authentication should be a dict.")
+
+    if "type" not in data or not data.get("type"):
+        raise ValueError("Type is required for authentication.")
+
+    if data["type"] == AuthenticationType.custom:
+        if "content" not in data or data["content"] is None:
+            raise ValueError("Content is required for custom authentication.")
+
+    elif data["type"] == AuthenticationType.bearer:
+        if "secret" not in data or data["secret"] is None:
+            raise ValueError(f'Secret is required for {data["type"]} authentication.')
+
+    elif data["type"] == AuthenticationType.basic:
+        if "secret" not in data or data["secret"] is None:
+            raise ValueError(f'Secret is required for {data["type"]} authentication.')
+        # assume the secret is a base64 encoded string
+
+    elif data["type"] == AuthenticationType.none:
+        data["secret"] = None
+        data["content"] = None
+
+    return data
+
+
+# This class utilizes code from the Open Source Project TaskingAI.
+# The original code can be found at: https://github.com/TaskingAI/TaskingAI
+class Authentication(BaseModel):
+    encrypted: bool = Field(False)
+    type: AuthenticationType = Field(...)
+    secret: Optional[str] = Field(None, min_length=1, max_length=1024)
+    content: Optional[Dict] = Field(None)
+
+    @model_validator(mode="before")
+    def validate_all_fields_at_the_same_time(cls, data: Dict):
+        data = validate_authentication_data(data)
+        return data
+
+    def is_encrypted(self):
+        return self.encrypted or self.type == AuthenticationType.none
+
+    def encrypt(self):
+        if self.encrypted or self.type == AuthenticationType.none:
+            return
+        if self.secret is not None:
+            self.secret = aes_encrypt(self.secret)
+        if self.content is not None:
+            for key in self.content:
+                self.content[key] = aes_encrypt(self.content[key])
+        self.encrypted = True
+
+    def decrypt(self):
+        if not self.encrypted or self.type == AuthenticationType.none:
+            return
+        if self.secret is not None:
+            self.secret = aes_decrypt(self.secret)
+        if self.content is not None:
+            for key in self.content:
+                self.content[key] = aes_decrypt(self.content[key])
+        self.encrypted = False

+ 0 - 0
app/services/__init__.py


+ 0 - 0
app/services/assistant/__init__.py


+ 67 - 0
app/services/assistant/assistant.py

@@ -0,0 +1,67 @@
+from sqlmodel import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.exceptions.exception import ResourceNotFoundError
+from app.models.assistant import Assistant, AssistantUpdate, AssistantCreate
+from app.models.token_relation import RelationType
+from app.providers.auth_provider import auth_policy
+from app.schemas.common import DeleteResponse
+from app.utils import revise_tool_names
+
+
+class AssistantService:
+    @staticmethod
+    async def create_assistant(*, session: AsyncSession, body: AssistantCreate, token_id: str = None) -> Assistant:
+        revise_tool_names(body.tools)
+        db_assistant = Assistant.model_validate(body.model_dump(by_alias=True))
+        session.add(db_assistant)
+        auth_policy.insert_token_rel(
+            session=session, token_id=token_id, relation_type=RelationType.Assistant, relation_id=db_assistant.id
+        )
+        await session.commit()
+        await session.refresh(db_assistant)
+        return db_assistant
+
+    @staticmethod
+    async def modify_assistant(*, session: AsyncSession, assistant_id: str, body: AssistantUpdate) -> Assistant:
+        revise_tool_names(body.tools)
+        db_assistant = await AssistantService.get_assistant(session=session, assistant_id=assistant_id)
+        update_data = body.dict(exclude_unset=True)
+        for key, value in update_data.items():
+            setattr(db_assistant, key, value)
+        session.add(db_assistant)
+        await session.commit()
+        await session.refresh(db_assistant)
+        return db_assistant
+
+    @staticmethod
+    async def delete_assistant(
+        *,
+        session: AsyncSession,
+        assistant_id: str,
+    ) -> DeleteResponse:
+        db_ass = await AssistantService.get_assistant(session=session, assistant_id=assistant_id)
+        await session.delete(db_ass)
+        await auth_policy.delete_token_rel(
+            session=session, relation_type=RelationType.Assistant, relation_id=assistant_id
+        )
+        await session.commit()
+        return DeleteResponse(id=assistant_id, object="assistant.deleted", deleted=True)
+
+    @staticmethod
+    async def get_assistant(*, session: AsyncSession, assistant_id: str) -> Assistant:
+        statement = select(Assistant).where(Assistant.id == assistant_id)
+        result = await session.execute(statement)
+        assistant = result.scalars().one_or_none()
+        if assistant is None:
+            raise ResourceNotFoundError(message="Assistant not found")
+        return assistant
+
+    @staticmethod
+    def get_assistant_sync(*, session: AsyncSession, assistant_id: str) -> Assistant:
+        statement = select(Assistant).where(Assistant.id == assistant_id)
+        result = session.execute(statement)
+        assistant = result.scalars().one_or_none()
+        if assistant is None:
+            raise ResourceNotFoundError(message="Assistant not found")
+        return assistant

+ 53 - 0
app/services/assistant/assistant_file.py

@@ -0,0 +1,53 @@
+from sqlmodel import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.exceptions.exception import ResourceNotFoundError
+from app.models.assistant_file import (
+    AssistantFile,
+    AssistantFileCreate,
+    AssistantFileUpdate,
+)
+from app.schemas.common import DeleteResponse
+
+
+class AssistantFileService:
+    @staticmethod
+    async def create_assistant_file(
+        *, session: AsyncSession, assistant_id: str, body: AssistantFileCreate
+    ) -> AssistantFile:
+        # TODO 关系表暂时不实现
+        return AssistantFile(id="", assistant_id=assistant_id)
+
+    @staticmethod
+    async def modify_assistant_file(
+        *, session: AsyncSession, assistant_id: str, body: AssistantFileUpdate
+    ) -> AssistantFile:
+        db_assistant = await AssistantFileService.get_assistant_file(
+            session=session, assistant_id=assistant_id, file_id=body.id
+        )
+        update_data = body.dict(exclude_unset=True)
+        for key, value in update_data.items():
+            setattr(db_assistant, key, value)
+        session.add(db_assistant)
+        await session.commit()
+        await session.refresh(db_assistant)
+        return db_assistant
+
+    @staticmethod
+    async def delete_assistant_file(*, session: AsyncSession, assistant_id: str, file_id: str) -> DeleteResponse:
+        assistant_file = await AssistantFileService.get_assistant_file(
+            session=session, assistant_id=assistant_id, file_id=file_id
+        )
+        id = assistant_file.id
+        await session.delete(assistant_file)
+        await session.commit()
+        return DeleteResponse(id=id, object="assistant_file.deleted", deleted=True)
+
+    @staticmethod
+    async def get_assistant_file(*, session: AsyncSession, assistant_id: str, file_id: str) -> AssistantFile:
+        statement = select(AssistantFile).where(AssistantFile.id == assistant_id).where(AssistantFile.id == file_id)
+        result = await session.execute(statement)
+        assistant_file = result.scalars().one_or_none()
+        if assistant_file is None:
+            raise ResourceNotFoundError(message=f"Assistant file-{file_id} not found")
+        return session.exec(statement).one_or_none()

+ 0 - 0
app/services/file/__init__.py


+ 9 - 0
app/services/file/file.py

@@ -0,0 +1,9 @@
+from typing import Type
+from config.storage import settings
+from app.libs.class_loader import load_class
+from app.services.file.impl.base import BaseFileService
+
+#FileService: Type[BaseFileService] = load_class(name="app.services.file.impl.oss_file.OSSFileService")
+FileService: Type[BaseFileService] = load_class(name=settings.FILE_SERVICE_MODULE)
+print("==========================================1========================1=============================================================1=============")
+print(FileService)

+ 0 - 0
app/services/file/impl/__init__.py


+ 47 - 0
app/services/file/impl/base.py

@@ -0,0 +1,47 @@
+from abc import ABC, abstractmethod
+from typing import List, Union, Generator, Tuple, Optional
+
+from sqlalchemy.orm import Session
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from fastapi import UploadFile
+
+from app.models import File
+from app.schemas.common import DeleteResponse
+
+
+class BaseFileService(ABC):
+    @staticmethod
+    @abstractmethod
+    def get_file_list_by_ids(*, session: Session, file_ids: List[str]) -> List[File]:
+        pass
+
+    @staticmethod
+    @abstractmethod
+    async def get_file_list(*, session: AsyncSession, purpose: str, file_ids: Optional[List[str]]) -> List[File]:
+        pass
+
+    @staticmethod
+    @abstractmethod
+    async def create_file(*, session: AsyncSession, purpose: str, file: UploadFile) -> File:
+        pass
+
+    @staticmethod
+    @abstractmethod
+    async def get_file(*, session: AsyncSession, file_id: str) -> File:
+        pass
+
+    @staticmethod
+    @abstractmethod
+    async def get_file_content(*, session: AsyncSession, file_id: str) -> Tuple[Union[bytes, Generator], str]:
+        pass
+
+    @staticmethod
+    @abstractmethod
+    async def delete_file(*, session: AsyncSession, file_id: str) -> DeleteResponse:
+        pass
+
+    @staticmethod
+    @abstractmethod
+    def search_in_files(*, query: str, file_keys: List[str]) -> dict:
+        pass

+ 94 - 0
app/services/file/impl/oss_file.py

@@ -0,0 +1,94 @@
+import uuid
+from typing import List, Union, Generator, Tuple
+
+from fastapi import UploadFile
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.orm import Session
+from sqlmodel import select, col, desc
+
+from app.core.doc_loaders import doc_loader
+from app.exceptions.exception import ResourceNotFoundError
+from app.models import File
+from app.providers.storage import storage
+from app.schemas.common import DeleteResponse
+from app.services.file.file import BaseFileService
+import json
+
+class OSSFileService(BaseFileService):
+    @staticmethod
+    def get_file_list_by_ids(*, session: Session, file_ids: List[str]) -> List[File]:
+        if not file_ids:
+            return []
+        statement = select(File).where(col(File.id).in_(json.loads(file_ids)))
+        return session.execute(statement).scalars().all()
+
+    @staticmethod
+    async def get_file_list(*, session: AsyncSession, purpose: str, file_ids: List[str]) -> List[File]:
+        statement = select(File)
+        if purpose is not None and len(purpose) > 0:
+            statement = statement.where(File.purpose == purpose)
+        if file_ids is not None:
+            statement = statement.where(File.id.in_(file_ids))
+        statement = statement.order_by(desc(File.created_at))
+        result = await session.execute(statement)
+        return result.scalars().all()
+
+    @staticmethod
+    async def create_file(*, session: AsyncSession, purpose: str, file: UploadFile) -> File:
+        # 文件是否存在
+        # statement = (
+        #     select(File)
+        #     .where(File.purpose == purpose)
+        #     .where(File.filename == file.filename)
+        #     .where(File.bytes == file.size)
+        # )
+        # result = await session.execute(statement)
+        # ext_file = result.scalars().first()
+        # if ext_file is not None:
+        #     # TODO: 文件去重策略
+        #     return ext_file
+
+        file_key = f"{uuid.uuid4()}-{file.filename}"
+        storage.save(filename=file_key, data=file.file.read())
+
+        # 存储
+        db_file = File(purpose=purpose, filename=file.filename, bytes=file.size, key=file_key)
+        session.add(db_file)
+        await session.commit()
+        await session.refresh(db_file)
+        return db_file
+
+    @staticmethod
+    async def get_file(*, session: AsyncSession, file_id: str) -> File:
+        statement = select(File).where(File.id == file_id)
+        result = await session.execute(statement)
+        ext_file = result.scalars().one_or_none()
+        if ext_file is None:
+            raise ResourceNotFoundError(message="File not found")
+        return ext_file
+
+    @staticmethod
+    async def get_file_content(*, session: AsyncSession, file_id: str) -> Tuple[Union[bytes, Generator], str]:
+        ext_file = await OSSFileService.get_file(session=session, file_id=file_id)
+        file_data = storage.load(ext_file.key)
+        return file_data, ext_file.filename
+
+    @staticmethod
+    async def delete_file(*, session: AsyncSession, file_id: str) -> DeleteResponse:
+        ext_file = await OSSFileService.get_file(session=session, file_id=file_id)
+        # TODO 删除s3文件
+
+        # 删除记录
+        await session.delete(ext_file)
+        await session.commit()
+        return DeleteResponse(id=file_id, deleted=True)
+
+    @staticmethod
+    def search_in_files(query: str, file_keys: List[str]) -> dict:
+        files = {}
+        for file_key in file_keys:
+            file_data = storage.load(file_key)
+            # 截取前 5000 字符,防止超出 LLM 最大上下文限制
+            files[file_key] = doc_loader.load(file_data)[:5000]
+
+        return files

Some files were not shown because too many files changed in this diff