message.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. from typing import List, Optional
  2. from fastapi import HTTPException
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy.orm import Session
  5. from sqlmodel import select
  6. from app.exceptions.exception import ResourceNotFoundError
  7. from app.models import MessageFile
  8. from app.models.message import Message, MessageCreate, MessageUpdate
  9. from app.services.thread.thread import ThreadService
  10. from app.models.token_relation import RelationType
  11. from app.providers.auth_provider import auth_policy
  12. from app.schemas.common import DeleteResponse
  13. class MessageService:
  14. @staticmethod
  15. def format_message_content(message_create: MessageCreate) -> List:
  16. content = []
  17. if isinstance(message_create.content, str):
  18. content.append(
  19. {
  20. "type": "text",
  21. "text": {"value": message_create.content, "annotations": []},
  22. }
  23. )
  24. elif isinstance(message_create.content, list):
  25. for msg in message_create.content:
  26. if msg.get("type") == "text":
  27. msg_value = msg.get("text")
  28. content.append(
  29. {
  30. "type": "text",
  31. "text": {"value": msg_value, "annotations": []},
  32. }
  33. )
  34. elif msg.get("type") == "image_file" or msg.get("type") == "image_url":
  35. content.append(msg)
  36. elif msg.get("type") == "input_audio":
  37. content.append(msg)
  38. return content
  39. @staticmethod
  40. def new_message(
  41. *, session: Session, content, role, assistant_id, thread_id, run_id
  42. ) -> Message:
  43. message = Message(
  44. content=[{"type": "text", "text": {"value": content, "annotations": []}}],
  45. role=role,
  46. assistant_id=assistant_id,
  47. thread_id=thread_id,
  48. run_id=run_id,
  49. )
  50. session.add(message)
  51. session.commit()
  52. session.refresh(message)
  53. return message
  54. @staticmethod
  55. def get_message_list(*, session: Session, thread_id) -> List[Message]:
  56. statement = (
  57. select(Message)
  58. .where(Message.thread_id == thread_id)
  59. .order_by(Message.created_at)
  60. )
  61. return session.execute(statement).scalars().all()
  62. @staticmethod
  63. async def create_message(
  64. *, session: AsyncSession, body: MessageCreate, thread_id: str
  65. ) -> Message:
  66. # get thread
  67. thread = await ThreadService.get_thread(thread_id=thread_id, session=session)
  68. print(
  69. "create_messagecreate_messagecreate_messagecreate_messagecreate_messagecreate_message"
  70. )
  71. # print(thread)
  72. # print(body)
  73. # TODO message annotations
  74. body_file_ids = body.file_ids
  75. if body.attachments:
  76. body_file_ids = [a.get("file_id") for a in body.attachments]
  77. # print(body_file_ids)
  78. if body_file_ids:
  79. thread_file_ids = []
  80. if thread.tool_resources and "file_search" in thread.tool_resources:
  81. thread_file_ids = (
  82. thread.tool_resources.get("file_search")
  83. .get("vector_stores")[0]
  84. .get("file_ids")
  85. )
  86. for file_id in body_file_ids:
  87. if file_id not in thread_file_ids:
  88. thread_file_ids.append(file_id)
  89. print(thread_file_ids)
  90. # if thread_file_ids:
  91. if not thread.tool_resources:
  92. thread.tool_resources = {}
  93. if "file_search" not in thread.tool_resources:
  94. thread.tool_resources["file_search"] = {
  95. "vector_stores": [{"file_ids": []}]
  96. }
  97. thread.tool_resources.get("file_search").get("vector_stores")[0][
  98. "file_ids"
  99. ] = thread_file_ids
  100. print(thread)
  101. session.add(thread)
  102. await session.commit()
  103. await session.refresh(thread)
  104. content = MessageService.format_message_content(body)
  105. db_message = Message.model_validate(
  106. body.model_dump(by_alias=True),
  107. update={"thread_id": thread_id, "content": content},
  108. from_attributes=True,
  109. )
  110. session.add(db_message)
  111. await session.commit()
  112. await session.refresh(db_message)
  113. return db_message
  114. @staticmethod
  115. def get_message_sync(
  116. *, session: Session, thread_id: str, message_id: str
  117. ) -> Message:
  118. statement = (
  119. select(Message)
  120. .where(Message.thread_id == thread_id)
  121. .where(Message.id == message_id)
  122. )
  123. result = session.execute(statement)
  124. message = result.scalars().one_or_none()
  125. if message is None:
  126. raise HTTPException(status_code=404, detail="Message not found")
  127. return message
  128. @staticmethod
  129. def modify_message_sync(
  130. *, session: Session, thread_id: str, message_id: str, body: MessageUpdate
  131. ) -> Message:
  132. if body.content:
  133. body.content = [
  134. {"type": "text", "text": {"value": body.content, "annotations": []}}
  135. ]
  136. # get thread
  137. ThreadService.get_thread_sync(thread_id=thread_id, session=session)
  138. # get message
  139. db_message = MessageService.get_message_sync(
  140. session=session, thread_id=thread_id, message_id=message_id
  141. )
  142. update_data = body.dict(exclude_unset=True)
  143. for key, value in update_data.items():
  144. setattr(db_message, key, value)
  145. session.add(db_message)
  146. session.commit()
  147. session.refresh(db_message)
  148. return db_message
  149. @staticmethod
  150. async def modify_message(
  151. *, session: AsyncSession, thread_id: str, message_id: str, body: MessageUpdate
  152. ) -> Message:
  153. if body.content:
  154. body.content = [
  155. {"type": "text", "text": {"value": body.content, "annotations": []}}
  156. ]
  157. # get thread
  158. await ThreadService.get_thread(thread_id=thread_id, session=session)
  159. # get message
  160. db_message = await MessageService.get_message(
  161. session=session, thread_id=thread_id, message_id=message_id
  162. )
  163. update_data = body.dict(exclude_unset=True)
  164. for key, value in update_data.items():
  165. setattr(db_message, key, value)
  166. session.add(db_message)
  167. await session.commit()
  168. await session.refresh(db_message)
  169. return db_message
  170. @staticmethod
  171. async def delete_message(
  172. *, session: AsyncSession, thread_id: str, message_id: str
  173. ) -> Message:
  174. message = await MessageService.get_message(
  175. session=session, thread_id=thread_id, message_id=message_id
  176. )
  177. await session.delete(message)
  178. await auth_policy.delete_token_rel(
  179. session=session, relation_type=RelationType.Message, relation_id=message_id
  180. )
  181. await session.commit()
  182. return DeleteResponse(id=message_id, object="message.deleted", deleted=True)
  183. @staticmethod
  184. async def get_message(
  185. *, session: AsyncSession, thread_id: str, message_id: str
  186. ) -> Message:
  187. statement = (
  188. select(Message)
  189. .where(Message.thread_id == thread_id)
  190. .where(Message.id == message_id)
  191. )
  192. result = await session.execute(statement)
  193. message = result.scalars().one_or_none()
  194. if message is None:
  195. raise HTTPException(status_code=404, detail="Message not found")
  196. return message
  197. @staticmethod
  198. async def get_message_file(
  199. *, session: AsyncSession, thread_id: str, message_id: str, file_id: str
  200. ) -> MessageFile:
  201. await MessageService.get_message(
  202. session=session, thread_id=thread_id, message_id=message_id
  203. )
  204. # get message files
  205. statement = (
  206. select(MessageFile)
  207. .where(MessageFile.id == file_id)
  208. .where(MessageFile.message_id == message_id)
  209. )
  210. result = await session.execute(statement)
  211. msg_file = result.scalars().one_or_none()
  212. if msg_file is None:
  213. raise ResourceNotFoundError(message="Message file not found")
  214. return msg_file
  215. @staticmethod
  216. async def copy_messages(
  217. *,
  218. session: AsyncSession,
  219. from_thread_id: str,
  220. to_thread_id: str,
  221. end_message_id: str
  222. ):
  223. """
  224. copy thread messages to another thread
  225. """
  226. statement = select(Message).where(Message.thread_id == from_thread_id)
  227. if end_message_id:
  228. statement = statement.where(Message.id <= end_message_id)
  229. result = await session.execute(statement.order_by(Message.id))
  230. original_messages = result.scalars().all()
  231. for original_message in original_messages:
  232. new_message = Message(
  233. thread_id=to_thread_id,
  234. **original_message.model_dump(
  235. exclude={"id", "created_at", "updated_at", "thread_id"}
  236. ),
  237. )
  238. session.add(new_message)
  239. await session.commit()
  240. @staticmethod
  241. async def create_messages(
  242. *,
  243. session: AsyncSession,
  244. thread_id: str,
  245. run_id: str,
  246. assistant_id: str,
  247. messages: list
  248. ):
  249. for original_message in messages:
  250. content = MessageService.format_message_content(original_message)
  251. new_message = Message.model_validate(
  252. original_message.model_dump(by_alias=True),
  253. update={
  254. "thread_id": thread_id,
  255. "run_id": run_id,
  256. "assistant_id": assistant_id,
  257. "content": content,
  258. "role": original_message.role,
  259. },
  260. )
  261. session.add(new_message)