12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- from sqlmodel import select
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy.orm import Session
- from app.exceptions.exception import ResourceNotFoundError, BadRequestError
- from app.models.message import MessageCreate
- from app.models.thread import Thread, ThreadUpdate, ThreadCreate
- from app.models.token_relation import RelationType
- from app.providers.auth_provider import auth_policy
- from app.schemas.common import DeleteResponse
- class ThreadService:
- @staticmethod
- async def create_thread(
- *, session: AsyncSession, body: ThreadCreate, token_id=None
- ) -> Thread:
- db_thread = Thread.model_validate(body.model_dump(by_alias=True))
- session.add(db_thread)
- auth_policy.insert_token_rel(
- session=session,
- token_id=token_id,
- relation_type=RelationType.Thread,
- relation_id=db_thread.id,
- )
- await session.commit()
- await session.refresh(db_thread)
- thread_id = db_thread.id
- # save messages
- if body.messages is not None and len(body.messages) > 0:
- from app.services.message.message import MessageService
- for message in body.messages:
- if message.role != "user" and message.role != "assistant":
- raise BadRequestError(message='Role must be "user" or "assistant"')
- await MessageService.create_message(
- session=session,
- thread_id=thread_id,
- body=MessageCreate.model_validate(
- message.model_dump(by_alias=True)
- ),
- )
- elif body.thread_id:
- # copy thread
- from app.services.message.message import MessageService
- await MessageService.copy_messages(
- session=session,
- from_thread_id=body.thread_id,
- to_thread_id=thread_id,
- end_message_id=body.end_message_id,
- )
- await session.refresh(db_thread)
- return db_thread
- @staticmethod
- async def modify_thread(
- *, session: AsyncSession, thread_id: str, body: ThreadUpdate
- ) -> Thread:
- db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
- update_data = body.dict(exclude_unset=True)
- for key, value in update_data.items():
- setattr(db_thread, key, value)
- session.add(db_thread)
- await session.commit()
- await session.refresh(db_thread)
- return db_thread
- @staticmethod
- async def delete_assistant(
- *, session: AsyncSession, thread_id: str
- ) -> DeleteResponse:
- db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
- await session.delete(db_thread)
- await auth_policy.delete_token_rel(
- session=session, relation_type=RelationType.Thread, relation_id=thread_id
- )
- await session.commit()
- return DeleteResponse(id=thread_id, object="thread.deleted", deleted=True)
- @staticmethod
- async def get_thread(*, session: AsyncSession, thread_id: str) -> Thread:
- statement = select(Thread).where(Thread.id == thread_id)
- result = await session.execute(statement)
- thread = result.scalars().one_or_none()
- if thread is None:
- raise ResourceNotFoundError(message=f"thread {thread_id} not found")
- return thread
- @staticmethod
- def get_thread_sync(*, session: Session, thread_id: str) -> Thread:
- statement = select(Thread).where(Thread.id == thread_id)
- result = session.execute(statement)
- thread = result.scalars().one_or_none()
- if thread is None:
- raise ResourceNotFoundError(message=f"thread {thread_id} not found")
- return thread
|