thread.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from sqlmodel import select
  2. from sqlalchemy.ext.asyncio import AsyncSession
  3. from sqlalchemy.orm import Session
  4. from app.exceptions.exception import ResourceNotFoundError, BadRequestError
  5. from app.models.message import MessageCreate
  6. from app.models.thread import Thread, ThreadUpdate, ThreadCreate
  7. from app.models.token_relation import RelationType
  8. from app.providers.auth_provider import auth_policy
  9. from app.schemas.common import DeleteResponse
  10. class ThreadService:
  11. @staticmethod
  12. async def create_thread(
  13. *, session: AsyncSession, body: ThreadCreate, token_id=None
  14. ) -> Thread:
  15. db_thread = Thread.model_validate(body.model_dump(by_alias=True))
  16. session.add(db_thread)
  17. auth_policy.insert_token_rel(
  18. session=session,
  19. token_id=token_id,
  20. relation_type=RelationType.Thread,
  21. relation_id=db_thread.id,
  22. )
  23. await session.commit()
  24. await session.refresh(db_thread)
  25. thread_id = db_thread.id
  26. # save messages
  27. if body.messages is not None and len(body.messages) > 0:
  28. from app.services.message.message import MessageService
  29. for message in body.messages:
  30. if message.role != "user" and message.role != "assistant":
  31. raise BadRequestError(message='Role must be "user" or "assistant"')
  32. await MessageService.create_message(
  33. session=session,
  34. thread_id=thread_id,
  35. body=MessageCreate.model_validate(
  36. message.model_dump(by_alias=True)
  37. ),
  38. )
  39. elif body.thread_id:
  40. # copy thread
  41. from app.services.message.message import MessageService
  42. await MessageService.copy_messages(
  43. session=session,
  44. from_thread_id=body.thread_id,
  45. to_thread_id=thread_id,
  46. end_message_id=body.end_message_id,
  47. )
  48. await session.refresh(db_thread)
  49. return db_thread
  50. @staticmethod
  51. async def modify_thread(
  52. *, session: AsyncSession, thread_id: str, body: ThreadUpdate
  53. ) -> Thread:
  54. db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
  55. update_data = body.dict(exclude_unset=True)
  56. for key, value in update_data.items():
  57. setattr(db_thread, key, value)
  58. session.add(db_thread)
  59. await session.commit()
  60. await session.refresh(db_thread)
  61. return db_thread
  62. @staticmethod
  63. async def delete_assistant(
  64. *, session: AsyncSession, thread_id: str
  65. ) -> DeleteResponse:
  66. db_thread = await ThreadService.get_thread(session=session, thread_id=thread_id)
  67. await session.delete(db_thread)
  68. await auth_policy.delete_token_rel(
  69. session=session, relation_type=RelationType.Thread, relation_id=thread_id
  70. )
  71. await session.commit()
  72. return DeleteResponse(id=thread_id, object="thread.deleted", deleted=True)
  73. @staticmethod
  74. async def get_thread(*, session: AsyncSession, thread_id: str) -> Thread:
  75. statement = select(Thread).where(Thread.id == thread_id)
  76. result = await session.execute(statement)
  77. thread = result.scalars().one_or_none()
  78. if thread is None:
  79. raise ResourceNotFoundError(message=f"thread {thread_id} not found")
  80. return thread
  81. @staticmethod
  82. def get_thread_sync(*, session: Session, thread_id: str) -> Thread:
  83. statement = select(Thread).where(Thread.id == thread_id)
  84. result = session.execute(statement)
  85. thread = result.scalars().one_or_none()
  86. if thread is None:
  87. raise ResourceNotFoundError(message=f"thread {thread_id} not found")
  88. return thread