message.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from typing import Optional
  2. from fastapi import APIRouter, Depends
  3. from fastapi.params import Query
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from sqlmodel import select
  6. from app.api.deps import get_async_session
  7. from app.models import MessageFile
  8. from app.models.message import Message, MessageCreate, MessageUpdate, MessageRead
  9. from app.libs.paginate import cursor_page, CommonPage
  10. from app.services.message.message import MessageService
  11. router = APIRouter()
  12. @router.get(
  13. "/{thread_id}/messages",
  14. response_model=CommonPage[MessageRead],
  15. )
  16. async def list_messages(
  17. *,
  18. session: AsyncSession = Depends(get_async_session),
  19. thread_id: str,
  20. run_id: Optional[str] = Query(
  21. None, description="Filter messages by the run ID that generated them."
  22. ),
  23. ):
  24. """
  25. Returns a list of messages for a given thread.
  26. """
  27. statement = select(Message).where(Message.thread_id == thread_id)
  28. if run_id:
  29. # 根据 run_id 进行过滤
  30. statement = statement.where(Message.run_id == run_id)
  31. page = await cursor_page(statement, session)
  32. page.data = [ast.model_dump(by_alias=True) for ast in page.data]
  33. return page
  34. @router.post("/{thread_id}/messages", response_model=MessageRead)
  35. async def create_message(
  36. *,
  37. session: AsyncSession = Depends(get_async_session),
  38. thread_id: str,
  39. body: MessageCreate,
  40. ):
  41. """
  42. Create a message.
  43. """
  44. message = await MessageService.create_message(
  45. session=session, thread_id=thread_id, body=body
  46. )
  47. return message.model_dump(by_alias=True)
  48. @router.get(
  49. "/{thread_id}/messages/{message_id}",
  50. response_model=MessageRead,
  51. )
  52. async def get_message(
  53. *,
  54. session: AsyncSession = Depends(get_async_session),
  55. thread_id: str,
  56. message_id: str,
  57. ):
  58. """
  59. Retrieve a message.
  60. """
  61. message = await MessageService.get_message(
  62. session=session, thread_id=thread_id, message_id=message_id
  63. )
  64. return message.model_dump(by_alias=True)
  65. @router.post(
  66. "/{thread_id}/messages/{message_id}",
  67. response_model=MessageRead,
  68. )
  69. async def modify_message(
  70. *,
  71. session: AsyncSession = Depends(get_async_session),
  72. thread_id: str,
  73. message_id: str = ...,
  74. body: MessageUpdate = ...,
  75. ):
  76. """
  77. Modifies a message.
  78. """
  79. message = await MessageService.modify_message(
  80. session=session, thread_id=thread_id, message_id=message_id, body=body
  81. )
  82. return message.model_dump(by_alias=True)
  83. @router.get(
  84. "/{thread_id}/messages/{message_id}/files",
  85. response_model=CommonPage[MessageFile],
  86. )
  87. async def list_message_files(
  88. *,
  89. session: AsyncSession = Depends(get_async_session),
  90. message_id: str = ...,
  91. ):
  92. """
  93. Returns a list of message files.
  94. """
  95. return await cursor_page(
  96. select(MessageFile).where(MessageFile.message_id == message_id), session
  97. )
  98. @router.get(
  99. "/{thread_id}/messages/{message_id}/files/{file_id}",
  100. response_model=MessageFile,
  101. )
  102. async def get_message_file(
  103. *,
  104. session: AsyncSession = Depends(get_async_session),
  105. thread_id: str,
  106. message_id: str = ...,
  107. file_id: str = ...,
  108. ) -> MessageFile:
  109. """
  110. Retrieves a message file.
  111. """
  112. return await MessageService.get_message_file(
  113. session=session, thread_id=thread_id, message_id=message_id, file_id=file_id
  114. )
  115. @router.delete(
  116. "/{thread_id}/messages/{message_id}",
  117. response_model=MessageFile,
  118. )
  119. async def delete_message(
  120. *,
  121. session: AsyncSession = Depends(get_async_session),
  122. thread_id: str,
  123. message_id: str = ...,
  124. ):
  125. """
  126. Deletes a message.
  127. Args:
  128. extra_headers: Send extra headers
  129. extra_query: Add additional query parameters to the request
  130. extra_body: Add additional JSON properties to the request
  131. timeout: Override the client-level default timeout for this request, in seconds
  132. """
  133. if not thread_id:
  134. raise ValueError(
  135. f"Expected a non-empty value for `thread_id` but received {thread_id!r}"
  136. )
  137. if not message_id:
  138. raise ValueError(
  139. f"Expected a non-empty value for `message_id` but received {message_id!r}"
  140. )
  141. """
  142. Modifies a message.
  143. """
  144. return await MessageService.delete_message(
  145. session=session, thread_id=thread_id, message_id=message_id
  146. )