| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 | from typing import List, Optionalfrom fastapi import HTTPExceptionfrom sqlalchemy.ext.asyncio import AsyncSessionfrom sqlalchemy.orm import Sessionfrom sqlmodel import selectfrom app.exceptions.exception import ResourceNotFoundErrorfrom app.models import MessageFilefrom app.models.message import Message, MessageCreate, MessageUpdatefrom app.services.thread.thread import ThreadServicefrom app.models.token_relation import RelationTypefrom app.providers.auth_provider import auth_policyfrom app.schemas.common import DeleteResponsefrom sqlalchemy.orm.attributes import flag_modifiedclass MessageService:    @staticmethod    def format_message_content(message_create: MessageCreate) -> List:        content = []        if isinstance(message_create.content, str):            content.append(                {                    "type": "text",                    "text": {"value": message_create.content, "annotations": []},                }            )        elif isinstance(message_create.content, list):            for msg in message_create.content:                if msg.get("type") == "text":                    msg_value = msg.get("text")                    content.append(                        {                            "type": "text",                            "text": {"value": msg_value, "annotations": []},                        }                    )                elif msg.get("type") == "image_file" or msg.get("type") == "image_url":                    content.append(msg)                elif msg.get("type") == "input_audio":                    content.append(msg)        return content    @staticmethod    def new_message(        *, session: Session, content, role, assistant_id, thread_id, run_id    ) -> Message:        message = Message(            content=[{"type": "text", "text": {"value": content, "annotations": []}}],            role=role,            assistant_id=assistant_id,            thread_id=thread_id,            run_id=run_id,        )        session.add(message)        session.commit()        session.refresh(message)        return message    @staticmethod    def get_message_list(*, session: Session, thread_id) -> List[Message]:        statement = (            select(Message)            .where(Message.thread_id == thread_id)            .order_by(Message.created_at.desc())        )        return session.execute(statement).scalars().all()    @staticmethod    async def create_message(        *, session: AsyncSession, body: MessageCreate, thread_id: str    ) -> Message:        # get thread        thread = await ThreadService.get_thread(session=session, thread_id=thread_id)        print(            "create_messagecreate_messagecreate_messagecreate_messagecreate_messagecreate_message"        )        # print(thread)        # print(body)        # TODO message annotations        body_file_ids = body.file_ids        if body.attachments:            body_file_ids = [a.get("file_id") for a in body.attachments]        # print(body_file_ids)        if body_file_ids:            thread_file_ids = []            if thread.tool_resources and "file_search" in thread.tool_resources:                thread_file_ids = (                    thread.tool_resources.get("file_search")                    .get("vector_stores")[0]                    .get("file_ids")                )            for file_id in body_file_ids:                if file_id not in thread_file_ids:                    thread_file_ids.append(file_id)            print(thread_file_ids)            # if thread_file_ids:            if not thread.tool_resources:                thread.tool_resources = {}            if "file_search" not in thread.tool_resources:                thread.tool_resources["file_search"] = {                    "vector_stores": [{"file_ids": []}]                }            thread.tool_resources.get("file_search").get("vector_stores")[0][                "file_ids"            ] = thread_file_ids            setattr(thread, "tool_resources", thread.tool_resources)            flag_modified(thread, "tool_resources")            # thread.tool_resources = thread.tool_resources            print(thread)            session.add(thread)            await session.commit()            await session.refresh(thread)            # session.add(thread)            # await session.commit()            # await session.refresh(thread)        content = MessageService.format_message_content(body)        db_message = Message.model_validate(            body.model_dump(by_alias=True),            update={"thread_id": thread_id, "content": content},            from_attributes=True,        )        session.add(db_message)        await session.commit()        await session.refresh(db_message)        return db_message    @staticmethod    def get_message_sync(        *, session: Session, thread_id: str, message_id: str    ) -> Message:        statement = (            select(Message)            .where(Message.thread_id == thread_id)            .where(Message.id == message_id)        )        result = session.execute(statement)        message = result.scalars().one_or_none()        if message is None:            raise HTTPException(status_code=404, detail="Message not found")        return message    @staticmethod    def modify_message_sync(        *, session: Session, thread_id: str, message_id: str, body: MessageUpdate    ) -> Message:        if body.content:            body.content = [                {"type": "text", "text": {"value": body.content, "annotations": []}}            ]        # get thread        ThreadService.get_thread_sync(thread_id=thread_id, session=session)        # get message        db_message = MessageService.get_message_sync(            session=session, thread_id=thread_id, message_id=message_id        )        update_data = body.dict(exclude_unset=True)        for key, value in update_data.items():            setattr(db_message, key, value)        session.add(db_message)        session.commit()        session.refresh(db_message)        return db_message    @staticmethod    async def modify_message(        *, session: AsyncSession, thread_id: str, message_id: str, body: MessageUpdate    ) -> Message:        if body.content:            body.content = [                {"type": "text", "text": {"value": body.content, "annotations": []}}            ]        # get thread        await ThreadService.get_thread(thread_id=thread_id, session=session)        # get message        db_message = await MessageService.get_message(            session=session, thread_id=thread_id, message_id=message_id        )        update_data = body.dict(exclude_unset=True)        for key, value in update_data.items():            setattr(db_message, key, value)        session.add(db_message)        await session.commit()        await session.refresh(db_message)        return db_message    @staticmethod    async def delete_message(        *, session: AsyncSession, thread_id: str, message_id: str    ) -> Message:        message = await MessageService.get_message(            session=session, thread_id=thread_id, message_id=message_id        )        await session.delete(message)        await auth_policy.delete_token_rel(            session=session, relation_type=RelationType.Message, relation_id=message_id        )        await session.commit()        return DeleteResponse(id=message_id, object="message.deleted", deleted=True)    @staticmethod    async def get_message(        *, session: AsyncSession, thread_id: str, message_id: str    ) -> Message:        statement = (            select(Message)            .where(Message.thread_id == thread_id)            .where(Message.id == message_id)        )        result = await session.execute(statement)        message = result.scalars().one_or_none()        if message is None:            raise HTTPException(status_code=404, detail="Message not found")        return message    @staticmethod    async def get_message_file(        *, session: AsyncSession, thread_id: str, message_id: str, file_id: str    ) -> MessageFile:        await MessageService.get_message(            session=session, thread_id=thread_id, message_id=message_id        )        # get message files        statement = (            select(MessageFile)            .where(MessageFile.id == file_id)            .where(MessageFile.message_id == message_id)        )        result = await session.execute(statement)        msg_file = result.scalars().one_or_none()        if msg_file is None:            raise ResourceNotFoundError(message="Message file not found")        return msg_file    @staticmethod    async def copy_messages(        *,        session: AsyncSession,        from_thread_id: str,        to_thread_id: str,        end_message_id: str    ):        """        copy thread messages to another thread        """        statement = select(Message).where(Message.thread_id == from_thread_id)        if end_message_id:            statement = statement.where(Message.id <= end_message_id)        result = await session.execute(statement.order_by(Message.id))        original_messages = result.scalars().all()        for original_message in original_messages:            new_message = Message(                thread_id=to_thread_id,                **original_message.model_dump(                    exclude={"id", "created_at", "updated_at", "thread_id"}                ),            )            session.add(new_message)        await session.commit()    @staticmethod    async def create_messages(        *,        session: AsyncSession,        thread_id: str,        run_id: str,        assistant_id: str,        messages: list    ):        for original_message in messages:            content = MessageService.format_message_content(original_message)            new_message = Message.model_validate(                original_message.model_dump(by_alias=True),                update={                    "thread_id": thread_id,                    "run_id": run_id,                    "assistant_id": assistant_id,                    "content": content,                    "role": original_message.role,                },            )            session.add(new_message)
 |