123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- from typing import List, Optional
- from fastapi import HTTPException
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy.orm import Session
- from sqlmodel import select
- from app.exceptions.exception import ResourceNotFoundError
- from app.models import MessageFile
- from app.models.message import Message, MessageCreate, MessageUpdate
- from app.services.thread.thread import ThreadService
- from app.models.token_relation import RelationType
- from app.providers.auth_provider import auth_policy
- from app.schemas.common import DeleteResponse
- class MessageService:
- @staticmethod
- def format_message_content(message_create: MessageCreate) -> List:
- content = []
- if isinstance(message_create.content, str):
- content.append(
- {
- "type": "text",
- "text": {"value": message_create.content, "annotations": []},
- }
- )
- elif isinstance(message_create.content, list):
- for msg in message_create.content:
- if msg.get("type") == "text":
- msg_value = msg.get("text")
- content.append(
- {
- "type": "text",
- "text": {"value": msg_value, "annotations": []},
- }
- )
- elif msg.get("type") == "image_file" or msg.get("type") == "image_url":
- content.append(msg)
- return content
- @staticmethod
- def new_message(
- *, session: Session, content, role, assistant_id, thread_id, run_id
- ) -> Message:
- message = Message(
- content=[{"type": "text", "text": {"value": content, "annotations": []}}],
- role=role,
- assistant_id=assistant_id,
- thread_id=thread_id,
- run_id=run_id,
- )
- session.add(message)
- session.commit()
- session.refresh(message)
- return message
- @staticmethod
- def get_message_list(*, session: Session, thread_id) -> List[Message]:
- statement = (
- select(Message)
- .where(Message.thread_id == thread_id)
- .order_by(Message.created_at)
- )
- return session.execute(statement).scalars().all()
- @staticmethod
- async def create_message(
- *, session: AsyncSession, body: MessageCreate, thread_id: str
- ) -> Message:
- # get thread
- thread = await ThreadService.get_thread(thread_id=thread_id, session=session)
- # TODO message annotations
- body_file_ids = body.file_ids
- if body.attachments:
- body_file_ids = [a.get("file_id") for a in body.attachments]
- if body_file_ids:
- thread_file_ids = []
- if thread.tool_resources and "file_search" in thread.tool_resources:
- thread_file_ids = (
- thread.tool_resources.get("file_search")
- .get("vector_stores")[0]
- .get("file_ids")
- )
- for file_id in body_file_ids:
- if file_id not in thread_file_ids:
- thread_file_ids.append(file_id)
- if thread_file_ids:
- if not thread.tool_resources:
- thread.tool_resources = {}
- if "file_search" not in thread.tool_resources:
- thread.tool_resources["file_search"] = {
- "vector_stores": [{"file_ids": []}]
- }
- thread.tool_resources.get("file_search").get("vector_stores")[0][
- "file_ids"
- ] = thread_file_ids
- session.add(thread)
- content = MessageService.format_message_content(body)
- db_message = Message.model_validate(
- body.model_dump(by_alias=True),
- update={"thread_id": thread_id, "content": content},
- from_attributes=True,
- )
- session.add(db_message)
- await session.commit()
- await session.refresh(db_message)
- return db_message
- @staticmethod
- def get_message_sync(
- *, session: Session, thread_id: str, message_id: str
- ) -> Message:
- statement = (
- select(Message)
- .where(Message.thread_id == thread_id)
- .where(Message.id == message_id)
- )
- result = session.execute(statement)
- message = result.scalars().one_or_none()
- if message is None:
- raise HTTPException(status_code=404, detail="Message not found")
- return message
- @staticmethod
- def modify_message_sync(
- *, session: Session, thread_id: str, message_id: str, body: MessageUpdate
- ) -> Message:
- if body.content:
- body.content = [
- {"type": "text", "text": {"value": body.content, "annotations": []}}
- ]
- # get thread
- ThreadService.get_thread_sync(thread_id=thread_id, session=session)
- # get message
- db_message = MessageService.get_message_sync(
- session=session, thread_id=thread_id, message_id=message_id
- )
- update_data = body.dict(exclude_unset=True)
- for key, value in update_data.items():
- setattr(db_message, key, value)
- session.add(db_message)
- session.commit()
- session.refresh(db_message)
- return db_message
- @staticmethod
- async def modify_message(
- *, session: AsyncSession, thread_id: str, message_id: str, body: MessageUpdate
- ) -> Message:
- if body.content:
- body.content = [
- {"type": "text", "text": {"value": body.content, "annotations": []}}
- ]
- # get thread
- await ThreadService.get_thread(thread_id=thread_id, session=session)
- # get message
- db_message = await MessageService.get_message(
- session=session, thread_id=thread_id, message_id=message_id
- )
- update_data = body.dict(exclude_unset=True)
- for key, value in update_data.items():
- setattr(db_message, key, value)
- session.add(db_message)
- await session.commit()
- await session.refresh(db_message)
- return db_message
- @staticmethod
- async def delete_message(
- *, session: AsyncSession, thread_id: str, message_id: str
- ) -> Message:
- message = await MessageService.get_message(
- session=session, thread_id=thread_id, message_id=message_id
- )
- await session.delete(message)
- await auth_policy.delete_token_rel(
- session=session, relation_type=RelationType.Message, relation_id=message_id
- )
- await session.commit()
- return DeleteResponse(id=message_id, object="message.deleted", deleted=True)
- @staticmethod
- async def get_message(
- *, session: AsyncSession, thread_id: str, message_id: str
- ) -> Message:
- statement = (
- select(Message)
- .where(Message.thread_id == thread_id)
- .where(Message.id == message_id)
- )
- result = await session.execute(statement)
- message = result.scalars().one_or_none()
- if message is None:
- raise HTTPException(status_code=404, detail="Message not found")
- return message
- @staticmethod
- async def get_message_file(
- *, session: AsyncSession, thread_id: str, message_id: str, file_id: str
- ) -> MessageFile:
- await MessageService.get_message(
- session=session, thread_id=thread_id, message_id=message_id
- )
- # get message files
- statement = (
- select(MessageFile)
- .where(MessageFile.id == file_id)
- .where(MessageFile.message_id == message_id)
- )
- result = await session.execute(statement)
- msg_file = result.scalars().one_or_none()
- if msg_file is None:
- raise ResourceNotFoundError(message="Message file not found")
- return msg_file
- @staticmethod
- async def copy_messages(
- *,
- session: AsyncSession,
- from_thread_id: str,
- to_thread_id: str,
- end_message_id: str
- ):
- """
- copy thread messages to another thread
- """
- statement = select(Message).where(Message.thread_id == from_thread_id)
- if end_message_id:
- statement = statement.where(Message.id <= end_message_id)
- result = await session.execute(statement.order_by(Message.id))
- original_messages = result.scalars().all()
- for original_message in original_messages:
- new_message = Message(
- thread_id=to_thread_id,
- **original_message.model_dump(
- exclude={"id", "created_at", "updated_at", "thread_id"}
- ),
- )
- session.add(new_message)
- await session.commit()
- @staticmethod
- async def create_messages(
- *,
- session: AsyncSession,
- thread_id: str,
- run_id: str,
- assistant_id: str,
- messages: list
- ):
- for original_message in messages:
- content = MessageService.format_message_content(original_message)
- new_message = Message.model_validate(
- original_message.model_dump(by_alias=True),
- update={
- "thread_id": thread_id,
- "run_id": run_id,
- "assistant_id": assistant_id,
- "content": content,
- "role": original_message.role,
- },
- )
- session.add(new_message)
|