message.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. from typing import List, Optional
  2. from fastapi import HTTPException
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy.orm import Session
  5. from sqlmodel import select
  6. from app.exceptions.exception import ResourceNotFoundError
  7. from app.models import MessageFile
  8. from app.models.message import Message, MessageCreate, MessageUpdate
  9. from app.services.thread.thread import ThreadService
  10. from app.models.token_relation import RelationType
  11. from app.providers.auth_provider import auth_policy
  12. from app.schemas.common import DeleteResponse
  13. class MessageService:
  14. @staticmethod
  15. def format_message_content(message_create: MessageCreate) -> List:
  16. content = []
  17. if isinstance(message_create.content, str):
  18. content.append(
  19. {
  20. "type": "text",
  21. "text": {"value": message_create.content, "annotations": []},
  22. }
  23. )
  24. elif isinstance(message_create.content, list):
  25. for msg in message_create.content:
  26. if msg.get("type") == "text":
  27. msg_value = msg.get("text")
  28. content.append(
  29. {
  30. "type": "text",
  31. "text": {"value": msg_value, "annotations": []},
  32. }
  33. )
  34. elif msg.get("type") == "image_file" or msg.get("type") == "image_url":
  35. content.append(msg)
  36. return content
  37. @staticmethod
  38. def new_message(
  39. *, session: Session, content, role, assistant_id, thread_id, run_id
  40. ) -> Message:
  41. message = Message(
  42. content=[{"type": "text", "text": {"value": content, "annotations": []}}],
  43. role=role,
  44. assistant_id=assistant_id,
  45. thread_id=thread_id,
  46. run_id=run_id,
  47. )
  48. session.add(message)
  49. session.commit()
  50. session.refresh(message)
  51. return message
  52. @staticmethod
  53. def get_message_list(*, session: Session, thread_id) -> List[Message]:
  54. statement = (
  55. select(Message)
  56. .where(Message.thread_id == thread_id)
  57. .order_by(Message.created_at)
  58. )
  59. return session.execute(statement).scalars().all()
  60. @staticmethod
  61. async def create_message(
  62. *, session: AsyncSession, body: MessageCreate, thread_id: str
  63. ) -> Message:
  64. # get thread
  65. thread = await ThreadService.get_thread(thread_id=thread_id, session=session)
  66. # TODO message annotations
  67. body_file_ids = body.file_ids
  68. if body.attachments:
  69. body_file_ids = [a.get("file_id") for a in body.attachments]
  70. if body_file_ids:
  71. thread_file_ids = []
  72. if thread.tool_resources and "file_search" in thread.tool_resources:
  73. thread_file_ids = (
  74. thread.tool_resources.get("file_search")
  75. .get("vector_stores")[0]
  76. .get("file_ids")
  77. )
  78. for file_id in body_file_ids:
  79. if file_id not in thread_file_ids:
  80. thread_file_ids.append(file_id)
  81. if thread_file_ids:
  82. if not thread.tool_resources:
  83. thread.tool_resources = {}
  84. if "file_search" not in thread.tool_resources:
  85. thread.tool_resources["file_search"] = {
  86. "vector_stores": [{"file_ids": []}]
  87. }
  88. thread.tool_resources.get("file_search").get("vector_stores")[0][
  89. "file_ids"
  90. ] = thread_file_ids
  91. session.add(thread)
  92. content = MessageService.format_message_content(body)
  93. db_message = Message.model_validate(
  94. body.model_dump(by_alias=True),
  95. update={"thread_id": thread_id, "content": content},
  96. from_attributes=True,
  97. )
  98. session.add(db_message)
  99. await session.commit()
  100. await session.refresh(db_message)
  101. return db_message
  102. @staticmethod
  103. def get_message_sync(
  104. *, session: Session, thread_id: str, message_id: str
  105. ) -> Message:
  106. statement = (
  107. select(Message)
  108. .where(Message.thread_id == thread_id)
  109. .where(Message.id == message_id)
  110. )
  111. result = session.execute(statement)
  112. message = result.scalars().one_or_none()
  113. if message is None:
  114. raise HTTPException(status_code=404, detail="Message not found")
  115. return message
  116. @staticmethod
  117. def modify_message_sync(
  118. *, session: Session, thread_id: str, message_id: str, body: MessageUpdate
  119. ) -> Message:
  120. if body.content:
  121. body.content = [
  122. {"type": "text", "text": {"value": body.content, "annotations": []}}
  123. ]
  124. # get thread
  125. ThreadService.get_thread_sync(thread_id=thread_id, session=session)
  126. # get message
  127. db_message = MessageService.get_message_sync(
  128. session=session, thread_id=thread_id, message_id=message_id
  129. )
  130. update_data = body.dict(exclude_unset=True)
  131. for key, value in update_data.items():
  132. setattr(db_message, key, value)
  133. session.add(db_message)
  134. session.commit()
  135. session.refresh(db_message)
  136. return db_message
  137. @staticmethod
  138. async def modify_message(
  139. *, session: AsyncSession, thread_id: str, message_id: str, body: MessageUpdate
  140. ) -> Message:
  141. if body.content:
  142. body.content = [
  143. {"type": "text", "text": {"value": body.content, "annotations": []}}
  144. ]
  145. # get thread
  146. await ThreadService.get_thread(thread_id=thread_id, session=session)
  147. # get message
  148. db_message = await MessageService.get_message(
  149. session=session, thread_id=thread_id, message_id=message_id
  150. )
  151. update_data = body.dict(exclude_unset=True)
  152. for key, value in update_data.items():
  153. setattr(db_message, key, value)
  154. session.add(db_message)
  155. await session.commit()
  156. await session.refresh(db_message)
  157. return db_message
  158. @staticmethod
  159. async def delete_message(
  160. *, session: AsyncSession, thread_id: str, message_id: str
  161. ) -> Message:
  162. message = await MessageService.get_message(
  163. session=session, thread_id=thread_id, message_id=message_id
  164. )
  165. await session.delete(message)
  166. await auth_policy.delete_token_rel(
  167. session=session, relation_type=RelationType.Message, relation_id=message_id
  168. )
  169. await session.commit()
  170. return DeleteResponse(id=message_id, object="message.deleted", deleted=True)
  171. @staticmethod
  172. async def get_message(
  173. *, session: AsyncSession, thread_id: str, message_id: str
  174. ) -> Message:
  175. statement = (
  176. select(Message)
  177. .where(Message.thread_id == thread_id)
  178. .where(Message.id == message_id)
  179. )
  180. result = await session.execute(statement)
  181. message = result.scalars().one_or_none()
  182. if message is None:
  183. raise HTTPException(status_code=404, detail="Message not found")
  184. return message
  185. @staticmethod
  186. async def get_message_file(
  187. *, session: AsyncSession, thread_id: str, message_id: str, file_id: str
  188. ) -> MessageFile:
  189. await MessageService.get_message(
  190. session=session, thread_id=thread_id, message_id=message_id
  191. )
  192. # get message files
  193. statement = (
  194. select(MessageFile)
  195. .where(MessageFile.id == file_id)
  196. .where(MessageFile.message_id == message_id)
  197. )
  198. result = await session.execute(statement)
  199. msg_file = result.scalars().one_or_none()
  200. if msg_file is None:
  201. raise ResourceNotFoundError(message="Message file not found")
  202. return msg_file
  203. @staticmethod
  204. async def copy_messages(
  205. *,
  206. session: AsyncSession,
  207. from_thread_id: str,
  208. to_thread_id: str,
  209. end_message_id: str
  210. ):
  211. """
  212. copy thread messages to another thread
  213. """
  214. statement = select(Message).where(Message.thread_id == from_thread_id)
  215. if end_message_id:
  216. statement = statement.where(Message.id <= end_message_id)
  217. result = await session.execute(statement.order_by(Message.id))
  218. original_messages = result.scalars().all()
  219. for original_message in original_messages:
  220. new_message = Message(
  221. thread_id=to_thread_id,
  222. **original_message.model_dump(
  223. exclude={"id", "created_at", "updated_at", "thread_id"}
  224. ),
  225. )
  226. session.add(new_message)
  227. await session.commit()
  228. @staticmethod
  229. async def create_messages(
  230. *,
  231. session: AsyncSession,
  232. thread_id: str,
  233. run_id: str,
  234. assistant_id: str,
  235. messages: list
  236. ):
  237. for original_message in messages:
  238. content = MessageService.format_message_content(original_message)
  239. new_message = Message.model_validate(
  240. original_message.model_dump(by_alias=True),
  241. update={
  242. "thread_id": thread_id,
  243. "run_id": run_id,
  244. "assistant_id": assistant_id,
  245. "content": content,
  246. "role": original_message.role,
  247. },
  248. )
  249. session.add(new_message)