message.py 10 KB


  1. from typing import List, Optional
  2. from fastapi import HTTPException
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy.orm import Session
  5. from sqlmodel import select
  6. from app.exceptions.exception import ResourceNotFoundError
  7. from app.models import MessageFile
  8. from app.models.message import Message, MessageCreate, MessageUpdate
  9. from app.services.thread.thread import ThreadService
  10. from app.models.token_relation import RelationType
  11. from app.providers.auth_provider import auth_policy
  12. from app.schemas.common import DeleteResponse
  13. from sqlalchemy.orm.attributes import flag_modified
  14. class MessageService:
  15. @staticmethod
  16. def format_message_content(message_create: MessageCreate) -> List:
  17. content = []
  18. if isinstance(message_create.content, str):
  19. content.append(
  20. {
  21. "type": "text",
  22. "text": {"value": message_create.content, "annotations": []},
  23. }
  24. )
  25. elif isinstance(message_create.content, list):
  26. for msg in message_create.content:
  27. if msg.get("type") == "text":
  28. msg_value = msg.get("text")
  29. content.append(
  30. {
  31. "type": "text",
  32. "text": {"value": msg_value, "annotations": []},
  33. }
  34. )
  35. elif msg.get("type") == "image_file" or msg.get("type") == "image_url":
  36. content.append(msg)
  37. elif msg.get("type") == "input_audio":
  38. content.append(msg)
  39. return content
  40. @staticmethod
  41. def new_message(
  42. *, session: Session, content, role, assistant_id, thread_id, run_id
  43. ) -> Message:
  44. message = Message(
  45. content=[{"type": "text", "text": {"value": content, "annotations": []}}],
  46. role=role,
  47. assistant_id=assistant_id,
  48. thread_id=thread_id,
  49. run_id=run_id,
  50. )
  51. session.add(message)
  52. session.commit()
  53. session.refresh(message)
  54. return message
  55. @staticmethod
  56. def get_message_list(*, session: Session, thread_id) -> List[Message]:
  57. statement = (
  58. select(Message)
  59. .where(Message.thread_id == thread_id)
  60. .order_by(Message.created_at)
  61. )
  62. return session.execute(statement).scalars().all()
  63. @staticmethod
  64. async def create_message(
  65. *, session: AsyncSession, body: MessageCreate, thread_id: str
  66. ) -> Message:
  67. # get thread
  68. thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
  69. print(
  70. "create_messagecreate_messagecreate_messagecreate_messagecreate_messagecreate_message"
  71. )
  72. # print(thread)
  73. # print(body)
  74. # TODO message annotations
  75. body_file_ids = body.file_ids
  76. if body.attachments:
  77. body_file_ids = [a.get("file_id") for a in body.attachments]
  78. # print(body_file_ids)
  79. if body_file_ids:
  80. thread_file_ids = []
  81. if thread.tool_resources and "file_search" in thread.tool_resources:
  82. thread_file_ids = (
  83. thread.tool_resources.get("file_search")
  84. .get("vector_stores")[0]
  85. .get("file_ids")
  86. )
  87. for file_id in body_file_ids:
  88. if file_id not in thread_file_ids:
  89. thread_file_ids.append(file_id)
  90. print(thread_file_ids)
  91. # if thread_file_ids:
  92. if not thread.tool_resources:
  93. thread.tool_resources = {}
  94. if "file_search" not in thread.tool_resources:
  95. thread.tool_resources["file_search"] = {
  96. "vector_stores": [{"file_ids": []}]
  97. }
  98. thread.tool_resources.get("file_search").get("vector_stores")[0][
  99. "file_ids"
  100. ] = thread_file_ids
  101. setattr(thread, "tool_resources", thread.tool_resources)
  102. flag_modified(thread, "tool_resources")
  103. # thread.tool_resources = thread.tool_resources
  104. print(thread)
  105. session.add(thread)
  106. await session.commit()
  107. await session.refresh(thread)
  108. # session.add(thread)
  109. # await session.commit()
  110. # await session.refresh(thread)
  111. content = MessageService.format_message_content(body)
  112. db_message = Message.model_validate(
  113. body.model_dump(by_alias=True),
  114. update={"thread_id": thread_id, "content": content},
  115. from_attributes=True,
  116. )
  117. session.add(db_message)
  118. await session.commit()
  119. await session.refresh(db_message)
  120. return db_message
  121. @staticmethod
  122. def get_message_sync(
  123. *, session: Session, thread_id: str, message_id: str
  124. ) -> Message:
  125. statement = (
  126. select(Message)
  127. .where(Message.thread_id == thread_id)
  128. .where(Message.id == message_id)
  129. )
  130. result = session.execute(statement)
  131. message = result.scalars().one_or_none()
  132. if message is None:
  133. raise HTTPException(status_code=404, detail="Message not found")
  134. return message
  135. @staticmethod
  136. def modify_message_sync(
  137. *, session: Session, thread_id: str, message_id: str, body: MessageUpdate
  138. ) -> Message:
  139. if body.content:
  140. body.content = [
  141. {"type": "text", "text": {"value": body.content, "annotations": []}}
  142. ]
  143. # get thread
  144. ThreadService.get_thread_sync(thread_id=thread_id, session=session)
  145. # get message
  146. db_message = MessageService.get_message_sync(
  147. session=session, thread_id=thread_id, message_id=message_id
  148. )
  149. update_data = body.dict(exclude_unset=True)
  150. for key, value in update_data.items():
  151. setattr(db_message, key, value)
  152. session.add(db_message)
  153. session.commit()
  154. session.refresh(db_message)
  155. return db_message
  156. @staticmethod
  157. async def modify_message(
  158. *, session: AsyncSession, thread_id: str, message_id: str, body: MessageUpdate
  159. ) -> Message:
  160. if body.content:
  161. body.content = [
  162. {"type": "text", "text": {"value": body.content, "annotations": []}}
  163. ]
  164. # get thread
  165. await ThreadService.get_thread(thread_id=thread_id, session=session)
  166. # get message
  167. db_message = await MessageService.get_message(
  168. session=session, thread_id=thread_id, message_id=message_id
  169. )
  170. update_data = body.dict(exclude_unset=True)
  171. for key, value in update_data.items():
  172. setattr(db_message, key, value)
  173. session.add(db_message)
  174. await session.commit()
  175. await session.refresh(db_message)
  176. return db_message
  177. @staticmethod
  178. async def delete_message(
  179. *, session: AsyncSession, thread_id: str, message_id: str
  180. ) -> Message:
  181. message = await MessageService.get_message(
  182. session=session, thread_id=thread_id, message_id=message_id
  183. )
  184. await session.delete(message)
  185. await auth_policy.delete_token_rel(
  186. session=session, relation_type=RelationType.Message, relation_id=message_id
  187. )
  188. await session.commit()
  189. return DeleteResponse(id=message_id, object="message.deleted", deleted=True)
  190. @staticmethod
  191. async def get_message(
  192. *, session: AsyncSession, thread_id: str, message_id: str
  193. ) -> Message:
  194. statement = (
  195. select(Message)
  196. .where(Message.thread_id == thread_id)
  197. .where(Message.id == message_id)
  198. )
  199. result = await session.execute(statement)
  200. message = result.scalars().one_or_none()
  201. if message is None:
  202. raise HTTPException(status_code=404, detail="Message not found")
  203. return message
  204. @staticmethod
  205. async def get_message_file(
  206. *, session: AsyncSession, thread_id: str, message_id: str, file_id: str
  207. ) -> MessageFile:
  208. await MessageService.get_message(
  209. session=session, thread_id=thread_id, message_id=message_id
  210. )
  211. # get message files
  212. statement = (
  213. select(MessageFile)
  214. .where(MessageFile.id == file_id)
  215. .where(MessageFile.message_id == message_id)
  216. )
  217. result = await session.execute(statement)
  218. msg_file = result.scalars().one_or_none()
  219. if msg_file is None:
  220. raise ResourceNotFoundError(message="Message file not found")
  221. return msg_file
  222. @staticmethod
  223. async def copy_messages(
  224. *,
  225. session: AsyncSession,
  226. from_thread_id: str,
  227. to_thread_id: str,
  228. end_message_id: str
  229. ):
  230. """
  231. copy thread messages to another thread
  232. """
  233. statement = select(Message).where(Message.thread_id == from_thread_id)
  234. if end_message_id:
  235. statement = statement.where(Message.id <= end_message_id)
  236. result = await session.execute(statement.order_by(Message.id))
  237. original_messages = result.scalars().all()
  238. for original_message in original_messages:
  239. new_message = Message(
  240. thread_id=to_thread_id,
  241. **original_message.model_dump(
  242. exclude={"id", "created_at", "updated_at", "thread_id"}
  243. ),
  244. )
  245. session.add(new_message)
  246. await session.commit()
  247. @staticmethod
  248. async def create_messages(
  249. *,
  250. session: AsyncSession,
  251. thread_id: str,
  252. run_id: str,
  253. assistant_id: str,
  254. messages: list
  255. ):
  256. for original_message in messages:
  257. content = MessageService.format_message_content(original_message)
  258. new_message = Message.model_validate(
  259. original_message.model_dump(by_alias=True),
  260. update={
  261. "thread_id": thread_id,
  262. "run_id": run_id,
  263. "assistant_id": assistant_id,
  264. "content": content,
  265. "role": original_message.role,
  266. },
  267. )
  268. session.add(new_message)