message.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from typing import Optional
  2. from fastapi import APIRouter, Depends
  3. from fastapi.params import Query
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from sqlmodel import select
  6. from app.api.deps import get_async_session
  7. from app.models import MessageFile
  8. from app.models.message import Message, MessageCreate, MessageUpdate, MessageRead
  9. from app.libs.paginate import cursor_page, CommonPage
  10. from app.services.message.message import MessageService
  11. from app.schemas.common import DeleteResponse
  12. router = APIRouter()
  13. @router.get(
  14. "/{thread_id}/messages",
  15. response_model=CommonPage[MessageRead],
  16. )
  17. async def list_messages(
  18. *,
  19. session: AsyncSession = Depends(get_async_session),
  20. thread_id: str,
  21. run_id: Optional[str] = Query(
  22. None, description="Filter messages by the run ID that generated them."
  23. ),
  24. ):
  25. """
  26. Returns a list of messages for a given thread.
  27. """
  28. statement = select(Message).where(Message.thread_id == thread_id)
  29. if run_id:
  30. # 根据 run_id 进行过滤
  31. statement = statement.where(Message.run_id == run_id)
  32. page = await cursor_page(statement, session)
  33. page.data = [ast.model_dump(by_alias=True) for ast in page.data]
  34. return page
  35. @router.post("/{thread_id}/messages", response_model=MessageRead)
  36. async def create_message(
  37. *,
  38. session: AsyncSession = Depends(get_async_session),
  39. thread_id: str,
  40. body: MessageCreate,
  41. ):
  42. """
  43. Create a message.
  44. """
  45. message = await MessageService.create_message(
  46. session=session, thread_id=thread_id, body=body
  47. )
  48. return message.model_dump(by_alias=True)
  49. @router.get(
  50. "/{thread_id}/messages/{message_id}",
  51. response_model=MessageRead,
  52. )
  53. async def get_message(
  54. *,
  55. session: AsyncSession = Depends(get_async_session),
  56. thread_id: str,
  57. message_id: str,
  58. ):
  59. """
  60. Retrieve a message.
  61. """
  62. message = await MessageService.get_message(
  63. session=session, thread_id=thread_id, message_id=message_id
  64. )
  65. return message.model_dump(by_alias=True)
  66. @router.post(
  67. "/{thread_id}/messages/{message_id}",
  68. response_model=MessageRead,
  69. )
  70. async def modify_message(
  71. *,
  72. session: AsyncSession = Depends(get_async_session),
  73. thread_id: str,
  74. message_id: str = ...,
  75. body: MessageUpdate = ...,
  76. ):
  77. """
  78. Modifies a message.
  79. """
  80. message = await MessageService.modify_message(
  81. session=session, thread_id=thread_id, message_id=message_id, body=body
  82. )
  83. return message.model_dump(by_alias=True)
  84. @router.get(
  85. "/{thread_id}/messages/{message_id}/files",
  86. response_model=CommonPage[MessageFile],
  87. )
  88. async def list_message_files(
  89. *,
  90. session: AsyncSession = Depends(get_async_session),
  91. message_id: str = ...,
  92. ):
  93. """
  94. Returns a list of message files.
  95. """
  96. return await cursor_page(
  97. select(MessageFile).where(MessageFile.message_id == message_id), session
  98. )
  99. @router.get(
  100. "/{thread_id}/messages/{message_id}/files/{file_id}",
  101. response_model=MessageFile,
  102. )
  103. async def get_message_file(
  104. *,
  105. session: AsyncSession = Depends(get_async_session),
  106. thread_id: str,
  107. message_id: str = ...,
  108. file_id: str = ...,
  109. ) -> MessageFile:
  110. """
  111. Retrieves a message file.
  112. """
  113. return await MessageService.get_message_file(
  114. session=session, thread_id=thread_id, message_id=message_id, file_id=file_id
  115. )
  116. @router.delete("/{thread_id}/messages/{message_id}", response_model=DeleteResponse)
  117. async def delete_message(
  118. *,
  119. session: AsyncSession = Depends(get_async_session),
  120. thread_id: str,
  121. message_id: str = ...,
  122. ) -> DeleteResponse:
  123. """
  124. Deletes a message.
  125. Args:
  126. extra_headers: Send extra headers
  127. extra_query: Add additional query parameters to the request
  128. extra_body: Add additional JSON properties to the request
  129. timeout: Override the client-level default timeout for this request, in seconds
  130. """
  131. if not thread_id:
  132. raise ValueError(
  133. f"Expected a non-empty value for `thread_id` but received {thread_id!r}"
  134. )
  135. if not message_id:
  136. raise ValueError(
  137. f"Expected a non-empty value for `message_id` but received {message_id!r}"
  138. )
  139. """
  140. Modifies a message.
  141. """
  142. return await MessageService.delete_message(
  143. session=session, thread_id=thread_id, message_id=message_id
  144. )