message.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from typing import Optional
  2. from fastapi import APIRouter, Depends
  3. from fastapi.params import Query
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from sqlmodel import select
  6. from app.api.deps import get_async_session
  7. from app.models import MessageFile
  8. from app.models.message import Message, MessageCreate, MessageUpdate, MessageRead
  9. from app.libs.paginate import cursor_page, CommonPage
  10. from app.services.message.message import MessageService
  11. router = APIRouter()
  12. @router.get(
  13. "/{thread_id}/messages",
  14. response_model=CommonPage[MessageRead],
  15. )
  16. async def list_messages(
  17. *,
  18. session: AsyncSession = Depends(get_async_session),
  19. thread_id: str,
  20. run_id: Optional[str] = Query(None, description="Filter messages by the run ID that generated them."),
  21. ):
  22. """
  23. Returns a list of messages for a given thread.
  24. """
  25. statement = select(Message).where(Message.thread_id == thread_id)
  26. if run_id:
  27. # 根据 run_id 进行过滤
  28. statement = statement.where(Message.run_id == run_id)
  29. page = await cursor_page(statement, session)
  30. page.data = [ast.model_dump(by_alias=True) for ast in page.data]
  31. return page
  32. @router.post("/{thread_id}/messages", response_model=MessageRead)
  33. async def create_message(
  34. *, session: AsyncSession = Depends(get_async_session), thread_id: str, body: MessageCreate
  35. ):
  36. """
  37. Create a message.
  38. """
  39. message = await MessageService.create_message(session=session, thread_id=thread_id, body=body)
  40. return message.model_dump(by_alias=True)
  41. @router.get(
  42. "/{thread_id}/messages/{message_id}",
  43. response_model=MessageRead,
  44. )
  45. async def get_message(
  46. *, session: AsyncSession = Depends(get_async_session), thread_id: str, message_id: str
  47. ):
  48. """
  49. Retrieve a message.
  50. """
  51. message = await MessageService.get_message(session=session, thread_id=thread_id, message_id=message_id)
  52. return message.model_dump(by_alias=True)
  53. @router.post(
  54. "/{thread_id}/messages/{message_id}",
  55. response_model=MessageRead,
  56. )
  57. async def modify_message(
  58. *,
  59. session: AsyncSession = Depends(get_async_session),
  60. thread_id: str,
  61. message_id: str = ...,
  62. body: MessageUpdate = ...,
  63. ):
  64. """
  65. Modifies a message.
  66. """
  67. message = await MessageService.modify_message(session=session, thread_id=thread_id, message_id=message_id, body=body)
  68. return message.model_dump(by_alias=True)
  69. @router.get(
  70. "/{thread_id}/messages/{message_id}/files",
  71. response_model=CommonPage[MessageFile],
  72. )
  73. async def list_message_files(
  74. *,
  75. session: AsyncSession = Depends(get_async_session),
  76. message_id: str = ...,
  77. ):
  78. """
  79. Returns a list of message files.
  80. """
  81. return await cursor_page(select(MessageFile).where(MessageFile.message_id == message_id), session)
  82. @router.get(
  83. "/{thread_id}/messages/{message_id}/files/{file_id}",
  84. response_model=MessageFile,
  85. )
  86. async def get_message_file(
  87. *,
  88. session: AsyncSession = Depends(get_async_session),
  89. thread_id: str,
  90. message_id: str = ...,
  91. file_id: str = ...,
  92. ) -> MessageFile:
  93. """
  94. Retrieves a message file.
  95. """
  96. return await MessageService.get_message_file(
  97. session=session, thread_id=thread_id, message_id=message_id, file_id=file_id
  98. )