thread.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from sqlmodel import select
  2. from sqlalchemy.ext.asyncio import AsyncSession
  3. from sqlalchemy.orm import Session
  4. from app.exceptions.exception import ResourceNotFoundError, BadRequestError
  5. from app.models.message import MessageCreate
  6. from app.models.thread import Thread, ThreadUpdate, ThreadCreate
  7. from app.models.token_relation import RelationType
  8. from app.providers.auth_provider import auth_policy
  9. from app.schemas.common import DeleteResponse
  10. class ThreadService:
  11. @staticmethod
  12. async def create_thread(*, session: AsyncSession, body: ThreadCreate, token_id=None) -> Thread:
  13. db_thread = Thread.model_validate(body.model_dump(by_alias=True))
  14. session.add(db_thread)
  15. auth_policy.insert_token_rel(
  16. session=session, token_id=token_id, relation_type=RelationType.Thread, relation_id=db_thread.id
  17. )
  18. await session.commit()
  19. await session.refresh(db_thread)
  20. thread_id = db_thread.id
  21. # save messages
  22. if body.messages is not None and len(body.messages) > 0:
  23. from app.services.message.message import MessageService
  24. for message in body.messages:
  25. if message.role != "user" and message.role != "assistant":
  26. raise BadRequestError(message='Role must be "user" or "assistant"')
  27. await MessageService.create_message(
  28. session=session,
  29. thread_id=thread_id,
  30. body=MessageCreate.model_validate(message.model_dump(by_alias=True)),
  31. )
  32. elif body.thread_id:
  33. # copy thread
  34. from app.services.message.message import MessageService
  35. await MessageService.copy_messages(
  36. session=session,
  37. from_thread_id=body.thread_id,
  38. to_thread_id=thread_id,
  39. end_message_id=body.end_message_id,
  40. )
  41. await session.refresh(db_thread)
  42. return db_thread
  43. @staticmethod
  44. async def modify_thread(*, session: AsyncSession, thread_id: str, body: ThreadUpdate) -> Thread:
  45. db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
  46. update_data = body.dict(exclude_unset=True)
  47. for key, value in update_data.items():
  48. setattr(db_thread, key, value)
  49. session.add(db_thread)
  50. await session.commit()
  51. await session.refresh(db_thread)
  52. return db_thread
  53. @staticmethod
  54. async def delete_assistant(*, session: AsyncSession, thread_id: str) -> DeleteResponse:
  55. db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
  56. await session.delete(db_thread)
  57. await auth_policy.delete_token_rel(session=session, relation_type=RelationType.Thread, relation_id=thread_id)
  58. await session.commit()
  59. return DeleteResponse(id=thread_id, object="thread.deleted", deleted=True)
  60. @staticmethod
  61. async def get_thread(*, session: AsyncSession, thread_id: str) -> Thread:
  62. statement = select(Thread).where(Thread.id == thread_id)
  63. result = await session.execute(statement)
  64. thread = result.scalars().one_or_none()
  65. if thread is None:
  66. raise ResourceNotFoundError(message=f"thread {thread_id} not found")
  67. return thread
  68. @staticmethod
  69. def get_thread_sync(*, session: Session, thread_id: str) -> Thread:
  70. statement = select(Thread).where(Thread.id == thread_id)
  71. result = session.execute(statement)
  72. thread = result.scalars().one_or_none()
  73. if thread is None:
  74. raise ResourceNotFoundError(message=f"thread {thread_id} not found")
  75. return thread