message.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from typing import List, Optional
  2. from fastapi import HTTPException
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy.orm import Session
  5. from sqlmodel import select
  6. from app.exceptions.exception import ResourceNotFoundError
  7. from app.models import MessageFile
  8. from app.models.message import Message, MessageCreate, MessageUpdate
  9. from app.services.thread.thread import ThreadService
  10. class MessageService:
  11. @staticmethod
  12. def format_message_content(message_create: MessageCreate) -> List:
  13. content = []
  14. if isinstance(message_create.content, str):
  15. content.append({"type": "text", "text": {"value": message_create.content, "annotations": []}})
  16. elif isinstance(message_create.content, list):
  17. for msg in message_create.content:
  18. if msg.get("type") == "text":
  19. msg_value = msg.get("text")
  20. content.append({"type": "text", "text": {"value": msg_value, "annotations": []}})
  21. elif msg.get("type") == "image_file" or msg.get("type") == "image_url":
  22. content.append(msg)
  23. return content
  24. @staticmethod
  25. def new_message(*, session: Session, content, role, assistant_id, thread_id, run_id) -> Message:
  26. message = Message(
  27. content=[{"type": "text", "text": {"value": content, "annotations": []}}],
  28. role=role,
  29. assistant_id=assistant_id,
  30. thread_id=thread_id,
  31. run_id=run_id,
  32. )
  33. session.add(message)
  34. session.commit()
  35. session.refresh(message)
  36. return message
  37. @staticmethod
  38. def get_message_list(*, session: Session, thread_id) -> List[Message]:
  39. statement = select(Message).where(Message.thread_id == thread_id).order_by(Message.created_at)
  40. return session.execute(statement).scalars().all()
  41. @staticmethod
  42. async def create_message(*, session: AsyncSession, body: MessageCreate, thread_id: str) -> Message:
  43. # get thread
  44. thread = await ThreadService.get_thread(thread_id=thread_id, session=session)
  45. # TODO message annotations
  46. body_file_ids = body.file_ids
  47. if body.attachments:
  48. body_file_ids = [a.get("file_id") for a in body.attachments]
  49. if body_file_ids:
  50. thread_file_ids = []
  51. if thread.tool_resources and "file_search" in thread.tool_resources:
  52. thread_file_ids = thread.tool_resources.get("file_search").get("vector_stores")[0].get("file_ids")
  53. for file_id in body_file_ids:
  54. if file_id not in thread_file_ids:
  55. thread_file_ids.append(file_id)
  56. if thread_file_ids:
  57. if not thread.tool_resources:
  58. thread.tool_resources = {}
  59. if "file_search" not in thread.tool_resources:
  60. thread.tool_resources["file_search"] = {"vector_stores": [{"file_ids": []}]}
  61. thread.tool_resources.get("file_search").get("vector_stores")[0]["file_ids"] = thread_file_ids
  62. session.add(thread)
  63. content = MessageService.format_message_content(body)
  64. db_message = Message.model_validate(body.model_dump(by_alias=True), update={"thread_id": thread_id, "content": content}, from_attributes=True)
  65. session.add(db_message)
  66. await session.commit()
  67. await session.refresh(db_message)
  68. return db_message
  69. @staticmethod
  70. def get_message_sync(*, session: Session, thread_id: str, message_id: str) -> Message:
  71. statement = select(Message).where(Message.thread_id == thread_id).where(Message.id == message_id)
  72. result = session.execute(statement)
  73. message = result.scalars().one_or_none()
  74. if message is None:
  75. raise HTTPException(status_code=404, detail="Message not found")
  76. return message
  77. @staticmethod
  78. def modify_message_sync(*, session: Session, thread_id: str, message_id: str, body: MessageUpdate) -> Message:
  79. if body.content:
  80. body.content = [{"type": "text", "text": {"value": body.content, "annotations": []}}]
  81. # get thread
  82. ThreadService.get_thread_sync(thread_id=thread_id, session=session)
  83. # get message
  84. db_message = MessageService.get_message_sync(session=session, thread_id=thread_id, message_id=message_id)
  85. update_data = body.dict(exclude_unset=True)
  86. for key, value in update_data.items():
  87. setattr(db_message, key, value)
  88. session.add(db_message)
  89. session.commit()
  90. session.refresh(db_message)
  91. return db_message
  92. @staticmethod
  93. async def modify_message(*, session: AsyncSession, thread_id: str, message_id: str, body: MessageUpdate) -> Message:
  94. if body.content:
  95. body.content = [{"type": "text", "text": {"value": body.content, "annotations": []}}]
  96. # get thread
  97. await ThreadService.get_thread(thread_id=thread_id, session=session)
  98. # get message
  99. db_message = await MessageService.get_message(session=session, thread_id=thread_id, message_id=message_id)
  100. update_data = body.dict(exclude_unset=True)
  101. for key, value in update_data.items():
  102. setattr(db_message, key, value)
  103. session.add(db_message)
  104. await session.commit()
  105. await session.refresh(db_message)
  106. return db_message
  107. @staticmethod
  108. async def get_message(*, session: AsyncSession, thread_id: str, message_id: str) -> Message:
  109. statement = select(Message).where(Message.thread_id == thread_id).where(Message.id == message_id)
  110. result = await session.execute(statement)
  111. message = result.scalars().one_or_none()
  112. if message is None:
  113. raise HTTPException(status_code=404, detail="Message not found")
  114. return message
  115. @staticmethod
  116. async def get_message_file(*, session: AsyncSession, thread_id: str, message_id: str, file_id: str) -> MessageFile:
  117. await MessageService.get_message(session=session, thread_id=thread_id, message_id=message_id)
  118. # get message files
  119. statement = select(MessageFile).where(MessageFile.id == file_id).where(MessageFile.message_id == message_id)
  120. result = await session.execute(statement)
  121. msg_file = result.scalars().one_or_none()
  122. if msg_file is None:
  123. raise ResourceNotFoundError(message="Message file not found")
  124. return msg_file
  125. @staticmethod
  126. async def copy_messages(*, session: AsyncSession, from_thread_id: str, to_thread_id: str, end_message_id: str):
  127. """
  128. copy thread messages to another thread
  129. """
  130. statement = select(Message).where(Message.thread_id == from_thread_id)
  131. if end_message_id:
  132. statement = statement.where(Message.id <= end_message_id)
  133. result = await session.execute(statement.order_by(Message.id))
  134. original_messages = result.scalars().all()
  135. for original_message in original_messages:
  136. new_message = Message(
  137. thread_id=to_thread_id,
  138. **original_message.model_dump(exclude={"id", "created_at", "updated_at", "thread_id"}),
  139. )
  140. session.add(new_message)
  141. await session.commit()
  142. @staticmethod
  143. async def create_messages(*, session: AsyncSession, thread_id: str, run_id: str, assistant_id: str, messages: list):
  144. for original_message in messages:
  145. content = MessageService.format_message_content(original_message)
  146. new_message = Message.model_validate(
  147. original_message.model_dump(by_alias=True),
  148. update={
  149. "thread_id": thread_id,
  150. "run_id": run_id,
  151. "assistant_id": assistant_id,
  152. "content": content,
  153. "role": original_message.role,
  154. },
  155. )
  156. session.add(new_message)