| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 | from sqlmodel import selectfrom sqlalchemy.ext.asyncio import AsyncSessionfrom sqlalchemy.orm import Sessionfrom app.exceptions.exception import ResourceNotFoundError, BadRequestErrorfrom app.models.message import MessageCreatefrom app.models.thread import Thread, ThreadUpdate, ThreadCreatefrom app.models.token_relation import RelationTypefrom app.providers.auth_provider import auth_policyfrom app.schemas.common import DeleteResponseclass ThreadService:    @staticmethod    async def create_thread(        *, session: AsyncSession, body: ThreadCreate, token_id=None    ) -> Thread:        db_thread = Thread.model_validate(body.model_dump(by_alias=True))        session.add(db_thread)        auth_policy.insert_token_rel(            session=session,            token_id=token_id,            relation_type=RelationType.Thread,            relation_id=db_thread.id,        )        await session.commit()        await session.refresh(db_thread)        thread_id = db_thread.id        # save messages        if body.messages is not None and len(body.messages) > 0:            from app.services.message.message import MessageService            for message in body.messages:                if message.role != "user" and message.role != "assistant":                    raise BadRequestError(message='Role must be "user" or "assistant"')                await MessageService.create_message(                    session=session,                    thread_id=thread_id,                    body=MessageCreate.model_validate(                        message.model_dump(by_alias=True)                    ),                )        elif body.thread_id:            # copy thread            from app.services.message.message import MessageService            await MessageService.copy_messages(                session=session,                from_thread_id=body.thread_id,                to_thread_id=thread_id,                end_message_id=body.end_message_id,            )        await session.refresh(db_thread)        return db_thread    @staticmethod    async def modify_thread(        *, session: AsyncSession, thread_id: str, body: ThreadUpdate    ) -> Thread:        db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)        update_data = body.dict(exclude_unset=True)        for key, value in update_data.items():            setattr(db_thread, key, value)        session.add(db_thread)        await session.commit()        await session.refresh(db_thread)        return db_thread    @staticmethod    async def delete_assistant(        *, session: AsyncSession, thread_id: str    ) -> DeleteResponse:        db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)        await session.delete(db_thread)        await auth_policy.delete_token_rel(            session=session, relation_type=RelationType.Thread, relation_id=thread_id        )        await session.commit()        return DeleteResponse(id=thread_id, object="thread.deleted", deleted=True)    @staticmethod    async def get_thread(*, session: AsyncSession, thread_id: str) -> Thread:        statement = select(Thread).where(Thread.id == thread_id)        result = await session.execute(statement)        thread = result.scalars().one_or_none()        if thread is None:            raise ResourceNotFoundError(message=f"thread {thread_id} not found")        return thread    @staticmethod    def get_thread_sync(*, session: Session, thread_id: str) -> Thread:        statement = select(Thread).where(Thread.id == thread_id)        result = session.execute(statement)        thread = result.scalars().one_or_none()        if thread is None:            raise ResourceNotFoundError(message=f"thread {thread_id} not found")        return thread
 |