123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- from typing import Optional
- from fastapi import APIRouter, Depends
- from fastapi.params import Query
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlmodel import select
- from app.api.deps import get_async_session
- from app.models import MessageFile
- from app.models.message import Message, MessageCreate, MessageUpdate, MessageRead
- from app.libs.paginate import cursor_page, CommonPage
- from app.services.message.message import MessageService
- from app.schemas.common import DeleteResponse
- router = APIRouter()
- @router.get(
- "/{thread_id}/messages",
- response_model=CommonPage[MessageRead],
- )
- async def list_messages(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- run_id: Optional[str] = Query(
- None, description="Filter messages by the run ID that generated them."
- ),
- ):
- """
- Returns a list of messages for a given thread.
- """
- statement = select(Message).where(Message.thread_id == thread_id)
- if run_id:
- # 根据 run_id 进行过滤
- statement = statement.where(Message.run_id == run_id)
- page = await cursor_page(statement, session)
- page.data = [ast.model_dump(by_alias=True) for ast in page.data]
- return page
- @router.post("/{thread_id}/messages", response_model=MessageRead)
- async def create_message(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- body: MessageCreate,
- ):
- """
- Create a message.
- """
- message = await MessageService.create_message(
- session=session, thread_id=thread_id, body=body
- )
- return message.model_dump(by_alias=True)
- @router.get(
- "/{thread_id}/messages/{message_id}",
- response_model=MessageRead,
- )
- async def get_message(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- message_id: str,
- ):
- """
- Retrieve a message.
- """
- message = await MessageService.get_message(
- session=session, thread_id=thread_id, message_id=message_id
- )
- return message.model_dump(by_alias=True)
- @router.post(
- "/{thread_id}/messages/{message_id}",
- response_model=MessageRead,
- )
- async def modify_message(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- message_id: str = ...,
- body: MessageUpdate = ...,
- ):
- """
- Modifies a message.
- """
- message = await MessageService.modify_message(
- session=session, thread_id=thread_id, message_id=message_id, body=body
- )
- return message.model_dump(by_alias=True)
- @router.get(
- "/{thread_id}/messages/{message_id}/files",
- response_model=CommonPage[MessageFile],
- )
- async def list_message_files(
- *,
- session: AsyncSession = Depends(get_async_session),
- message_id: str = ...,
- ):
- """
- Returns a list of message files.
- """
- return await cursor_page(
- select(MessageFile).where(MessageFile.message_id == message_id), session
- )
- @router.get(
- "/{thread_id}/messages/{message_id}/files/{file_id}",
- response_model=MessageFile,
- )
- async def get_message_file(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- message_id: str = ...,
- file_id: str = ...,
- ) -> MessageFile:
- """
- Retrieves a message file.
- """
- return await MessageService.get_message_file(
- session=session, thread_id=thread_id, message_id=message_id, file_id=file_id
- )
- @router.delete("/{thread_id}/messages/{message_id}", response_model=DeleteResponse)
- async def delete_message(
- *,
- session: AsyncSession = Depends(get_async_session),
- thread_id: str,
- message_id: str = ...,
- ) -> DeleteResponse:
- """
- Deletes a message.
- Args:
- extra_headers: Send extra headers
- extra_query: Add additional query parameters to the request
- extra_body: Add additional JSON properties to the request
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not thread_id:
- raise ValueError(
- f"Expected a non-empty value for `thread_id` but received {thread_id!r}"
- )
- if not message_id:
- raise ValueError(
- f"Expected a non-empty value for `message_id` but received {message_id!r}"
- )
- """
- Modifies a message.
- """
- return await MessageService.delete_message(
- session=session, thread_id=thread_id, message_id=message_id
- )
|