| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 | from typing import Optionalfrom fastapi import APIRouter, Dependsfrom fastapi.params import Queryfrom sqlalchemy.ext.asyncio import AsyncSessionfrom sqlmodel import selectfrom app.api.deps import get_async_sessionfrom app.models import MessageFilefrom app.models.message import Message, MessageCreate, MessageUpdate, MessageReadfrom app.libs.paginate import cursor_page, CommonPagefrom app.services.message.message import MessageServicefrom app.schemas.common import DeleteResponserouter = APIRouter()@router.get(    "/{thread_id}/messages",    response_model=CommonPage[MessageRead],)async def list_messages(    *,    session: AsyncSession = Depends(get_async_session),    thread_id: str,    run_id: Optional[str] = Query(        None, description="Filter messages by the run ID that generated them."    ),):    """    Returns a list of messages for a given thread.    """    statement = select(Message).where(Message.thread_id == thread_id)    if run_id:        # 根据 run_id 进行过滤        statement = statement.where(Message.run_id == run_id)    page = await cursor_page(statement, session)    page.data = [ast.model_dump(by_alias=True) for ast in page.data]    return page@router.post("/{thread_id}/messages", response_model=MessageRead)async def create_message(    *,    session: AsyncSession = Depends(get_async_session),    thread_id: str,    body: MessageCreate,):    """    Create a message.    """    message = await MessageService.create_message(        session=session, thread_id=thread_id, body=body    )    return message.model_dump(by_alias=True)@router.get(    "/{thread_id}/messages/{message_id}",    response_model=MessageRead,)async def get_message(    *,    session: AsyncSession = Depends(get_async_session),    thread_id: str,    message_id: str,):    """    Retrieve a message.    """    message = await MessageService.get_message(        session=session, thread_id=thread_id, message_id=message_id    )    return message.model_dump(by_alias=True)@router.post(    "/{thread_id}/messages/{message_id}",    response_model=MessageRead,)async def modify_message(    *,    session: AsyncSession = Depends(get_async_session),    thread_id: str,    message_id: str = ...,    body: MessageUpdate = ...,):    """    Modifies a message.    """    message = await MessageService.modify_message(        session=session, thread_id=thread_id, message_id=message_id, body=body    )    return message.model_dump(by_alias=True)@router.get(    "/{thread_id}/messages/{message_id}/files",    response_model=CommonPage[MessageFile],)async def list_message_files(    *,    session: AsyncSession = Depends(get_async_session),    message_id: str = ...,):    """    Returns a list of message files.    """    return await cursor_page(        select(MessageFile).where(MessageFile.message_id == message_id), session    )@router.get(    "/{thread_id}/messages/{message_id}/files/{file_id}",    response_model=MessageFile,)async def get_message_file(    *,    session: AsyncSession = Depends(get_async_session),    thread_id: str,    message_id: str = ...,    file_id: str = ...,) -> MessageFile:    """    Retrieves a message file.    """    return await MessageService.get_message_file(        session=session, thread_id=thread_id, message_id=message_id, file_id=file_id    )@router.delete("/{thread_id}/messages/{message_id}", response_model=DeleteResponse)async def delete_message(    *,    session: AsyncSession = Depends(get_async_session),    thread_id: str,    message_id: str = ...,) -> DeleteResponse:    """    Deletes a message.    Args:      extra_headers: Send extra headers      extra_query: Add additional query parameters to the request      extra_body: Add additional JSON properties to the request      timeout: Override the client-level default timeout for this request, in seconds    """    if not thread_id:        raise ValueError(            f"Expected a non-empty value for `thread_id` but received {thread_id!r}"        )    if not message_id:        raise ValueError(            f"Expected a non-empty value for `message_id` but received {message_id!r}"        )    """    Modifies a message.    """    return await MessageService.delete_message(        session=session, thread_id=thread_id, message_id=message_id    )
 |